main.rs 22.4 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use clap::Subcommand;
4
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5
use hf_hub::{Cache, Repo, RepoType};
6
7
8
9
10
11
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
12
13
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15
use std::path::{Path, PathBuf};
16
use text_generation_router::config::Config;
17
18
19
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
20
use thiserror::Error;
Nicolas Patry's avatar
Nicolas Patry committed
21
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
22
use tower_http::cors::AllowOrigin;
23
24
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
25
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
28
29
30

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
31
32
33
    #[command(subcommand)]
    command: Option<Commands>,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
35
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
36
37
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
38
39
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
40
41
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
42
    #[clap(default_value = "1024", long, env)]
43
    max_input_tokens: usize,
44
    #[clap(default_value = "2048", long, env)]
45
    max_total_tokens: usize,
46
47
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
48
49
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
50
51
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
52
53
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
54
55
    #[clap(long, env)]
    max_batch_size: Option<usize>,
56
57
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
60
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
62
63
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
64
    #[clap(long, env)]
65
66
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
67
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
70
71
    #[clap(long, env)]
    json_output: bool,
72
73
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
74
75
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
76
77
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
78
79
80
81
82
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
83
    ngrok_edge: Option<String>,
84
    #[clap(long, env, default_value_t = false)]
85
    messages_api_enabled: bool,
drbh's avatar
drbh committed
86
87
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
88
89
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
90
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91

92
93
94
95
96
#[derive(Debug, Subcommand)]
enum Commands {
    PrintSchema,
}

97
98
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
99
    let args = Args::parse();
100

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
102
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
103
        max_concurrent_requests,
104
        max_best_of,
105
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
106
        max_top_n_tokens,
107
        max_input_tokens,
108
        max_total_tokens,
109
        waiting_served_ratio,
110
111
        max_batch_prefill_tokens,
        max_batch_total_tokens,
112
        max_waiting_tokens,
113
        max_batch_size,
114
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
115
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
116
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
117
        tokenizer_name,
118
        tokenizer_config_path,
119
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
        validation_workers,
121
        json_output,
122
        otlp_endpoint,
123
        otlp_service_name,
124
        cors_allow_origin,
125
126
        ngrok,
        ngrok_authtoken,
127
        ngrok_edge,
128
        messages_api_enabled,
drbh's avatar
drbh committed
129
        disable_grammar_support,
130
        max_client_batch_size,
131
        command,
Olivier Dehaene's avatar
Olivier Dehaene committed
132
133
    } = args;

134
135
136
137
138
139
140
141
    let print_schema_command = match command {
        Some(Commands::PrintSchema) => true,
        None => {
            // only init logging if we are not running the print schema command
            init_logging(otlp_endpoint, otlp_service_name, json_output);
            false
        }
    };
142

143
    // Validate args
144
    if max_input_tokens >= max_total_tokens {
145
        return Err(RouterError::ArgumentValidation(
146
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
147
148
        ));
    }
149
150
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
151
    }
152

153
    if validation_workers == 0 {
154
155
156
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
158
    }

159
160
161
162
163
164
165
166
167
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

168
169
170
171
172
173
174
175
176
177
178
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

179
    // Parse Huggingface hub token
Nicolas Patry's avatar
Nicolas Patry committed
180
181
182
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
183

184
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
185
    // This will only be used to validate payloads
186
    let local_path = Path::new(&tokenizer_name);
187

188
189
    // Shared API builder initialization
    let api_builder = || {
190
191
192
193
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

194
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
195
196
197
            builder = builder.with_cache_dir(cache_dir.into());
        }

198
199
200
201
202
203
204
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
205
206
207
208
209
210
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
211
    let api = if use_api {
212
213
214
215
216
217
218
219
220
221
222
223
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
            let cache = Cache::default();
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
224
            }
225
        }
226
    } else {
227
        Type::None
228
229
230
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
231
232
233
234
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
235
        preprocessor_config_filename,
drbh's avatar
drbh committed
236
237
238
        processor_config_filename,
        model_info,
    ) = match api {
239
240
241
242
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
243
            Some(local_path.join("preprocessor_config.json")),
drbh's avatar
drbh committed
244
            Some(local_path.join("processor_config.json")),
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
260
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
drbh's avatar
drbh committed
261
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
262
263
264
265
266
267
268
269
270
271
272

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
273
                preprocessor_config_filename,
drbh's avatar
drbh committed
274
                processor_config_filename,
275
276
277
278
279
280
281
282
283
284
285
286
287
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
288
                repo.get("preprocessor_config.json"),
drbh's avatar
drbh committed
289
                repo.get("processor_config.json"),
290
291
292
293
294
295
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
296
297
            .ok()
            .as_ref()
298
299
300
301
302
303
304
305
306
307
308
309
310
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
311

312
313
314
315
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
316
    } else {
317
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
318
    };
319
320
321
322
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
323
324
325
326
327

    let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
        let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer {
            if let Some(class) = &tokenizer_config.tokenizer_class {
Nicolas Patry's avatar
Nicolas Patry committed
328
                if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
329
330
                    if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
                        tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
331
                        tokenizer.with_post_processor(post_processor);
332
                    }
333
                }
334
335
336
337
            }
        }
        tokenizer
    });
338

339
340
    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
drbh's avatar
drbh committed
341
342
343
344
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

345
    tracing::info!("Using config {config:?}");
346
347
348
349
350
351
352
353
354
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
355
            true
356
357
358
359
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
360
361
362
363
364
365
366
367
368
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

369
370
371
372
373
374
375
376
377
378
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

    // Run server
    server::run(
OlivierDehaene's avatar
OlivierDehaene committed
379
        master_shard_uds_path,
380
381
382
383
384
385
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
386
        max_input_tokens,
387
388
389
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
390
        max_batch_total_tokens,
391
        max_waiting_tokens,
392
        max_batch_size,
393
        tokenizer,
394
        config,
395
396
397
398
399
400
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
401
        tokenizer_config,
402
        preprocessor_config,
drbh's avatar
drbh committed
403
        processor_config,
404
        messages_api_enabled,
drbh's avatar
drbh committed
405
        disable_grammar_support,
406
        max_client_batch_size,
407
        print_schema_command,
408
409
410
    )
    .await?;
    Ok(())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
411
}
412
413
414

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
415
///     - otlp_service_name service name to appear in APM
416
417
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
418
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
419
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
420
421
422
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
423
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
424
425
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
426
        .with_ansi(ansi)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
450
                        otlp_service_name,
451
452
453
454
455
456
457
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
458
            init_tracing_opentelemetry::init_propagator().unwrap();
459
460
461
462
        };
    }

    // Filter events with LOG_LEVEL
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
478
479
480
481
482
483

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
484
485

/// get model info from the Huggingface Hub
486
487
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
488
489

    if response.status().is_success() {
490
491
492
493
494
495
496
497
498
499
500
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
501
    }
502
}
503

504
/// get base tokenizer
505
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

522
        api_base_repo.get("tokenizer.json").await.ok()
523
524
525
526
527
    } else {
        None
    }
}

528
529
530
531
532
533
534
535
536
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
537
538
539
540
541
542
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
543
544
545
546

    Some(tokenizer_config)
}

547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
    tokenizer: &Tokenizer,
    tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
    let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
    let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

    let bos_token = tokenizer_config.bos_token.as_ref();
    let eos_token = tokenizer_config.eos_token.as_ref();

    if add_bos_token && bos_token.is_none() {
        panic!("add_bos_token = true but bos_token is None");
    }

    if add_eos_token && eos_token.is_none() {
        panic!("add_eos_token = true but eos_token is None");
    }

    let mut single = Vec::new();
    let mut pair = Vec::new();
    let mut special_tokens = Vec::new();

    if add_bos_token {
        if let Some(bos) = bos_token {
            let bos_token_id = tokenizer
573
                .token_to_id(bos.as_str())
574
                .expect("Should have found the bos token id");
575
576
577
            special_tokens.push((bos.as_str(), bos_token_id));
            single.push(format!("{}:0", bos.as_str()));
            pair.push(format!("{}:0", bos.as_str()));
578
579
580
581
582
583
584
585
586
        }
    }

    single.push("$A:0".to_string());
    pair.push("$A:0".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            let eos_token_id = tokenizer
587
                .token_to_id(eos.as_str())
588
                .expect("Should have found the eos token id");
589
590
591
            special_tokens.push((eos.as_str(), eos_token_id));
            single.push(format!("{}:0", eos.as_str()));
            pair.push(format!("{}:0", eos.as_str()));
592
593
594
595
596
        }
    }

    if add_bos_token {
        if let Some(bos) = bos_token {
597
            pair.push(format!("{}:1", bos.as_str()));
598
599
600
601
602
603
604
        }
    }

    pair.push("$B:1".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
605
            pair.push(format!("{}:1", eos.as_str()));
606
607
608
609
610
611
612
613
614
615
616
617
        }
    }

    let post_processor = TemplateProcessing::builder()
        .try_single(single)?
        .try_pair(pair)?
        .special_tokens(special_tokens)
        .build()?;

    Ok(post_processor)
}

618
619
#[derive(Debug, Error)]
enum RouterError {
620
621
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
622
623
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
624
625
626
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}
627
628
629
630

#[cfg(test)]
mod tests {
    use super::*;
631
    use text_generation_router::TokenizerConfigToken;
632
633
634
635
636
637

    #[test]
    fn test_create_post_processor() {
        let tokenizer_config = HubTokenizerConfig {
            add_bos_token: None,
            add_eos_token: None,
638
639
            bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
            eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
640
641
642
643
644
645
646
647
648
649
            chat_template: None,
            tokenizer_class: None,
            completion_template: None,
        };

        let tokenizer =
            Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
        let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

        let expected = TemplateProcessing::builder()
650
            .try_single("<s>:0 $A:0")
651
            .unwrap()
652
            .try_pair("<s>:0 $A:0 <s>:1 $B:1")
653
654
655
656
657
658
659
660
            .unwrap()
            .special_tokens(vec![("<s>".to_string(), 1)])
            .build()
            .unwrap();

        assert_eq!(post_processor, expected);
    }
}