"...models/task_modules/coders/centerpoint_bbox_coders.py" did not exist on "d154e24dd83435e031128e8b684ddb3e77d7b173"
main.rs 22.6 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
use clap::Subcommand;
4
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5
use hf_hub::{Cache, Repo, RepoType};
6
7
8
9
10
11
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
12
13
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15
use std::path::{Path, PathBuf};
16
use text_generation_router::config::Config;
17
18
19
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
20
use thiserror::Error;
Nicolas Patry's avatar
Nicolas Patry committed
21
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
22
use tower_http::cors::AllowOrigin;
23
24
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
25
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
28
29
30

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
31
32
33
    #[command(subcommand)]
    command: Option<Commands>,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
35
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
36
37
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
38
39
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
40
41
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
42
    #[clap(default_value = "1024", long, env)]
43
    max_input_tokens: usize,
44
    #[clap(default_value = "2048", long, env)]
45
    max_total_tokens: usize,
46
47
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
48
49
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
50
51
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
52
53
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
54
55
    #[clap(long, env)]
    max_batch_size: Option<usize>,
56
57
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
60
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
62
63
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
64
    #[clap(long, env)]
65
66
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
67
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
70
71
    #[clap(long, env)]
    json_output: bool,
72
73
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
74
75
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
76
77
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
78
79
80
81
82
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
83
    ngrok_edge: Option<String>,
84
    #[clap(long, env, default_value_t = false)]
85
    messages_api_enabled: bool,
drbh's avatar
drbh committed
86
87
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
88
89
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
90
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91

92
93
94
95
96
#[derive(Debug, Subcommand)]
enum Commands {
    PrintSchema,
}

97
98
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
99
    let args = Args::parse();
100

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
102
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
103
        max_concurrent_requests,
104
        max_best_of,
105
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
106
        max_top_n_tokens,
107
        max_input_tokens,
108
        max_total_tokens,
109
        waiting_served_ratio,
110
111
        max_batch_prefill_tokens,
        max_batch_total_tokens,
112
        max_waiting_tokens,
113
        max_batch_size,
114
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
115
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
116
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
117
        tokenizer_name,
118
        tokenizer_config_path,
119
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
        validation_workers,
121
        json_output,
122
        otlp_endpoint,
123
        otlp_service_name,
124
        cors_allow_origin,
125
126
        ngrok,
        ngrok_authtoken,
127
        ngrok_edge,
128
        messages_api_enabled,
drbh's avatar
drbh committed
129
        disable_grammar_support,
130
        max_client_batch_size,
131
        command,
Olivier Dehaene's avatar
Olivier Dehaene committed
132
133
    } = args;

134
135
136
137
138
139
140
141
    let print_schema_command = match command {
        Some(Commands::PrintSchema) => true,
        None => {
            // only init logging if we are not running the print schema command
            init_logging(otlp_endpoint, otlp_service_name, json_output);
            false
        }
    };
142

143
    // Validate args
144
    if max_input_tokens >= max_total_tokens {
145
        return Err(RouterError::ArgumentValidation(
146
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
147
148
        ));
    }
149
150
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
151
    }
152

153
    if validation_workers == 0 {
154
155
156
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
158
    }

159
160
161
162
163
164
165
166
167
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

168
169
170
171
172
173
174
175
176
177
178
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

179
    // Parse Huggingface hub token
Nicolas Patry's avatar
Nicolas Patry committed
180
181
182
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
183

184
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
185
    // This will only be used to validate payloads
186
    let local_path = Path::new(&tokenizer_name);
187

188
189
    // Shared API builder initialization
    let api_builder = || {
190
191
192
193
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

194
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
195
196
197
            builder = builder.with_cache_dir(cache_dir.into());
        }

198
199
200
201
202
203
204
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
205
206
207
208
209
210
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
211
    let api = if use_api {
212
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
213
214
215
216
217
            let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
                .map_err(|_| ())
                .map(|cache_dir| Cache::new(cache_dir.into()))
                .unwrap_or_else(|_| Cache::default());

218
219
220
221
222
223
224
225
226
227
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
228
            }
229
        }
230
    } else {
231
        Type::None
232
233
234
    };

    // Load tokenizer and model info
drbh's avatar
drbh committed
235
236
237
238
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
239
        preprocessor_config_filename,
drbh's avatar
drbh committed
240
241
242
        processor_config_filename,
        model_info,
    ) = match api {
243
244
245
246
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
247
            Some(local_path.join("preprocessor_config.json")),
drbh's avatar
drbh committed
248
            Some(local_path.join("processor_config.json")),
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
264
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
drbh's avatar
drbh committed
265
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
266
267
268
269
270
271
272
273
274
275
276

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
277
                preprocessor_config_filename,
drbh's avatar
drbh committed
278
                processor_config_filename,
279
280
281
282
283
284
285
286
287
288
289
290
291
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
292
                repo.get("preprocessor_config.json"),
drbh's avatar
drbh committed
293
                repo.get("processor_config.json"),
294
295
296
297
298
299
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
300
301
            .ok()
            .as_ref()
302
303
304
305
306
307
308
309
310
311
312
313
314
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
315

316
317
318
319
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
320
    } else {
321
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
322
    };
323
324
325
326
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
327
328
329
330
331

    let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
        let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer {
            if let Some(class) = &tokenizer_config.tokenizer_class {
Nicolas Patry's avatar
Nicolas Patry committed
332
                if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
333
334
                    if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
                        tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
335
                        tokenizer.with_post_processor(post_processor);
336
                    }
337
                }
338
339
340
341
            }
        }
        tokenizer
    });
342

343
344
    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
drbh's avatar
drbh committed
345
346
347
348
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

349
    tracing::info!("Using config {config:?}");
350
351
352
353
354
355
356
357
358
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
359
            true
360
361
362
363
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
364
365
366
367
368
369
370
371
372
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

373
374
375
376
377
378
379
380
381
382
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

    // Run server
    server::run(
OlivierDehaene's avatar
OlivierDehaene committed
383
        master_shard_uds_path,
384
385
386
387
388
389
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
390
        max_input_tokens,
391
392
393
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
394
        max_batch_total_tokens,
395
        max_waiting_tokens,
396
        max_batch_size,
397
        tokenizer,
398
        config,
399
400
401
402
403
404
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
405
        tokenizer_config,
406
        preprocessor_config,
drbh's avatar
drbh committed
407
        processor_config,
408
        messages_api_enabled,
drbh's avatar
drbh committed
409
        disable_grammar_support,
410
        max_client_batch_size,
411
        print_schema_command,
412
413
414
    )
    .await?;
    Ok(())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
415
}
416
417
418

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
419
///     - otlp_service_name service name to appear in APM
420
421
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
422
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
423
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
424
425
426
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
427
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
428
429
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
430
        .with_ansi(ansi)
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
454
                        otlp_service_name,
455
456
457
458
459
460
461
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
462
            init_tracing_opentelemetry::init_propagator().unwrap();
463
464
465
466
        };
    }

    // Filter events with LOG_LEVEL
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
482
483
484
485
486
487

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
488
489

/// get model info from the Huggingface Hub
490
491
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
492
493

    if response.status().is_success() {
494
495
496
497
498
499
500
501
502
503
504
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
505
    }
506
}
507

508
/// get base tokenizer
509
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

526
        api_base_repo.get("tokenizer.json").await.ok()
527
528
529
530
531
    } else {
        None
    }
}

532
533
534
535
536
537
538
539
540
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
541
542
543
544
545
546
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
547
548
549
550

    Some(tokenizer_config)
}

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
    tokenizer: &Tokenizer,
    tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
    let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
    let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

    let bos_token = tokenizer_config.bos_token.as_ref();
    let eos_token = tokenizer_config.eos_token.as_ref();

    if add_bos_token && bos_token.is_none() {
        panic!("add_bos_token = true but bos_token is None");
    }

    if add_eos_token && eos_token.is_none() {
        panic!("add_eos_token = true but eos_token is None");
    }

    let mut single = Vec::new();
    let mut pair = Vec::new();
    let mut special_tokens = Vec::new();

    if add_bos_token {
        if let Some(bos) = bos_token {
            let bos_token_id = tokenizer
577
                .token_to_id(bos.as_str())
578
                .expect("Should have found the bos token id");
579
580
581
            special_tokens.push((bos.as_str(), bos_token_id));
            single.push(format!("{}:0", bos.as_str()));
            pair.push(format!("{}:0", bos.as_str()));
582
583
584
585
586
587
588
589
590
        }
    }

    single.push("$A:0".to_string());
    pair.push("$A:0".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            let eos_token_id = tokenizer
591
                .token_to_id(eos.as_str())
592
                .expect("Should have found the eos token id");
593
594
595
            special_tokens.push((eos.as_str(), eos_token_id));
            single.push(format!("{}:0", eos.as_str()));
            pair.push(format!("{}:0", eos.as_str()));
596
597
598
599
600
        }
    }

    if add_bos_token {
        if let Some(bos) = bos_token {
601
            pair.push(format!("{}:1", bos.as_str()));
602
603
604
605
606
607
608
        }
    }

    pair.push("$B:1".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
609
            pair.push(format!("{}:1", eos.as_str()));
610
611
612
613
614
615
616
617
618
619
620
621
        }
    }

    let post_processor = TemplateProcessing::builder()
        .try_single(single)?
        .try_pair(pair)?
        .special_tokens(special_tokens)
        .build()?;

    Ok(post_processor)
}

622
623
#[derive(Debug, Error)]
enum RouterError {
624
625
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
OlivierDehaene's avatar
OlivierDehaene committed
626
627
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
628
629
630
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
}
631
632
633
634

#[cfg(test)]
mod tests {
    use super::*;
635
    use text_generation_router::TokenizerConfigToken;
636
637
638
639
640
641

    #[test]
    fn test_create_post_processor() {
        let tokenizer_config = HubTokenizerConfig {
            add_bos_token: None,
            add_eos_token: None,
642
643
            bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
            eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
644
645
646
647
648
649
650
651
652
653
            chat_template: None,
            tokenizer_class: None,
            completion_template: None,
        };

        let tokenizer =
            Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
        let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

        let expected = TemplateProcessing::builder()
654
            .try_single("<s>:0 $A:0")
655
            .unwrap()
656
            .try_pair("<s>:0 $A:0 <s>:1 $B:1")
657
658
659
660
661
662
663
664
            .unwrap()
            .special_tokens(vec![("<s>".to_string(), 1)])
            .build()
            .unwrap();

        assert_eq!(post_processor, expected);
    }
}