main.rs 10.5 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use std::time::Duration;
13
use text_generation_client::ShardedClient;
14
use text_generation_router::{server, HubModelInfo};
15
use tokenizers::{FromPretrainedParameters, Tokenizer};
16
use tower_http::cors::AllowOrigin;
17
18
19
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
22
23
24

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
26
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
27
28
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
29
30
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
32
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
33
34
    #[clap(default_value = "1512", long, env)]
    max_total_tokens: usize,
35
36
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
37
38
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
39
40
    #[clap(default_value = "32000", long, env)]
    max_batch_total_tokens: u32,
41
42
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
45
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
47
48
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
49
50
    #[clap(default_value = "main", long, env)]
    revision: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
52
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
53
54
    #[clap(long, env)]
    json_output: bool,
55
56
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
57
58
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
59
60
61
62
63
64
65
66
67
68
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
    ngrok_domain: Option<String>,
    #[clap(long, env)]
    ngrok_username: Option<String>,
    #[clap(long, env)]
    ngrok_password: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
69
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
70

Olivier Dehaene's avatar
Olivier Dehaene committed
71
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
72
73
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
75
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
76
        max_concurrent_requests,
77
        max_best_of,
78
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
        max_input_length,
80
        max_total_tokens,
81
        waiting_served_ratio,
82
83
        max_batch_prefill_tokens,
        max_batch_total_tokens,
84
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
85
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
86
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
87
        tokenizer_name,
88
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
89
        validation_workers,
90
        json_output,
91
        otlp_endpoint,
92
        cors_allow_origin,
93
94
95
96
97
        ngrok,
        ngrok_authtoken,
        ngrok_domain,
        ngrok_username,
        ngrok_password,
Olivier Dehaene's avatar
Olivier Dehaene committed
98
99
    } = args;

100
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
102
103
        panic!("validation_workers must be > 0");
    }

104
105
106
107
108
109
110
111
112
113
114
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

115
116
117
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

118
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
    // This will only be used to validate payloads
120
    let local_path = Path::new(&tokenizer_name);
121
122
123
124
125
126
127
128
129
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
            revision: revision.clone(),
130
            auth_token: authorization_token.clone(),
131
            ..Default::default()
132
        };
133
134
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
135

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
136
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
137
138
139
140
141
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
142
143
            init_logging(otlp_endpoint, json_output);

144
145
146
147
148
149
150
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

151
152
            // Get Model info
            let model_info = match local_model {
153
                true => HubModelInfo {
154
155
156
157
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
158
159
160
161
162
163
164
165
166
167
                false => get_model_info(&tokenizer_name, &revision, authorization_token)
                    .await
                    .unwrap_or_else(|| {
                        tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                        HubModelInfo {
                            model_id: tokenizer_name.to_string(),
                            sha: None,
                            pipeline_tag: None,
                        }
                    }),
168
            };
169
170

            // if pipeline-tag == text-generation we default to return_full_text = true
171
            let compat_return_full_text = match &model_info.pipeline_tag {
172
173
174
175
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
176
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
177
178
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
179
            // Instantiate sharded client from the master unix socket
180
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
181
182
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
183
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
184
            sharded_client
185
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
186
187
                .await
                .expect("Unable to clear cache");
188
189
190
191
192
            // Get info from the shard
            let shard_info = sharded_client
                .info()
                .await
                .expect("Unable to get shard info");
193
194
195
196
197
198
199
200
201
202
203

            // Warmup model
            tracing::info!("Warming up model");
            sharded_client
                .warmup(
                    max_input_length as u32,
                    max_batch_prefill_tokens,
                    max_batch_total_tokens,
                )
                .await
                .expect("Unable to warmup model");
Olivier Dehaene's avatar
Olivier Dehaene committed
204
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
205

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
206
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
207
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
208

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
209
210
            // Run server
            server::run(
211
                model_info,
212
                shard_info,
213
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
214
                max_concurrent_requests,
215
                max_best_of,
216
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
217
                max_input_length,
218
                max_total_tokens,
219
                waiting_served_ratio,
220
                max_batch_prefill_tokens,
221
                max_batch_total_tokens,
222
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
223
224
225
226
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
227
                cors_allow_origin,
228
229
230
231
232
                ngrok,
                ngrok_authtoken,
                ngrok_domain,
                ngrok_username,
                ngrok_password,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
233
            )
234
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
235
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
236
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
237
}
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
293
294

/// get model info from the Huggingface Hub
295
296
297
298
299
pub async fn get_model_info(
    model_id: &str,
    revision: &str,
    token: Option<String>,
) -> Option<HubModelInfo> {
300
    let client = reqwest::Client::new();
301
    // Poor man's urlencode
302
    let revision = revision.replace('/', "%2F");
303
    let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
304
    let mut builder = client.get(url).timeout(Duration::from_secs(5));
305
306
307
308
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

309
310
311
312
313
314
    let response = builder.send().await.ok()?;

    if response.status().is_success() {
        return serde_json::from_str(&response.text().await.ok()?).ok();
    }
    None
315
}