main.rs 17.6 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
4
use hf_hub::{Cache, Repo, RepoType};
5
6
7
8
9
10
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
11
12
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
14
use std::path::{Path, PathBuf};
15
use text_generation_router::config::Config;
drbh's avatar
drbh committed
16
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
17
use thiserror::Error;
18
use tokenizers::Tokenizer;
19
use tower_http::cors::AllowOrigin;
20
21
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
22
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
25
26
27

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
29
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
30
31
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
32
33
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
34
35
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
36
    #[clap(default_value = "1024", long, env)]
37
    max_input_tokens: usize,
38
    #[clap(default_value = "2048", long, env)]
39
    max_total_tokens: usize,
40
41
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
42
43
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
44
45
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
46
47
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
48
49
    #[clap(long, env)]
    max_batch_size: Option<usize>,
50
51
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
52
53
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
54
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
56
57
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
58
    #[clap(long, env)]
59
60
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
61
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
62
63
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
64
65
    #[clap(long, env)]
    json_output: bool,
66
67
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
68
69
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
70
71
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
72
73
74
75
76
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
77
    ngrok_edge: Option<String>,
78
    #[clap(long, env, default_value_t = false)]
79
    messages_api_enabled: bool,
drbh's avatar
drbh committed
80
81
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
82
83
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
84
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
85

86
87
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
88
89
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
90
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
91
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
        max_concurrent_requests,
93
        max_best_of,
94
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
95
        max_top_n_tokens,
96
        max_input_tokens,
97
        max_total_tokens,
98
        waiting_served_ratio,
99
100
        max_batch_prefill_tokens,
        max_batch_total_tokens,
101
        max_waiting_tokens,
102
        max_batch_size,
103
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
104
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
105
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
106
        tokenizer_name,
107
        tokenizer_config_path,
108
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
109
        validation_workers,
110
        json_output,
111
        otlp_endpoint,
112
        otlp_service_name,
113
        cors_allow_origin,
114
115
        ngrok,
        ngrok_authtoken,
116
        ngrok_edge,
117
        messages_api_enabled,
drbh's avatar
drbh committed
118
        disable_grammar_support,
119
        max_client_batch_size,
Olivier Dehaene's avatar
Olivier Dehaene committed
120
121
    } = args;

122
    // Launch Tokio runtime
123
    init_logging(otlp_endpoint, otlp_service_name, json_output);
124

125
    // Validate args
126
    if max_input_tokens >= max_total_tokens {
127
        return Err(RouterError::ArgumentValidation(
128
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
129
130
        ));
    }
131
132
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
133
    }
134

135
    if validation_workers == 0 {
136
137
138
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
139
140
    }

141
142
143
144
145
146
147
148
149
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

150
151
152
153
154
155
156
157
158
159
160
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

161
    // Parse Huggingface hub token
162
    let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
163

164
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
    // This will only be used to validate payloads
166
    let local_path = Path::new(&tokenizer_name);
167

168
169
    // Shared API builder initialization
    let api_builder = || {
170
171
172
173
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

174
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
175
176
177
            builder = builder.with_cache_dir(cache_dir.into());
        }

178
179
180
181
182
183
184
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
185
186
187
188
189
190
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
191
    let api = if use_api {
192
193
194
195
196
197
198
199
200
201
202
203
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
            let cache = Cache::default();
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
204
            }
205
        }
206
    } else {
207
        Type::None
208
209
210
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
211
212
213
214
215
216
217
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
        processor_config_filename,
        model_info,
    ) = match api {
218
219
220
221
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
drbh's avatar
drbh committed
222
            Some(local_path.join("processor_config.json")),
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
drbh's avatar
drbh committed
238
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
239
240
241
242
243
244
245
246
247
248
249

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
drbh's avatar
drbh committed
250
                processor_config_filename,
251
252
253
254
255
256
257
258
259
260
261
262
263
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
drbh's avatar
drbh committed
264
                repo.get("processor_config.json"),
265
266
267
268
269
270
271
272
                None,
            )
        }
    };
    let tokenizer: Option<Tokenizer> =
        tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
273
274
            .ok()
            .as_ref()
275
276
277
278
279
280
281
282
283
284
285
286
287
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
288

289
290
291
292
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
293
    } else {
294
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
295
    };
296
297
298
299
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
300

drbh's avatar
drbh committed
301
302
303
304
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

305
    tracing::info!("Using config {config:?}");
306
307
308
309
310
311
312
313
314
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
315
            true
316
317
318
319
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
320
321
322
323
324
325
326
327
328
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

329
330
331
332
333
334
335
336
337
338
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

    // Run server
    server::run(
OlivierDehaene's avatar
OlivierDehaene committed
339
        master_shard_uds_path,
340
341
342
343
344
345
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
346
        max_input_tokens,
347
348
349
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
350
        max_batch_total_tokens,
351
        max_waiting_tokens,
352
        max_batch_size,
353
        tokenizer,
354
        config,
355
356
357
358
359
360
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
361
        tokenizer_config,
drbh's avatar
drbh committed
362
        processor_config,
363
        messages_api_enabled,
drbh's avatar
drbh committed
364
        disable_grammar_support,
365
        max_client_batch_size,
366
367
368
    )
    .await?;
    Ok(())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
369
}
370
371
372

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
373
///     - otlp_service_name service name to appear in APM
374
375
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
376
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
377
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
378
379
380
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
381
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
382
383
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
384
        .with_ansi(ansi)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
408
                        otlp_service_name,
409
410
411
412
413
414
415
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
416
            init_tracing_opentelemetry::init_propagator().unwrap();
417
418
419
420
        };
    }

    // Filter events with LOG_LEVEL
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
436
437
438
439
440
441

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
442
443

/// get model info from the Huggingface Hub
444
445
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
446
447

    if response.status().is_success() {
448
449
450
451
452
453
454
455
456
457
458
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
459
    }
460
}
461

462
/// get base tokenizer
463
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

480
        api_base_repo.get("tokenizer.json").await.ok()
481
482
483
484
485
    } else {
        None
    }
}

486
487
488
489
490
491
492
493
494
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
495
496
497
498
499
500
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
501
502
503
504

    Some(tokenizer_config)
}

505
506
#[derive(Debug, Error)]
enum RouterError {
507
508
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
509
510
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
511
512
513
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}