main.rs 20 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
4
use hf_hub::{Cache, Repo, RepoType};
5
6
7
8
9
10
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
11
12
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
14
use std::path::{Path, PathBuf};
15
use text_generation_router::config::Config;
16
17
18
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
19
use thiserror::Error;
20
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
21
use tower_http::cors::AllowOrigin;
22
23
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
24
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
25
26
27
28
29

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
32
33
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
34
35
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
36
37
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
38
    #[clap(default_value = "1024", long, env)]
39
    max_input_tokens: usize,
40
    #[clap(default_value = "2048", long, env)]
41
    max_total_tokens: usize,
42
43
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
44
45
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
46
47
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
48
49
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
50
51
    #[clap(long, env)]
    max_batch_size: Option<usize>,
52
53
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
54
55
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
56
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
60
    #[clap(long, env)]
61
62
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
63
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
65
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
66
67
    #[clap(long, env)]
    json_output: bool,
68
69
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
70
71
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
72
73
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
74
75
76
77
78
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
79
    ngrok_edge: Option<String>,
80
    #[clap(long, env, default_value_t = false)]
81
    messages_api_enabled: bool,
drbh's avatar
drbh committed
82
83
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
84
85
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
86
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
87

88
89
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
90
91
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
93
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
94
        max_concurrent_requests,
95
        max_best_of,
96
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
97
        max_top_n_tokens,
98
        max_input_tokens,
99
        max_total_tokens,
100
        waiting_served_ratio,
101
102
        max_batch_prefill_tokens,
        max_batch_total_tokens,
103
        max_waiting_tokens,
104
        max_batch_size,
105
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
106
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
107
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
108
        tokenizer_name,
109
        tokenizer_config_path,
110
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
        validation_workers,
112
        json_output,
113
        otlp_endpoint,
114
        otlp_service_name,
115
        cors_allow_origin,
116
117
        ngrok,
        ngrok_authtoken,
118
        ngrok_edge,
119
        messages_api_enabled,
drbh's avatar
drbh committed
120
        disable_grammar_support,
121
        max_client_batch_size,
Olivier Dehaene's avatar
Olivier Dehaene committed
122
123
    } = args;

124
    // Launch Tokio runtime
125
    init_logging(otlp_endpoint, otlp_service_name, json_output);
126

127
    // Validate args
128
    if max_input_tokens >= max_total_tokens {
129
        return Err(RouterError::ArgumentValidation(
130
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
131
132
        ));
    }
133
134
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
135
    }
136

137
    if validation_workers == 0 {
138
139
140
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
141
142
    }

143
144
145
146
147
148
149
150
151
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

152
153
154
155
156
157
158
159
160
161
162
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

163
    // Parse Huggingface hub token
Nicolas Patry's avatar
Nicolas Patry committed
164
165
166
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
167

168
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
169
    // This will only be used to validate payloads
170
    let local_path = Path::new(&tokenizer_name);
171

172
173
    // Shared API builder initialization
    let api_builder = || {
174
175
176
177
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

178
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
179
180
181
            builder = builder.with_cache_dir(cache_dir.into());
        }

182
183
184
185
186
187
188
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
189
190
191
192
193
194
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
195
    let api = if use_api {
196
197
198
199
200
201
202
203
204
205
206
207
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
            let cache = Cache::default();
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
208
            }
209
        }
210
    } else {
211
        Type::None
212
213
214
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
215
216
217
218
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
219
        preprocessor_config_filename,
drbh's avatar
drbh committed
220
221
222
        processor_config_filename,
        model_info,
    ) = match api {
223
224
225
226
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
227
            Some(local_path.join("preprocessor_config.json")),
drbh's avatar
drbh committed
228
            Some(local_path.join("processor_config.json")),
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
244
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
drbh's avatar
drbh committed
245
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
246
247
248
249
250
251
252
253
254
255
256

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
257
                preprocessor_config_filename,
drbh's avatar
drbh committed
258
                processor_config_filename,
259
260
261
262
263
264
265
266
267
268
269
270
271
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
272
                repo.get("preprocessor_config.json"),
drbh's avatar
drbh committed
273
                repo.get("processor_config.json"),
274
275
276
277
278
279
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
280
281
            .ok()
            .as_ref()
282
283
284
285
286
287
288
289
290
291
292
293
294
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
295

296
297
298
299
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
300
    } else {
301
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
302
    };
303
304
305
306
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    let tokenizer: Option<Tokenizer> =
        tokenizer_filename.and_then(|filename| {
            let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer{
                if let Some(class) = &tokenizer_config.tokenizer_class{
                    if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"  {
                        tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
                        let mut single = vec![];
                        let mut special_tokens = vec![];
                        if let Some(true) = &tokenizer_config.add_bos_token{
                            if let Some(bos_token) = &tokenizer_config.bos_token{
                                let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
                                special_tokens.push((bos_token.clone(), bos_token_id));
                                single.push(bos_token.to_string());
                            }
                        }
                        single.push("$0".to_string());
                        if let Some(true) = &tokenizer_config.add_eos_token{
                            if let Some(eos_token) = &tokenizer_config.eos_token{
                                let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
                                special_tokens.push((eos_token.clone(), eos_token_id));
                                single.push(eos_token.to_string());
                            }
                        }
                        let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();
                        tokenizer.with_post_processor(post_processor);
                    }}
                }
            tokenizer

        });
338

339
340
    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
drbh's avatar
drbh committed
341
342
343
344
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

345
    tracing::info!("Using config {config:?}");
346
347
348
349
350
351
352
353
354
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
355
            true
356
357
358
359
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
360
361
362
363
364
365
366
367
368
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

369
370
371
372
373
374
375
376
377
378
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

    // Run server
    server::run(
OlivierDehaene's avatar
OlivierDehaene committed
379
        master_shard_uds_path,
380
381
382
383
384
385
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
386
        max_input_tokens,
387
388
389
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
390
        max_batch_total_tokens,
391
        max_waiting_tokens,
392
        max_batch_size,
393
        tokenizer,
394
        config,
395
396
397
398
399
400
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
401
        tokenizer_config,
402
        preprocessor_config,
drbh's avatar
drbh committed
403
        processor_config,
404
        messages_api_enabled,
drbh's avatar
drbh committed
405
        disable_grammar_support,
406
        max_client_batch_size,
407
408
409
    )
    .await?;
    Ok(())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
410
}
411
412
413

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
414
///     - otlp_service_name service name to appear in APM
415
416
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
417
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
418
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
419
420
421
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
422
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
423
424
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
425
        .with_ansi(ansi)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
449
                        otlp_service_name,
450
451
452
453
454
455
456
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
457
            init_tracing_opentelemetry::init_propagator().unwrap();
458
459
460
461
        };
    }

    // Filter events with LOG_LEVEL
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
477
478
479
480
481
482

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
483
484

/// get model info from the Huggingface Hub
485
486
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
487
488

    if response.status().is_success() {
489
490
491
492
493
494
495
496
497
498
499
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
500
    }
501
}
502

503
/// get base tokenizer
504
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

521
        api_base_repo.get("tokenizer.json").await.ok()
522
523
524
525
526
    } else {
        None
    }
}

527
528
529
530
531
532
533
534
535
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
536
537
538
539
540
541
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
542
543
544
545

    Some(tokenizer_config)
}

546
547
#[derive(Debug, Error)]
enum RouterError {
548
549
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
550
551
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
552
553
554
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}