main.rs 22.4 KB
Newer Older
1
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
xuxzh1's avatar
last  
xuxzh1 committed
3
use clap::Subcommand;
4
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5
use hf_hub::{Cache, Repo, RepoType};
6
7
8
9
10
11
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
12
13
use std::fs::File;
use std::io::BufReader;
Olivier Dehaene's avatar
Olivier Dehaene committed
14
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15
use std::path::{Path, PathBuf};
16
use text_generation_router::config::Config;
xuxzh1's avatar
last  
xuxzh1 committed
17
18
19
use text_generation_router::{
    server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
20
use thiserror::Error;
xuxzh1's avatar
last  
xuxzh1 committed
21
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
22
use tower_http::cors::AllowOrigin;
23
24
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
xuxzh1's avatar
last  
xuxzh1 committed
25
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
28
29
30

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
xuxzh1's avatar
last  
xuxzh1 committed
31
32
33
    #[command(subcommand)]
    command: Option<Commands>,

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
35
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
36
37
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
38
39
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
40
41
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,
42
    #[clap(default_value = "1024", long, env)]
43
    max_input_tokens: usize,
44
    #[clap(default_value = "2048", long, env)]
45
    max_total_tokens: usize,
46
47
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
48
49
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
50
51
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
52
53
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
54
55
    #[clap(long, env)]
    max_batch_size: Option<usize>,
56
57
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
60
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
62
63
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
64
    #[clap(long, env)]
65
66
    tokenizer_config_path: Option<String>,
    #[clap(long, env)]
67
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
70
71
    #[clap(long, env)]
    json_output: bool,
72
73
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
xuxzh1's avatar
last  
xuxzh1 committed
74
75
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,
76
77
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
78
79
80
81
82
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
83
    ngrok_edge: Option<String>,
84
    #[clap(long, env, default_value_t = false)]
85
    messages_api_enabled: bool,
drbh's avatar
drbh committed
86
87
    #[clap(long, env, default_value_t = false)]
    disable_grammar_support: bool,
88
89
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
90
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91

xuxzh1's avatar
last  
xuxzh1 committed
92
93
94
95
96
#[derive(Debug, Subcommand)]
enum Commands {
    PrintSchema,
}

97
98
#[tokio::main]
async fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
99
    let args = Args::parse();
xuxzh1's avatar
last  
xuxzh1 committed
100

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
102
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
103
        max_concurrent_requests,
104
        max_best_of,
105
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
106
        max_top_n_tokens,
107
        max_input_tokens,
108
        max_total_tokens,
109
        waiting_served_ratio,
110
111
        max_batch_prefill_tokens,
        max_batch_total_tokens,
112
        max_waiting_tokens,
113
        max_batch_size,
114
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
115
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
116
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
117
        tokenizer_name,
118
        tokenizer_config_path,
119
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
120
        validation_workers,
121
        json_output,
122
        otlp_endpoint,
xuxzh1's avatar
last  
xuxzh1 committed
123
        otlp_service_name,
124
        cors_allow_origin,
125
126
        ngrok,
        ngrok_authtoken,
127
        ngrok_edge,
128
        messages_api_enabled,
drbh's avatar
drbh committed
129
        disable_grammar_support,
130
        max_client_batch_size,
xuxzh1's avatar
last  
xuxzh1 committed
131
        command,
Olivier Dehaene's avatar
Olivier Dehaene committed
132
133
    } = args;

xuxzh1's avatar
last  
xuxzh1 committed
134
135
136
137
138
139
140
141
    let print_schema_command = match command {
        Some(Commands::PrintSchema) => true,
        None => {
            // only init logging if we are not running the print schema command
            init_logging(otlp_endpoint, otlp_service_name, json_output);
            false
        }
    };
142

143
    // Validate args
144
    if max_input_tokens >= max_total_tokens {
145
        return Err(RouterError::ArgumentValidation(
146
            "`max_input_tokens` must be < `max_total_tokens`".to_string(),
147
148
        ));
    }
149
150
    if max_input_tokens as u32 > max_batch_prefill_tokens {
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
151
    }
152

153
    if validation_workers == 0 {
154
155
156
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
158
    }

159
160
161
162
163
164
165
166
167
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

168
169
170
171
172
173
174
175
176
177
178
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

179
    // Parse Huggingface hub token
xuxzh1's avatar
last  
xuxzh1 committed
180
181
182
    let authorization_token = std::env::var("HF_TOKEN")
        .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
        .ok();
183

184
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
185
    // This will only be used to validate payloads
186
    let local_path = Path::new(&tokenizer_name);
187

188
189
    // Shared API builder initialization
    let api_builder = || {
190
191
192
193
        let mut builder = ApiBuilder::new()
            .with_progress(false)
            .with_token(authorization_token);

194
        if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
195
196
197
            builder = builder.with_cache_dir(cache_dir.into());
        }

198
199
200
201
202
203
204
        builder
    };

    // Decide if we need to use the API based on the revision and local path
    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();

    // Initialize API if needed
205
206
207
208
209
210
    #[derive(Clone)]
    enum Type {
        Api(Api),
        Cache(Cache),
        None,
    }
211
    let api = if use_api {
212
213
214
215
216
217
218
219
220
221
222
223
        if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
            let cache = Cache::default();
            tracing::warn!("Offline mode active using cache defaults");
            Type::Cache(cache)
        } else {
            tracing::info!("Using the Hugging Face API");
            match api_builder().build() {
                Ok(api) => Type::Api(api),
                Err(_) => {
                    tracing::warn!("Unable to build the Hugging Face API");
                    Type::None
                }
224
            }
225
        }
226
    } else {
227
        Type::None
228
229
230
    };

    // Load tokenizer and model info
xuxzh1's avatar
last  
xuxzh1 committed
231
232
233
234
235
236
237
238
    let (
        tokenizer_filename,
        config_filename,
        tokenizer_config_filename,
        preprocessor_config_filename,
        processor_config_filename,
        model_info,
    ) = match api {
239
240
241
242
        Type::None => (
            Some(local_path.join("tokenizer.json")),
            Some(local_path.join("config.json")),
            Some(local_path.join("tokenizer_config.json")),
xuxzh1's avatar
last  
xuxzh1 committed
243
244
            Some(local_path.join("preprocessor_config.json")),
            Some(local_path.join("processor_config.json")),
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            None,
        ),
        Type::Api(api) => {
            let api_repo = api.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));

            let tokenizer_filename = match api_repo.get("tokenizer.json").await {
                Ok(tokenizer_filename) => Some(tokenizer_filename),
                Err(_) => get_base_tokenizer(&api, &api_repo).await,
            };
            let config_filename = api_repo.get("config.json").await.ok();
            let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
xuxzh1's avatar
last  
xuxzh1 committed
260
261
            let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
            let processor_config_filename = api_repo.get("processor_config.json").await.ok();
262
263
264
265
266
267
268
269
270
271
272

            let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
                Some(model_info)
            } else {
                tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                None
            };
            (
                tokenizer_filename,
                config_filename,
                tokenizer_config_filename,
xuxzh1's avatar
last  
xuxzh1 committed
273
274
                preprocessor_config_filename,
                processor_config_filename,
275
276
277
278
279
280
281
282
283
284
285
286
287
                model_info,
            )
        }
        Type::Cache(cache) => {
            let repo = cache.repo(Repo::with_revision(
                tokenizer_name.to_string(),
                RepoType::Model,
                revision.clone().unwrap_or_else(|| "main".to_string()),
            ));
            (
                repo.get("tokenizer.json"),
                repo.get("config.json"),
                repo.get("tokenizer_config.json"),
xuxzh1's avatar
last  
xuxzh1 committed
288
289
                repo.get("preprocessor_config.json"),
                repo.get("processor_config.json"),
290
291
292
293
294
295
                None,
            )
        }
    };
    let config: Option<Config> = config_filename.and_then(|filename| {
        std::fs::read_to_string(filename)
296
297
            .ok()
            .as_ref()
298
299
300
301
302
303
304
305
306
307
308
309
310
            .and_then(|c| {
                let config: Result<Config, _> = serde_json::from_str(c);
                if let Err(err) = &config {
                    tracing::warn!("Could not parse config {err:?}");
                }
                config.ok()
            })
    });
    let model_info = model_info.unwrap_or_else(|| HubModelInfo {
        model_id: tokenizer_name.to_string(),
        sha: None,
        pipeline_tag: None,
    });
311

312
313
314
315
    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
    {
        HubTokenizerConfig::from_file(filename)
316
    } else {
317
        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
318
    };
319
320
321
322
    let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
        tracing::warn!("Could not find tokenizer config locally and no API specified");
        HubTokenizerConfig::default()
    });
323

xuxzh1's avatar
last  
xuxzh1 committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
        let mut tokenizer = Tokenizer::from_file(filename).ok();
        if let Some(tokenizer) = &mut tokenizer {
            if let Some(class) = &tokenizer_config.tokenizer_class {
                if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
                    if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
                        tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
                        tokenizer.with_post_processor(post_processor);
                    }
                }
            }
        }
        tokenizer
    });

    let preprocessor_config =
        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
    let processor_config = processor_config_filename
        .and_then(HubProcessorConfig::from_file)
        .unwrap_or_default();

345
    tracing::info!("Using config {config:?}");
346
347
348
349
350
351
352
353
354
    if tokenizer.is_none() {
        tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
        tracing::warn!("Rust input length validation and truncation is disabled");
    }

    // if pipeline-tag == text-generation we default to return_full_text = true
    let compat_return_full_text = match &model_info.pipeline_tag {
        None => {
            tracing::warn!("no pipeline tag found for model {tokenizer_name}");
OlivierDehaene's avatar
OlivierDehaene committed
355
            true
356
357
358
359
        }
        Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
    };

drbh's avatar
drbh committed
360
361
362
363
364
365
366
367
368
    // Determine the server port based on the feature and environment variable.
    let port = if cfg!(feature = "google") {
        std::env::var("AIP_HTTP_PORT")
            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
            .unwrap_or(port)
    } else {
        port
    };

369
370
371
372
373
374
375
376
377
378
    let addr = match hostname.parse() {
        Ok(ip) => SocketAddr::new(ip, port),
        Err(_) => {
            tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
        }
    };

    // Run server
    server::run(
xuxzh1's avatar
last  
xuxzh1 committed
379
        master_shard_uds_path,
380
381
382
383
384
385
        model_info,
        compat_return_full_text,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_top_n_tokens,
386
        max_input_tokens,
387
388
389
        max_total_tokens,
        waiting_served_ratio,
        max_batch_prefill_tokens,
xuxzh1's avatar
last  
xuxzh1 committed
390
        max_batch_total_tokens,
391
        max_waiting_tokens,
392
        max_batch_size,
393
        tokenizer,
394
        config,
395
396
397
398
399
400
        validation_workers,
        addr,
        cors_allow_origin,
        ngrok,
        ngrok_authtoken,
        ngrok_edge,
401
        tokenizer_config,
xuxzh1's avatar
last  
xuxzh1 committed
402
403
        preprocessor_config,
        processor_config,
404
        messages_api_enabled,
drbh's avatar
drbh committed
405
        disable_grammar_support,
406
        max_client_batch_size,
xuxzh1's avatar
last  
xuxzh1 committed
407
        print_schema_command,
408
409
410
    )
    .await?;
    Ok(())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
411
}
412
413
414

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
xuxzh1's avatar
last  
xuxzh1 committed
415
///     - otlp_service_name service name to appear in APM
416
417
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
418
///     - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
xuxzh1's avatar
last  
xuxzh1 committed
419
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
420
421
422
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
423
    let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
424
425
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
426
        .with_ansi(ansi)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
xuxzh1's avatar
last  
xuxzh1 committed
450
                        otlp_service_name,
451
452
453
454
455
456
457
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
Nicolas Patry's avatar
Nicolas Patry committed
458
            init_tracing_opentelemetry::init_propagator().unwrap();
459
460
461
462
        };
    }

    // Filter events with LOG_LEVEL
xuxzh1's avatar
last  
xuxzh1 committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
478
479
480
481
482
483

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
484
485

/// get model info from the Huggingface Hub
486
487
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
    let response = api.info_request().send().await.ok()?;
488
489

    if response.status().is_success() {
490
491
492
493
494
495
496
497
498
499
500
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
501
    }
502
}
503

504
/// get base tokenizer
505
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    let config_filename = api_repo.get("config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of `User`.
    let config: serde_json::Value = serde_json::from_reader(reader).ok()?;

    if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
        let api_base_repo = api.repo(Repo::with_revision(
            base_model_id.to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

522
        api_base_repo.get("tokenizer.json").await.ok()
523
524
525
526
527
    } else {
        None
    }
}

528
529
530
531
532
533
534
535
536
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
    let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;

    // Open the file in read-only mode with buffer.
    let file = File::open(tokenizer_config_filename).ok()?;
    let reader = BufReader::new(file);

    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
537
538
539
540
541
542
    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
        .map_err(|e| {
            tracing::warn!("Unable to parse tokenizer config: {}", e);
            e
        })
        .ok()?;
543
544
545
546

    Some(tokenizer_config)
}

xuxzh1's avatar
last  
xuxzh1 committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
    tokenizer: &Tokenizer,
    tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
    let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
    let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

    let bos_token = tokenizer_config.bos_token.as_ref();
    let eos_token = tokenizer_config.eos_token.as_ref();

    if add_bos_token && bos_token.is_none() {
        panic!("add_bos_token = true but bos_token is None");
    }

    if add_eos_token && eos_token.is_none() {
        panic!("add_eos_token = true but eos_token is None");
    }

    let mut single = Vec::new();
    let mut pair = Vec::new();
    let mut special_tokens = Vec::new();

    if add_bos_token {
        if let Some(bos) = bos_token {
            let bos_token_id = tokenizer
                .token_to_id(bos.as_str())
                .expect("Should have found the bos token id");
            special_tokens.push((bos.as_str(), bos_token_id));
            single.push(format!("{}:0", bos.as_str()));
            pair.push(format!("{}:0", bos.as_str()));
        }
    }

    single.push("$A:0".to_string());
    pair.push("$A:0".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            let eos_token_id = tokenizer
                .token_to_id(eos.as_str())
                .expect("Should have found the eos token id");
            special_tokens.push((eos.as_str(), eos_token_id));
            single.push(format!("{}:0", eos.as_str()));
            pair.push(format!("{}:0", eos.as_str()));
        }
    }

    if add_bos_token {
        if let Some(bos) = bos_token {
            pair.push(format!("{}:1", bos.as_str()));
        }
    }

    pair.push("$B:1".to_string());

    if add_eos_token {
        if let Some(eos) = eos_token {
            pair.push(format!("{}:1", eos.as_str()));
        }
    }

    let post_processor = TemplateProcessing::builder()
        .try_single(single)?
        .try_pair(pair)?
        .special_tokens(special_tokens)
        .build()?;

    Ok(post_processor)
}

618
619
#[derive(Debug, Error)]
enum RouterError {
620
621
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
xuxzh1's avatar
last  
xuxzh1 committed
622
623
    #[error("WebServer error: {0}")]
    WebServer(#[from] server::WebServerError),
624
625
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
xuxzh1's avatar
last  
xuxzh1 committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
}

#[cfg(test)]
mod tests {
    use super::*;
    use text_generation_router::TokenizerConfigToken;

    #[test]
    fn test_create_post_processor() {
        let tokenizer_config = HubTokenizerConfig {
            add_bos_token: None,
            add_eos_token: None,
            bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
            eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
            chat_template: None,
            tokenizer_class: None,
            completion_template: None,
        };

        let tokenizer =
            Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
        let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

        let expected = TemplateProcessing::builder()
            .try_single("<s>:0 $A:0")
            .unwrap()
            .try_pair("<s>:0 $A:0 <s>:1 $B:1")
            .unwrap()
            .special_tokens(vec![("<s>".to_string(), 1)])
            .build()
            .unwrap();

        assert_eq!(post_processor, expected);
    }
660
}