main.rs 25.3 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use clap::Subcommand;
4
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5
use hf_hub::{Cache, Repo, RepoType};
6
7
8
9
10
11
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
12
13
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15
use std::path::{Path, PathBuf};
16
use text_generation_router::config::Config;
17
use text_generation_router::usage_stats;
18
19
20
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
21
use thiserror::Error;
Nicolas Patry's avatar
Nicolas Patry committed
22
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
23
use tower_http::cors::AllowOrigin;
24
25
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
26
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
27
28
29
30
31

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
32
33
34
    #[command(subcommand)]
    command: Option<Commands>,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
37
38
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
39
40
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
41
42
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
43
    #[clap(default_value = "1024", long, env)]
44
    max_input_tokens: usize,
45
    #[clap(default_value = "2048", long, env)]
46
    max_total_tokens: usize,
47
48
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
49
50
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
51
52
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
53
54
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
55
56
    #[clap(long, env)]
    max_batch_size: Option<usize>,
57
58
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
61
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
62
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
63
64
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
65
    #[clap(long, env)]
66
67
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
68
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
70
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
71
72
    #[clap(long, env)]
    json_output: bool,
73
74
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
75
76
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
77
78
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
79
80
81
82
83
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
84
    ngrok_edge: Option<String>,
85
    #[clap(long, env, default_value_t = false)]
86
    messages_api_enabled: bool,
drbh's avatar
drbh committed
87
88
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
89
90
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
91
92
93
94
    #[clap(long, env, default_value_t)]
    disable_usage_stats: bool,
    #[clap(long, env, default_value_t)]
    disable_crash_reports: bool,
Olivier Dehaene's avatar
Olivier Dehaene committed
95
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
96

97
98
99
100
101
#[derive(Debug, Subcommand)]
enum Commands {
    PrintSchema,
}

102
103
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
104
    let args = Args::parse();
105

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
106
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
107
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
        max_concurrent_requests,
109
        max_best_of,
110
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
111
        max_top_n_tokens,
112
        max_input_tokens,
113
        max_total_tokens,
114
        waiting_served_ratio,
115
116
        max_batch_prefill_tokens,
        max_batch_total_tokens,
117
        max_waiting_tokens,
118
        max_batch_size,
119
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
120
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
121
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
122
        tokenizer_name,
123
        tokenizer_config_path,
124
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
125
        validation_workers,
126
        json_output,
127
        otlp_endpoint,
128
        otlp_service_name,
129
        cors_allow_origin,
130
131
        ngrok,
        ngrok_authtoken,
132
        ngrok_edge,
133
        messages_api_enabled,
drbh's avatar
drbh committed
134
        disable_grammar_support,
135
        max_client_batch_size,
136
137
        disable_usage_stats,
        disable_crash_reports,
138
        command,
Olivier Dehaene's avatar
Olivier Dehaene committed
139
140
    } = args;

141
142
143
144
145
146
147
148
    let print_schema_command = match command {
        Some(Commands::PrintSchema) => true,
        None => {
            // only init logging if we are not running the print schema command
            init_logging(otlp_endpoint, otlp_service_name, json_output);
            false
        }
    };
149

150
    // Validate args
151
    if max_input_tokens >= max_total_tokens {
152
        return Err(RouterError::ArgumentValidation(
153
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
154
155
        ));
    }
156
157
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
158
    }
159

160
    if validation_workers == 0 {
161
162
163
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
164
165
    }

166
167
168
169
170
171
172
173
174
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

175
176
177
178
179
180
181
182
183
184
185
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

186
    // Parse Huggingface hub token
Nicolas Patry's avatar
Nicolas Patry committed
187
188
189
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
190

191
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
192
    // This will only be used to validate payloads
193
    let local_path = Path::new(&tokenizer_name);
194

195
196
    // Shared API builder initialization
    let api_builder = || {
197
198
199
200
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

201
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
202
203
204
            builder = builder.with_cache_dir(cache_dir.into());
        }

205
206
207
208
209
210
211
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
212
213
214
215
216
217
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
218
    let api = if use_api {
219
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
220
221
222
223
224
            let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
                .map_err(|_| ())
                .map(|cache_dir| Cache::new(cache_dir.into()))
                .unwrap_or_else(|_| Cache::default());

225
226
227
228
229
230
231
232
233
234
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
235
            }
236
        }
237
    } else {
238
        Type::None
239
240
241
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
242
243
244
245
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
246
        preprocessor_config_filename,
drbh's avatar
drbh committed
247
248
249
        processor_config_filename,
        model_info,
    ) = match api {
250
251
252
253
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
254
            Some(local_path.join("preprocessor_config.json")),
drbh's avatar
drbh committed
255
            Some(local_path.join("processor_config.json")),
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
271
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
drbh's avatar
drbh committed
272
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
273
274
275
276
277
278
279
280
281
282
283

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
284
                preprocessor_config_filename,
drbh's avatar
drbh committed
285
                processor_config_filename,
286
287
288
289
290
291
292
293
294
295
296
297
298
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
299
                repo.get("preprocessor_config.json"),
drbh's avatar
drbh committed
300
                repo.get("processor_config.json"),
301
302
303
304
305
306
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
307
308
            .ok()
            .as_ref()
309
310
311
312
313
314
315
316
317
318
319
320
321
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
322

323
324
325
326
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
327
    } else {
328
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
329
    };
330
331
332
333
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
334
    let tokenizer_class = tokenizer_config.tokenizer_class.clone();
335
336
337
338
339

    let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
        let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer {
            if let Some(class) = &tokenizer_config.tokenizer_class {
Nicolas Patry's avatar
Nicolas Patry committed
340
                if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
341
342
                    if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
                        tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
343
                        tokenizer.with_post_processor(post_processor);
344
                    }
345
                }
346
347
348
349
            }
        }
        tokenizer
    });
350

351
352
    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
drbh's avatar
drbh committed
353
354
355
356
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

357
    tracing::info!("Using config {config:?}");
358
359
360
361
362
363
364
365
366
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
367
            true
368
369
370
371
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
372
373
374
375
376
377
378
379
380
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

381
382
383
384
385
386
387
388
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    // Only send usage stats when TGI is run in container and the function returns Some
    let is_container = matches!(usage_stats::is_container(), Ok(true));

    let user_agent = if !disable_usage_stats && is_container {
        let reduced_args = usage_stats::Args::new(
            config.clone(),
            tokenizer_class,
            max_concurrent_requests,
            max_best_of,
            max_stop_sequences,
            max_top_n_tokens,
            max_input_tokens,
            max_total_tokens,
            waiting_served_ratio,
            max_batch_prefill_tokens,
            max_batch_total_tokens,
            max_waiting_tokens,
            max_batch_size,
            revision,
            validation_workers,
            messages_api_enabled,
            disable_grammar_support,
            max_client_batch_size,
            disable_usage_stats,
            disable_crash_reports,
        );
        Some(usage_stats::UserAgent::new(reduced_args))
    } else {
        None
    };

    if let Some(ref ua) = user_agent {
        let start_event =
            usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
        tokio::spawn(async move {
            start_event.send().await;
        });
    };

428
    // Run server
429
    let result = server::run(
OlivierDehaene's avatar
OlivierDehaene committed
430
        master_shard_uds_path,
431
432
433
434
435
436
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
437
        max_input_tokens,
438
439
440
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
441
        max_batch_total_tokens,
442
        max_waiting_tokens,
443
        max_batch_size,
444
        tokenizer,
445
        config,
446
447
448
449
450
451
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
452
        tokenizer_config,
453
        preprocessor_config,
drbh's avatar
drbh committed
454
        processor_config,
455
        messages_api_enabled,
drbh's avatar
drbh committed
456
        disable_grammar_support,
457
        max_client_batch_size,
458
        print_schema_command,
459
    )
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    .await;

    match result {
        Ok(_) => {
            if let Some(ref ua) = user_agent {
                let stop_event = usage_stats::UsageStatsEvent::new(
                    ua.clone(),
                    usage_stats::EventType::Stop,
                    None,
                );
                stop_event.send().await;
            };
            Ok(())
        }
        Err(e) => {
            if let Some(ref ua) = user_agent {
                if !disable_crash_reports {
                    let error_event = usage_stats::UsageStatsEvent::new(
                        ua.clone(),
                        usage_stats::EventType::Error,
                        Some(e.to_string()),
                    );
                    error_event.send().await;
                } else {
                    let unknow_error_event = usage_stats::UsageStatsEvent::new(
                        ua.clone(),
                        usage_stats::EventType::Error,
                        Some("unknow_error".to_string()),
                    );
                    unknow_error_event.send().await;
                }
            };
            Err(RouterError::WebServer(e))
        }
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
495
}
496
497
498

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
499
///     - otlp_service_name service name to appear in APM
500
501
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
502
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
503
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
504
505
506
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
507
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
508
509
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
510
        .with_ansi(ansi)
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
534
                        otlp_service_name,
535
536
537
538
539
540
541
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
542
            init_tracing_opentelemetry::init_propagator().unwrap();
543
544
545
546
        };
    }

    // Filter events with LOG_LEVEL
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
562
563
564
565
566
567

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
568
569

/// get model info from the Huggingface Hub
570
571
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
572
573

    if response.status().is_success() {
574
575
576
577
578
579
580
581
582
583
584
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
585
    }
586
}
587

588
/// get base tokenizer
589
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

606
        api_base_repo.get("tokenizer.json").await.ok()
607
608
609
610
611
    } else {
        None
    }
}

612
613
614
615
616
617
618
619
620
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
621
622
623
624
625
626
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
627
628
629
630

    Some(tokenizer_config)
}

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
    tokenizer: &Tokenizer,
    tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
    let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
    let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

    let bos_token = tokenizer_config.bos_token.as_ref();
    let eos_token = tokenizer_config.eos_token.as_ref();

    if add_bos_token && bos_token.is_none() {
        panic!("add_bos_token = true but bos_token is None");
    }

    if add_eos_token && eos_token.is_none() {
        panic!("add_eos_token = true but eos_token is None");
    }

    let mut single = Vec::new();
    let mut pair = Vec::new();
    let mut special_tokens = Vec::new();

    if add_bos_token {
        if let Some(bos) = bos_token {
            let bos_token_id = tokenizer
657
                .token_to_id(bos.as_str())
658
                .expect("Should have found the bos token id");
659
660
661
            special_tokens.push((bos.as_str(), bos_token_id));
            single.push(format!("{}:0", bos.as_str()));
            pair.push(format!("{}:0", bos.as_str()));
662
663
664
665
666
667
668
669
670
        }
    }

    single.push("$A:0".to_string());
    pair.push("$A:0".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            let eos_token_id = tokenizer
671
                .token_to_id(eos.as_str())
672
                .expect("Should have found the eos token id");
673
674
675
            special_tokens.push((eos.as_str(), eos_token_id));
            single.push(format!("{}:0", eos.as_str()));
            pair.push(format!("{}:0", eos.as_str()));
676
677
678
679
680
        }
    }

    if add_bos_token {
        if let Some(bos) = bos_token {
681
            pair.push(format!("{}:1", bos.as_str()));
682
683
684
685
686
687
688
        }
    }

    pair.push("$B:1".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
689
            pair.push(format!("{}:1", eos.as_str()));
690
691
692
693
694
695
696
697
698
699
700
701
        }
    }

    let post_processor = TemplateProcessing::builder()
        .try_single(single)?
        .try_pair(pair)?
        .special_tokens(special_tokens)
        .build()?;

    Ok(post_processor)
}

702
703
#[derive(Debug, Error)]
enum RouterError {
704
705
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
706
707
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
708
709
710
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}
711
712
713
714

#[cfg(test)]
mod tests {
    use super::*;
715
    use text_generation_router::TokenizerConfigToken;
716
717
718
719
720
721

    #[test]
    fn test_create_post_processor() {
        let tokenizer_config = HubTokenizerConfig {
            add_bos_token: None,
            add_eos_token: None,
722
723
            bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
            eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
724
725
726
727
728
729
730
731
732
733
            chat_template: None,
            tokenizer_class: None,
            completion_template: None,
        };

        let tokenizer =
            Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
        let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

        let expected = TemplateProcessing::builder()
734
            .try_single("<s>:0 $A:0")
735
            .unwrap()
736
            .try_pair("<s>:0 $A:0 <s>:1 $B:1")
737
738
739
740
741
742
743
744
            .unwrap()
            .special_tokens(vec![("<s>".to_string(), 1)])
            .build()
            .unwrap();

        assert_eq!(post_processor, expected);
    }
}