main.rs 25.4 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use clap::Subcommand;
4
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5
use hf_hub::{Cache, Repo, RepoType};
6
7
8
9
10
11
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
12
13
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15
use std::path::{Path, PathBuf};
16
use text_generation_router::config::Config;
17
use text_generation_router::usage_stats;
18
19
20
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
21
use thiserror::Error;
Nicolas Patry's avatar
Nicolas Patry committed
22
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
23
use tower_http::cors::AllowOrigin;
24
25
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
26
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
27
28
29
30
31

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
32
33
34
    #[command(subcommand)]
    command: Option<Commands>,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
37
38
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
39
40
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
41
42
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
43
    #[clap(default_value = "1024", long, env)]
44
    max_input_tokens: usize,
45
    #[clap(default_value = "2048", long, env)]
46
    max_total_tokens: usize,
47
48
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
49
50
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
51
52
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
53
54
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
55
56
    #[clap(long, env)]
    max_batch_size: Option<usize>,
57
58
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
61
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
62
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
63
64
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
65
    #[clap(long, env)]
66
67
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
68
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
70
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
71
72
    #[clap(long, env)]
    json_output: bool,
73
74
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
75
76
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
77
78
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
79
    #[clap(long, env)]
Erik Kaunismäki's avatar
Erik Kaunismäki committed
80
81
    api_key: Option<String>,
    #[clap(long, env)]
82
83
84
85
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
86
    ngrok_edge: Option<String>,
87
    #[clap(long, env, default_value_t = false)]
88
    messages_api_enabled: bool,
drbh's avatar
drbh committed
89
90
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
91
92
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
93
94
95
96
    #[clap(long, env, default_value_t)]
    disable_usage_stats: bool,
    #[clap(long, env, default_value_t)]
    disable_crash_reports: bool,
Olivier Dehaene's avatar
Olivier Dehaene committed
97
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
98

99
100
101
102
103
#[derive(Debug, Subcommand)]
enum Commands {
    PrintSchema,
}

104
105
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
106
    let args = Args::parse();
107

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
109
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
110
        max_concurrent_requests,
111
        max_best_of,
112
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
113
        max_top_n_tokens,
114
        max_input_tokens,
115
        max_total_tokens,
116
        waiting_served_ratio,
117
118
        max_batch_prefill_tokens,
        max_batch_total_tokens,
119
        max_waiting_tokens,
120
        max_batch_size,
121
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
122
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
123
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
124
        tokenizer_name,
125
        tokenizer_config_path,
126
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
127
        validation_workers,
128
        json_output,
129
        otlp_endpoint,
130
        otlp_service_name,
131
        cors_allow_origin,
Erik Kaunismäki's avatar
Erik Kaunismäki committed
132
        api_key,
133
134
        ngrok,
        ngrok_authtoken,
135
        ngrok_edge,
136
        messages_api_enabled,
drbh's avatar
drbh committed
137
        disable_grammar_support,
138
        max_client_batch_size,
139
140
        disable_usage_stats,
        disable_crash_reports,
141
        command,
Olivier Dehaene's avatar
Olivier Dehaene committed
142
143
    } = args;

144
145
146
147
148
149
150
151
    let print_schema_command = match command {
        Some(Commands::PrintSchema) => true,
        None => {
            // only init logging if we are not running the print schema command
            init_logging(otlp_endpoint, otlp_service_name, json_output);
            false
        }
    };
152

153
    // Validate args
154
    if max_input_tokens >= max_total_tokens {
155
        return Err(RouterError::ArgumentValidation(
156
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
157
158
        ));
    }
159
160
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
161
    }
162

163
    if validation_workers == 0 {
164
165
166
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
167
168
    }

169
170
171
172
173
174
175
176
177
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

178
179
180
181
182
183
184
185
186
187
188
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

189
    // Parse Huggingface hub token
Nicolas Patry's avatar
Nicolas Patry committed
190
191
192
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
193

194
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
195
    // This will only be used to validate payloads
196
    let local_path = Path::new(&tokenizer_name);
197

198
199
    // Shared API builder initialization
    let api_builder = || {
200
201
202
203
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

204
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
205
206
207
            builder = builder.with_cache_dir(cache_dir.into());
        }

208
209
210
211
212
213
214
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
215
216
217
218
219
220
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
221
    let api = if use_api {
222
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
223
224
225
226
227
            let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
                .map_err(|_| ())
                .map(|cache_dir| Cache::new(cache_dir.into()))
                .unwrap_or_else(|_| Cache::default());

228
229
230
231
232
233
234
235
236
237
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
238
            }
239
        }
240
    } else {
241
        Type::None
242
243
244
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
245
246
247
248
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
249
        preprocessor_config_filename,
drbh's avatar
drbh committed
250
251
252
        processor_config_filename,
        model_info,
    ) = match api {
253
254
255
256
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
257
            Some(local_path.join("preprocessor_config.json")),
drbh's avatar
drbh committed
258
            Some(local_path.join("processor_config.json")),
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
274
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
drbh's avatar
drbh committed
275
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
276
277
278
279
280
281
282
283
284
285
286

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
287
                preprocessor_config_filename,
drbh's avatar
drbh committed
288
                processor_config_filename,
289
290
291
292
293
294
295
296
297
298
299
300
301
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
302
                repo.get("preprocessor_config.json"),
drbh's avatar
drbh committed
303
                repo.get("processor_config.json"),
304
305
306
307
308
309
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
310
311
            .ok()
            .as_ref()
312
313
314
315
316
317
318
319
320
321
322
323
324
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
325

326
327
328
329
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
330
    } else {
331
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
332
    };
333
334
335
336
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
337
    let tokenizer_class = tokenizer_config.tokenizer_class.clone();
338
339
340
341
342

    let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
        let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer {
            if let Some(class) = &tokenizer_config.tokenizer_class {
Nicolas Patry's avatar
Nicolas Patry committed
343
                if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
344
345
                    if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
                        tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
346
                        tokenizer.with_post_processor(post_processor);
347
                    }
348
                }
349
350
351
352
            }
        }
        tokenizer
    });
353

354
355
    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
drbh's avatar
drbh committed
356
357
358
359
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

360
    tracing::info!("Using config {config:?}");
361
362
363
364
365
366
367
368
369
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
370
            true
371
372
373
374
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
375
376
377
378
379
380
381
382
383
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

384
385
386
387
388
389
390
391
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    // Only send usage stats when TGI is run in container and the function returns Some
    let is_container = matches!(usage_stats::is_container(), Ok(true));

    let user_agent = if !disable_usage_stats && is_container {
        let reduced_args = usage_stats::Args::new(
            config.clone(),
            tokenizer_class,
            max_concurrent_requests,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_tokens,
            max_total_tokens,
            waiting_served_ratio,
            max_batch_prefill_tokens,
            max_batch_total_tokens,
            max_waiting_tokens,
            max_batch_size,
            revision,
            validation_workers,
            messages_api_enabled,
            disable_grammar_support,
            max_client_batch_size,
            disable_usage_stats,
            disable_crash_reports,
        );
        Some(usage_stats::UserAgent::new(reduced_args))
    } else {
        None
    };

    if let Some(ref ua) = user_agent {
        let start_event =
            usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
        tokio::spawn(async move {
            start_event.send().await;
        });
    };

431
    // Run server
432
    let result = server::run(
OlivierDehaene's avatar
OlivierDehaene committed
433
        master_shard_uds_path,
434
435
436
437
438
439
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
440
        max_input_tokens,
441
442
443
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
444
        max_batch_total_tokens,
445
        max_waiting_tokens,
446
        max_batch_size,
447
        tokenizer,
448
        config,
449
450
451
        validation_workers,
        addr,
        cors_allow_origin,
Erik Kaunismäki's avatar
Erik Kaunismäki committed
452
        api_key,
453
454
455
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
456
        tokenizer_config,
457
        preprocessor_config,
drbh's avatar
drbh committed
458
        processor_config,
459
        messages_api_enabled,
drbh's avatar
drbh committed
460
        disable_grammar_support,
461
        max_client_batch_size,
462
        print_schema_command,
463
    )
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    .await;

    match result {
        Ok(_) => {
            if let Some(ref ua) = user_agent {
                let stop_event = usage_stats::UsageStatsEvent::new(
                    ua.clone(),
                    usage_stats::EventType::Stop,
                    None,
                );
                stop_event.send().await;
            };
            Ok(())
        }
        Err(e) => {
            if let Some(ref ua) = user_agent {
                if !disable_crash_reports {
                    let error_event = usage_stats::UsageStatsEvent::new(
                        ua.clone(),
                        usage_stats::EventType::Error,
                        Some(e.to_string()),
                    );
                    error_event.send().await;
                } else {
                    let unknow_error_event = usage_stats::UsageStatsEvent::new(
                        ua.clone(),
                        usage_stats::EventType::Error,
                        Some("unknow_error".to_string()),
                    );
                    unknow_error_event.send().await;
                }
            };
            Err(RouterError::WebServer(e))
        }
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
499
}
500
501
502

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
503
///     - otlp_service_name service name to appear in APM
504
505
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
506
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
507
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
508
509
510
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
511
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
512
513
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
514
        .with_ansi(ansi)
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
538
                        otlp_service_name,
539
540
541
542
543
544
545
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
546
            init_tracing_opentelemetry::init_propagator().unwrap();
547
548
549
550
        };
    }

    // Filter events with LOG_LEVEL
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
566
567
568
569
570
571

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
572
573

/// get model info from the Huggingface Hub
574
575
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
576
577

    if response.status().is_success() {
578
579
580
581
582
583
584
585
586
587
588
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
589
    }
590
}
591

592
/// get base tokenizer
593
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

610
        api_base_repo.get("tokenizer.json").await.ok()
611
612
613
614
615
    } else {
        None
    }
}

616
617
618
619
620
621
622
623
624
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
625
626
627
628
629
630
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
631
632
633
634

    Some(tokenizer_config)
}

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
    tokenizer: &Tokenizer,
    tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
    let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
    let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

    let bos_token = tokenizer_config.bos_token.as_ref();
    let eos_token = tokenizer_config.eos_token.as_ref();

    if add_bos_token && bos_token.is_none() {
        panic!("add_bos_token = true but bos_token is None");
    }

    if add_eos_token && eos_token.is_none() {
        panic!("add_eos_token = true but eos_token is None");
    }

    let mut single = Vec::new();
    let mut pair = Vec::new();
    let mut special_tokens = Vec::new();

    if add_bos_token {
        if let Some(bos) = bos_token {
            let bos_token_id = tokenizer
661
                .token_to_id(bos.as_str())
662
                .expect("Should have found the bos token id");
663
664
665
            special_tokens.push((bos.as_str(), bos_token_id));
            single.push(format!("{}:0", bos.as_str()));
            pair.push(format!("{}:0", bos.as_str()));
666
667
668
669
670
671
672
673
674
        }
    }

    single.push("$A:0".to_string());
    pair.push("$A:0".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            let eos_token_id = tokenizer
675
                .token_to_id(eos.as_str())
676
                .expect("Should have found the eos token id");
677
678
679
            special_tokens.push((eos.as_str(), eos_token_id));
            single.push(format!("{}:0", eos.as_str()));
            pair.push(format!("{}:0", eos.as_str()));
680
681
682
683
684
        }
    }

    if add_bos_token {
        if let Some(bos) = bos_token {
685
            pair.push(format!("{}:1", bos.as_str()));
686
687
688
689
690
691
692
        }
    }

    pair.push("$B:1".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
693
            pair.push(format!("{}:1", eos.as_str()));
694
695
696
697
698
699
700
701
702
703
704
705
        }
    }

    let post_processor = TemplateProcessing::builder()
        .try_single(single)?
        .try_pair(pair)?
        .special_tokens(special_tokens)
        .build()?;

    Ok(post_processor)
}

706
707
#[derive(Debug, Error)]
enum RouterError {
708
709
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
710
711
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
712
713
714
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}
715
716
717
718

#[cfg(test)]
mod tests {
    use super::*;
719
    use text_generation_router::TokenizerConfigToken;
720
721
722
723
724
725

    #[test]
    fn test_create_post_processor() {
        let tokenizer_config = HubTokenizerConfig {
            add_bos_token: None,
            add_eos_token: None,
726
727
            bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
            eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
728
729
730
731
732
733
734
735
736
737
            chat_template: None,
            tokenizer_class: None,
            completion_template: None,
        };

        let tokenizer =
            Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
        let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

        let expected = TemplateProcessing::builder()
738
            .try_single("<s>:0 $A:0")
739
            .unwrap()
740
            .try_pair("<s>:0 $A:0 <s>:1 $B:1")
741
742
743
744
745
746
747
748
            .unwrap()
            .special_tokens(vec![("<s>".to_string(), 1)])
            .build()
            .unwrap();

        assert_eq!(post_processor, expected);
    }
}