main.rs 21.9 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
4
use hf_hub::{Cache, Repo, RepoType};
5
6
7
8
9
10
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
11
12
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
13
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
14
use std::path::{Path, PathBuf};
15
use text_generation_router::config::Config;
16
17
18
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
19
use thiserror::Error;
20
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
21
use tower_http::cors::AllowOrigin;
22
23
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
24
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
25
26
27
28
29

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
32
33
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
34
35
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
36
37
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
38
    #[clap(default_value = "1024", long, env)]
39
    max_input_tokens: usize,
40
    #[clap(default_value = "2048", long, env)]
41
    max_total_tokens: usize,
42
43
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
44
45
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
46
47
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
48
49
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
50
51
    #[clap(long, env)]
    max_batch_size: Option<usize>,
52
53
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
54
55
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
56
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
60
    #[clap(long, env)]
61
62
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
63
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
64
65
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
66
67
    #[clap(long, env)]
    json_output: bool,
68
69
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
70
71
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
72
73
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
74
75
76
77
78
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
79
    ngrok_edge: Option<String>,
80
    #[clap(long, env, default_value_t = false)]
81
    messages_api_enabled: bool,
drbh's avatar
drbh committed
82
83
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
84
85
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
86
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
87

88
89
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
90
91
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
93
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
94
        max_concurrent_requests,
95
        max_best_of,
96
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
97
        max_top_n_tokens,
98
        max_input_tokens,
99
        max_total_tokens,
100
        waiting_served_ratio,
101
102
        max_batch_prefill_tokens,
        max_batch_total_tokens,
103
        max_waiting_tokens,
104
        max_batch_size,
105
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
106
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
107
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
108
        tokenizer_name,
109
        tokenizer_config_path,
110
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
        validation_workers,
112
        json_output,
113
        otlp_endpoint,
114
        otlp_service_name,
115
        cors_allow_origin,
116
117
        ngrok,
        ngrok_authtoken,
118
        ngrok_edge,
119
        messages_api_enabled,
drbh's avatar
drbh committed
120
        disable_grammar_support,
121
        max_client_batch_size,
Olivier Dehaene's avatar
Olivier Dehaene committed
122
123
    } = args;

124
    // Launch Tokio runtime
125
    init_logging(otlp_endpoint, otlp_service_name, json_output);
126

127
    // Validate args
128
    if max_input_tokens >= max_total_tokens {
129
        return Err(RouterError::ArgumentValidation(
130
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
131
132
        ));
    }
133
134
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
135
    }
136

137
    if validation_workers == 0 {
138
139
140
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
141
142
    }

143
144
145
146
147
148
149
150
151
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

152
153
154
155
156
157
158
159
160
161
162
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

163
    // Parse Huggingface hub token
Nicolas Patry's avatar
Nicolas Patry committed
164
165
166
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
167

168
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
169
    // This will only be used to validate payloads
170
    let local_path = Path::new(&tokenizer_name);
171

172
173
    // Shared API builder initialization
    let api_builder = || {
174
175
176
177
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

178
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
179
180
181
            builder = builder.with_cache_dir(cache_dir.into());
        }

182
183
184
185
186
187
188
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
189
190
191
192
193
194
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
195
    let api = if use_api {
196
197
198
199
200
201
202
203
204
205
206
207
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
            let cache = Cache::default();
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
208
            }
209
        }
210
    } else {
211
        Type::None
212
213
214
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
215
216
217
218
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
219
        preprocessor_config_filename,
drbh's avatar
drbh committed
220
221
222
        processor_config_filename,
        model_info,
    ) = match api {
223
224
225
226
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
227
            Some(local_path.join("preprocessor_config.json")),
drbh's avatar
drbh committed
228
            Some(local_path.join("processor_config.json")),
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
244
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
drbh's avatar
drbh committed
245
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
246
247
248
249
250
251
252
253
254
255
256

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
257
                preprocessor_config_filename,
drbh's avatar
drbh committed
258
                processor_config_filename,
259
260
261
262
263
264
265
266
267
268
269
270
271
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
272
                repo.get("preprocessor_config.json"),
drbh's avatar
drbh committed
273
                repo.get("processor_config.json"),
274
275
276
277
278
279
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
280
281
            .ok()
            .as_ref()
282
283
284
285
286
287
288
289
290
291
292
293
294
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
295

296
297
298
299
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
300
    } else {
301
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
302
    };
303
304
305
306
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
307
308
309
310
311
312
313
314

    let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
        let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer {
            if let Some(class) = &tokenizer_config.tokenizer_class {
                if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
                    if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
                        tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
315
                        tokenizer.with_post_processor(post_processor);
316
                    }
317
                }
318
319
320
321
            }
        }
        tokenizer
    });
322

323
324
    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
drbh's avatar
drbh committed
325
326
327
328
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

329
    tracing::info!("Using config {config:?}");
330
331
332
333
334
335
336
337
338
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
339
            true
340
341
342
343
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
344
345
346
347
348
349
350
351
352
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

353
354
355
356
357
358
359
360
361
362
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

    // Run server
    server::run(
OlivierDehaene's avatar
OlivierDehaene committed
363
        master_shard_uds_path,
364
365
366
367
368
369
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
370
        max_input_tokens,
371
372
373
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
374
        max_batch_total_tokens,
375
        max_waiting_tokens,
376
        max_batch_size,
377
        tokenizer,
378
        config,
379
380
381
382
383
384
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
385
        tokenizer_config,
386
        preprocessor_config,
drbh's avatar
drbh committed
387
        processor_config,
388
        messages_api_enabled,
drbh's avatar
drbh committed
389
        disable_grammar_support,
390
        max_client_batch_size,
391
392
393
    )
    .await?;
    Ok(())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
394
}
395
396
397

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
398
///     - otlp_service_name service name to appear in APM
399
400
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
401
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
402
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
403
404
405
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
406
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
407
408
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
409
        .with_ansi(ansi)
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
433
                        otlp_service_name,
434
435
436
437
438
439
440
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
441
            init_tracing_opentelemetry::init_propagator().unwrap();
442
443
444
445
        };
    }

    // Filter events with LOG_LEVEL
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
461
462
463
464
465
466

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
467
468

/// get model info from the Huggingface Hub
469
470
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
471
472

    if response.status().is_success() {
473
474
475
476
477
478
479
480
481
482
483
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
484
    }
485
}
486

487
/// get base tokenizer
488
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

505
        api_base_repo.get("tokenizer.json").await.ok()
506
507
508
509
510
    } else {
        None
    }
}

511
512
513
514
515
516
517
518
519
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
520
521
522
523
524
525
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
526
527
528
529

    Some(tokenizer_config)
}

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
    tokenizer: &Tokenizer,
    tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
    let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
    let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

    let bos_token = tokenizer_config.bos_token.as_ref();
    let eos_token = tokenizer_config.eos_token.as_ref();

    if add_bos_token && bos_token.is_none() {
        panic!("add_bos_token = true but bos_token is None");
    }

    if add_eos_token && eos_token.is_none() {
        panic!("add_eos_token = true but eos_token is None");
    }

    let mut single = Vec::new();
    let mut pair = Vec::new();
    let mut special_tokens = Vec::new();

    if add_bos_token {
        if let Some(bos) = bos_token {
            let bos_token_id = tokenizer
                .token_to_id(bos)
                .expect("Should have found the bos token id");
            special_tokens.push((bos.clone(), bos_token_id));
            single.push(format!("{}:0", bos));
            pair.push(format!("{}:0", bos));
        }
    }

    single.push("$A:0".to_string());
    pair.push("$A:0".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            let eos_token_id = tokenizer
                .token_to_id(eos)
                .expect("Should have found the eos token id");
            special_tokens.push((eos.clone(), eos_token_id));
            single.push(format!("{}:0", eos));
            pair.push(format!("{}:0", eos));
        }
    }

    if add_bos_token {
        if let Some(bos) = bos_token {
            single.push(format!("{}:1", bos));
        }
    }

    pair.push("$B:1".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            pair.push(format!("{}:1", eos));
        }
    }

    let post_processor = TemplateProcessing::builder()
        .try_single(single)?
        .try_pair(pair)?
        .special_tokens(special_tokens)
        .build()?;

    Ok(post_processor)
}

601
602
#[derive(Debug, Error)]
enum RouterError {
603
604
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
605
606
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
607
608
609
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_create_post_processor() {
        let tokenizer_config = HubTokenizerConfig {
            add_bos_token: None,
            add_eos_token: None,
            bos_token: Some("<s>".to_string()),
            eos_token: Some("</s>".to_string()),
            chat_template: None,
            tokenizer_class: None,
            completion_template: None,
        };

        let tokenizer =
            Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
        let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

        let expected = TemplateProcessing::builder()
            .try_single("<s>:0 $A:0 <s>:1")
            .unwrap()
            .try_pair("<s>:0 $A:0 $B:1")
            .unwrap()
            .special_tokens(vec![("<s>".to_string(), 1)])
            .build()
            .unwrap();

        assert_eq!(post_processor, expected);
    }
}