main.rs 14.6 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use std::time::Duration;
13
use text_generation_client::{ClientError, ShardedClient};
14
use text_generation_router::{server, HubModelInfo};
15
use thiserror::Error;
16
use tokenizers::{FromPretrainedParameters, Tokenizer};
17
use tower_http::cors::AllowOrigin;
18
19
20
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
24
25

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
28
29
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
30
31
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
32
33
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
34
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
    max_input_length: usize,
36
    #[clap(default_value = "2048", long, env)]
37
    max_total_tokens: usize,
38
39
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
40
41
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
42
43
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
44
45
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
46
47
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
48
49
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
50
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
52
53
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
54
55
    #[clap(long, env)]
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
56
57
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
58
59
    #[clap(long, env)]
    json_output: bool,
60
61
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
62
63
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
64
65
66
67
68
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
69
    ngrok_edge: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
70
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
71

72
fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
73
74
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
75
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
76
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
        max_concurrent_requests,
78
        max_best_of,
79
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
80
        max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
        max_input_length,
82
        max_total_tokens,
83
        waiting_served_ratio,
84
85
        max_batch_prefill_tokens,
        max_batch_total_tokens,
86
        max_waiting_tokens,
87
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
88
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
89
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
90
        tokenizer_name,
91
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
        validation_workers,
93
        json_output,
94
        otlp_endpoint,
95
        cors_allow_origin,
96
97
        ngrok,
        ngrok_authtoken,
98
        ngrok_edge,
Olivier Dehaene's avatar
Olivier Dehaene committed
99
100
    } = args;

101
    // Validate args
102
103
104
105
106
    if max_input_length >= max_total_tokens {
        return Err(RouterError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
107
    if max_input_length as u32 > max_batch_prefill_tokens {
108
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
109
    }
110

111
    if validation_workers == 0 {
112
113
114
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
116
    }

117
118
119
120
121
122
123
124
125
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

126
127
128
129
130
131
132
133
134
135
136
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

137
138
139
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

140
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
141
    // This will only be used to validate payloads
142
    let local_path = Path::new(&tokenizer_name);
143
144
145
146
147
148
149
150
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
151
            revision: revision.clone().unwrap_or("main".to_string()),
152
            auth_token: authorization_token.clone(),
153
            ..Default::default()
154
        };
155
156
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
157

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
158
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
159
160
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
161
        .build()?
Olivier Dehaene's avatar
Olivier Dehaene committed
162
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
163
164
            init_logging(otlp_endpoint, json_output);

165
166
167
168
169
170
171
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

172
173
            // Get Model info
            let model_info = match local_model {
174
                true => HubModelInfo {
175
176
177
178
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
179
                false => get_model_info(&tokenizer_name, revision, authorization_token)
180
181
182
183
184
185
186
187
188
                    .await
                    .unwrap_or_else(|| {
                        tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                        HubModelInfo {
                            model_id: tokenizer_name.to_string(),
                            sha: None,
                            pipeline_tag: None,
                        }
                    }),
189
            };
190
191

            // if pipeline-tag == text-generation we default to return_full_text = true
192
            let compat_return_full_text = match &model_info.pipeline_tag {
193
194
195
196
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
197
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
198
199
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
200
            // Instantiate sharded client from the master unix socket
201
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
202
                .await
203
                .map_err(RouterError::Connection)?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
204
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
205
            sharded_client
206
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
207
                .await
208
                .map_err(RouterError::Cache)?;
209
            // Get info from the shard
210
            let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
211
212
213

            // Warmup model
            tracing::info!("Warming up model");
214
            let max_supported_batch_total_tokens = match sharded_client
OlivierDehaene's avatar
OlivierDehaene committed
215
                .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
216
                .await
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                .map_err(RouterError::Warmup)?
            {
                // Older models do not support automatic max-batch-total-tokens
                None => {
                    let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
                        16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
                    );
                    tracing::warn!("Model does not support automatic max batch total tokens");
                    max_batch_total_tokens
                }
                // Flash attention models return their max supported total tokens
                Some(max_supported_batch_total_tokens) => {
                    // Warn if user added his own max-batch-total-tokens as we will ignore it
                    if max_batch_total_tokens.is_some() {
                        tracing::warn!(
                            "`--max-batch-total-tokens` is deprecated for Flash \
                        Attention models."
                        );
                        tracing::warn!(
                            "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
                        );
                    }
239
240
241
242
                    if max_total_tokens as u32 > max_supported_batch_total_tokens {
                        return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
                    }

243
244
245
246
                    max_supported_batch_total_tokens
                }
            };
            tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
Olivier Dehaene's avatar
Olivier Dehaene committed
247
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
248

249
250
251
252
253
254
255
            let addr = match hostname.parse() {
                Ok(ip) => SocketAddr::new(ip, port),
                Err(_) => {
                    tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
                }
            };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
256

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
257
258
            // Run server
            server::run(
259
                model_info,
260
                shard_info,
261
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
262
                max_concurrent_requests,
263
                max_best_of,
264
                max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
265
                max_top_n_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
266
                max_input_length,
267
                max_total_tokens,
268
                waiting_served_ratio,
269
                max_batch_prefill_tokens,
270
                max_supported_batch_total_tokens,
271
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
272
273
274
275
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
276
                cors_allow_origin,
277
278
                ngrok,
                ngrok_authtoken,
279
                ngrok_edge,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
280
            )
281
                .await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
282
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
283
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
284
}
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
327
            init_tracing_opentelemetry::init_propagator().unwrap();
328
329
330
331
332
333
334
335
336
337
338
339
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
340
341

/// get model info from the Huggingface Hub
342
343
pub async fn get_model_info(
    model_id: &str,
344
    revision: Option<String>,
345
346
    token: Option<String>,
) -> Option<HubModelInfo> {
347
348
349
350
351
352
353
354
355
    let revision = match revision {
        None => {
            tracing::warn!("`--revision` is not set");
            tracing::warn!("We strongly advise to set it to a known supported commit.");
            "main".to_string()
        }
        Some(revision) => revision,
    };

356
    let client = reqwest::Client::new();
357
    // Poor man's urlencode
358
    let revision = revision.replace('/', "%2F");
359
    let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
360
    let mut builder = client.get(url).timeout(Duration::from_secs(5));
361
362
363
364
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

365
366
367
    let response = builder.send().await.ok()?;

    if response.status().is_success() {
368
369
370
371
372
373
374
375
376
377
378
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
379
    }
380
}
381
382
383

#[derive(Debug, Error)]
enum RouterError {
384
385
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
386
387
388
389
390
391
392
393
394
395
396
397
398
    #[error("Unable to connect to the Python model shards: {0}")]
    Connection(ClientError),
    #[error("Unable to clear the Python model shards cache: {0}")]
    Cache(ClientError),
    #[error("Unable to get the Python model shards info: {0}")]
    Info(ClientError),
    #[error("Unable to warmup the Python model shards: {0}")]
    Warmup(ClientError),
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
    #[error("Axum webserver failed: {0}")]
    Axum(#[from] axum::BoxError),
}