main.rs 73.2 KB
Newer Older
1
use clap::{Parser, ValueEnum};
Nicolas Patry's avatar
Nicolas Patry committed
2
3
4
5
use hf_hub::{
    api::sync::{Api, ApiBuilder},
    Repo, RepoType,
};
6
7
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
8
use regex::Regex;
9
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
10
use std::env;
11
use std::ffi::OsString;
12
use std::io::{BufRead, BufReader};
13
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
15
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
17
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
18
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
20
21
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
22
23
24
25
use std::{
    fs, io,
    io::{Read, Write},
};
26
use thiserror::Error;
27
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28

29
mod env_runtime;
30
mod gpu;
31

Nicolas Patry's avatar
Nicolas Patry committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option<usize> {
    if let (Some(config), Some(compute)) = (config, compute) {
        if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
            tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
            let optimal_size = (f16_max_compute / model_compute) as usize;
            if optimal_size > 100 {
                // Ignore calculations that's too low
                // Most likely an error
                Some(optimal_size)
            } else {
                None
            }
        } else {
            None
        }
    } else {
        None
    }
}

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
fn get_config(
    model_id: &str,
    revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
    let mut path = std::path::Path::new(model_id).to_path_buf();
    let model_id = model_id.to_string();
    let filename = if !path.exists() {
        // Assume it's a hub id

        let api = if let Ok(token) = std::env::var("HF_TOKEN") {
            // env variable has precedence over on file token.
            ApiBuilder::new().with_token(Some(token)).build()?
        } else {
            Api::new()?
        };
        let repo = if let Some(ref revision) = revision {
            api.repo(Repo::with_revision(
                model_id,
                RepoType::Model,
                revision.to_string(),
            ))
        } else {
            api.model(model_id)
        };
        repo.get("config.json")?
    } else {
        path.push("config.json");
        path
    };

    let content = std::fs::read_to_string(filename)?;
    let config: RawConfig = serde_json::from_str(&content)?;

    let config: Config = config.into();
    Ok(config)
}

fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
90
    let compute_capability = gpu::get_cuda_capability();
91
    let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
92
93
94
95
96
97
98
99
100
101
102
    let mut attention: Option<String> = std::env::var("ATTENTION").ok();
    if let Some(config) = config {
        if prefix_caching.is_none() {
            if config.vision_config.is_some() {
                tracing::info!("Disabling prefix caching because of VLM model");
                prefix_caching = Some("0".to_string());
            } else if config.is_encoder_decoder {
                tracing::info!("Disabling prefix caching because of seq2seq model");
                prefix_caching = Some("0".to_string());
            }
        }
103
104
105
106
107
108
109

        let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
            "paged"
        } else {
            "flashdecoding"
        };

110
111
112
113
114
115
116
        match config.head_dim {
            Some(h) if h == 64 || h == 128 || h == 256 => {
                if lora_adapters.is_some() && prefix_caching.is_none() {
                    tracing::info!("Disabling prefix caching because of lora adapters");
                    prefix_caching = Some("0".to_string());
                }
                match config.model_type.as_deref() {
Daniël de Kok's avatar
Daniël de Kok committed
117
                    Some("falcon") | Some("deepseek_v2") => {
118
119
120
121
                        // Required because gemma2 needs bfloat16 which is not supported by
                        // flashinfer ?
                        if attention.is_none() {
                            tracing::info!(
122
                                "Forcing attention to '{fallback_attention}' because model {} requires it",
123
124
                                config.model_type.as_ref().unwrap()
                            );
125
126
127
128
129
                            attention = Some(fallback_attention.to_string());
                        }
                        if fallback_attention == "paged" && prefix_caching.is_none() {
                            tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
                            prefix_caching = Some("0".to_string());
130
131
132
133
134
135
136
137
                        }
                    }
                    Some("t5") => {}
                    _ => {}
                }
            }
            _ => {
                if attention.is_none() {
138
139
                    tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
                    attention = Some(fallback_attention.to_string());
140
141
142
143
144
145
146
                }
                if prefix_caching.is_none() {
                    prefix_caching = Some("0".to_string());
                }
            }
        }
    }
147
148
149
150
    if attention == Some("paged".to_string()) && prefix_caching.is_none() {
        tracing::info!("Disabling prefix caching on paged attention");
        prefix_caching = Some("0".to_string());
    }
151

152
    let attention = attention.unwrap_or("flashinfer".to_string());
153
154
    let prefix_caching = prefix_caching.unwrap_or("true".to_string());

155
156
157
    (prefix_caching, attention)
}

158
#[derive(Deserialize)]
159
struct RawConfig {
160
    max_position_embeddings: Option<usize>,
161
    n_positions: Option<usize>,
162
    model_type: Option<String>,
163
    max_seq_len: Option<usize>,
164
    quantization_config: Option<QuantizationConfig>,
165
166
    n_embd: Option<usize>,
    hidden_size: Option<usize>,
Nicolas Patry's avatar
Nicolas Patry committed
167
    intermediate_size: Option<usize>,
168
    num_attention_heads: Option<usize>,
Nicolas Patry's avatar
Nicolas Patry committed
169
170
    num_key_value_heads: Option<usize>,
    num_hidden_layers: Option<usize>,
171
172
173
    head_dim: Option<usize>,
    vision_config: Option<VisionConfig>,
    is_encoder_decoder: Option<bool>,
Nicolas Patry's avatar
Nicolas Patry committed
174
    #[serde(rename = "num_experts_per_tok")]
Nicolas Patry's avatar
Nicolas Patry committed
175
176
177
    num_experts_per_token: Option<usize>,
    #[serde(rename = "n_shared_experts")]
    num_shared_experts: Option<usize>,
178
179
180
181
182
}

#[derive(Deserialize)]
struct QuantizationConfig {
    quant_method: Option<Quantization>,
183
184
}

Nicolas Patry's avatar
Nicolas Patry committed
185
#[derive(Debug, Deserialize)]
186
187
struct VisionConfig {}

Nicolas Patry's avatar
Nicolas Patry committed
188
#[derive(Debug, Deserialize)]
189
190
struct Config {
    max_position_embeddings: Option<usize>,
191
    quantize: Option<Quantization>,
192
    head_dim: Option<usize>,
Nicolas Patry's avatar
Nicolas Patry committed
193
194
195
196
197
    num_heads: Option<usize>,
    num_kv_heads: Option<usize>,
    num_layers: Option<usize>,
    intermediate_size: Option<usize>,
    hidden_size: Option<usize>,
198
199
200
    model_type: Option<String>,
    vision_config: Option<VisionConfig>,
    is_encoder_decoder: bool,
Nicolas Patry's avatar
Nicolas Patry committed
201
202
    num_experts_per_token: usize,
    num_shared_experts: usize,
Nicolas Patry's avatar
Nicolas Patry committed
203
204
205
206
207
208
209
210
211
212
213
214
215
}

impl Config {
    fn flop(&self) -> Option<u64> {
        if self.vision_config.is_some() {
            // VLM are much harder to predict and VRAM requirements
            // Are more complex.
            return None;
        }
        let num_heads = self.num_heads? as u64;
        let num_kv_heads = self.num_kv_heads? as u64;
        let head_dim = self.head_dim? as u64;
        let hidden_size = self.hidden_size? as u64;
Nicolas Patry's avatar
Nicolas Patry committed
216
217
218
        let intermediate_size = (self.intermediate_size?
            * (self.num_experts_per_token + self.num_shared_experts))
            as u64;
Nicolas Patry's avatar
Nicolas Patry committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        let num_layers = self.num_layers? as u64;

        let q_flops = 2 * num_heads * head_dim * hidden_size;
        let k_flops = 2 * num_kv_heads * head_dim * hidden_size;
        let v_flops = 2 * num_kv_heads * head_dim * hidden_size;
        let attn_flops = 2 * num_heads * head_dim * hidden_size;
        let o_flops = 2 * num_heads * head_dim * hidden_size;
        let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops;

        let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;

        let layer_flops = attn_layer_flops + gate_up_down_flops;
        let total = layer_flops * num_layers;
        Some(total)
    }
234
235
236
237
238
239
240
241
}

impl From<RawConfig> for Config {
    fn from(other: RawConfig) -> Self {
        let max_position_embeddings = other
            .max_position_embeddings
            .or(other.max_seq_len)
            .or(other.n_positions);
242
        let quantize = other.quantization_config.and_then(|q| q.quant_method);
Nicolas Patry's avatar
Nicolas Patry committed
243
244
245
246
247
        let hidden_size = other.hidden_size.or(other.n_embd);
        let head_dim = other
            .head_dim
            .or_else(|| match (hidden_size, other.num_attention_heads) {
                (Some(hidden_size), Some(num_attention_heads))
248
249
250
251
252
                    if hidden_size % num_attention_heads == 0 =>
                {
                    Some(hidden_size / num_attention_heads)
                }
                _ => None,
Nicolas Patry's avatar
Nicolas Patry committed
253
254
255
256
257
            });
        let num_heads = other.num_attention_heads;
        let num_layers = other.num_hidden_layers;
        let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
        let intermediate_size = other.intermediate_size;
258
259
260
        let model_type = other.model_type;
        let vision_config = other.vision_config;
        let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Nicolas Patry's avatar
Nicolas Patry committed
261
262
        let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
        let num_shared_experts = other.num_shared_experts.unwrap_or(0);
263
264
        Config {
            max_position_embeddings,
265
            quantize,
266
267
268
269
            head_dim,
            model_type,
            vision_config,
            is_encoder_decoder,
Nicolas Patry's avatar
Nicolas Patry committed
270
271
272
273
274
            hidden_size,
            num_heads,
            num_kv_heads,
            intermediate_size,
            num_layers,
Nicolas Patry's avatar
Nicolas Patry committed
275
276
            num_experts_per_token,
            num_shared_experts,
277
278
279
280
        }
    }
}

281
282
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
283
enum Quantization {
284
    /// 4 bit quantization. Requires a specific AWQ quantized model:
285
    ///   <https://hf.co/models?search=awq>.
286
    /// Should replace GPTQ models wherever possible because of the better latency
287
    Awq,
288
289
    /// Compressed tensors, which can be a mixture of different quantization methods.
    CompressedTensors,
290
291
    /// 8 bit quantization, doesn't require specific model.
    /// Should be a drop-in replacement to bitsandbytes with much better performance.
292
    /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
293
    Eetq,
294
295
296
297
    /// Variable bit quantization. Requires a specific EXL2 quantized model:
    /// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
    /// not support tensor parallelism (num_shard > 1).
    Exl2,
298
    /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
299
    /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
300
301
302
    /// triton kernel (wider support) when it's not.
    /// AWQ has faster kernels.
    Gptq,
303
304
    /// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
    Marlin,
305
306
    /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
    /// but it is known that the model will be much slower to run than the native f16.
307
308
309
310
    // #[deprecated(
    //     since = "1.1.0",
    //     note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
    // )]
311
    Bitsandbytes,
312
313
    /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
    /// but it is known that the model will be much slower to run than the native f16.
314
    BitsandbytesNf4,
315
316
    /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
    /// perplexity performance for you model
317
    BitsandbytesFp4,
Nicolas Patry's avatar
Nicolas Patry committed
318
319
320
321
322
    /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
    /// This dtype has native ops should be the fastest if available.
    /// This is currently not the fastest because of local unpacking + padding to satisfy matrix
    /// multiplication limitations.
    Fp8,
323
324
325
326
327
328
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
329
330
            #[allow(deprecated)]
            // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases
331
332
333
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
334
            Quantization::BitsandbytesNf4 => {
Nicolas Patry's avatar
Nicolas Patry committed
335
336
                write!(f, "bitsandbytes-nf4")
            }
337
            Quantization::BitsandbytesFp4 => {
Nicolas Patry's avatar
Nicolas Patry committed
338
339
                write!(f, "bitsandbytes-fp4")
            }
340
341
342
            Quantization::Exl2 => {
                write!(f, "exl2")
            }
343
344
345
            Quantization::Gptq => {
                write!(f, "gptq")
            }
346
347
348
            Quantization::Marlin => {
                write!(f, "marlin")
            }
349
350
351
            Quantization::Awq => {
                write!(f, "awq")
            }
352
353
354
            Quantization::CompressedTensors => {
                write!(f, "compressed-tensors")
            }
355
356
357
            Quantization::Eetq => {
                write!(f, "eetq")
            }
Nicolas Patry's avatar
Nicolas Patry committed
358
359
360
            Quantization::Fp8 => {
                write!(f, "fp8")
            }
361
362
363
364
        }
    }
}

365
366
367
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
368
    #[clap(name = "bfloat16")]
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

386
387
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
388
389
390
    #[clap(name = "fp8_e4m3fn")]
    Fp8e4m3fn,

391
392
393
394
395
396
397
    #[clap(name = "fp8_e5m2")]
    Fp8e5m2,
}

impl std::fmt::Display for KVCacheDtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
398
399
400
            KVCacheDtype::Fp8e4m3fn => {
                write!(f, "fp8_e4m3fn")
            }
401
402
403
404
405
406
407
            KVCacheDtype::Fp8e5m2 => {
                write!(f, "fp8_e5m2")
            }
        }
    }
}

Nicolas Patry's avatar
Nicolas Patry committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
    Linear,
    Dynamic,
}

impl std::fmt::Display for RopeScaling {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            RopeScaling::Linear => {
                write!(f, "linear")
            }
            RopeScaling::Dynamic => {
                write!(f, "dynamic")
            }
        }
    }
}

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
#[derive(Clone, Copy, Debug, ValueEnum)]
pub enum UsageStatsLevel {
    /// Default option, usage statistics are collected anonymously
    On,
    /// Disables all collection of usage statistics
    Off,
    /// Doesn't send the error stack trace or error type, but allows sending a crash event
    NoStack,
}

impl std::fmt::Display for UsageStatsLevel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            UsageStatsLevel::On => {
                write!(f, "on")
            }
            UsageStatsLevel::Off => {
                write!(f, "off")
            }
            UsageStatsLevel::NoStack => {
                write!(f, "no-stack")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
455
456
457
458
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
459
460
461
462
463
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
464
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
465
    model_id: String,
466
467
468

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
469
    #[clap(long, env)]
470
    revision: Option<String>,
471

472
473
474
475
476
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

477
    /// Whether to shard the model across multiple GPUs
478
479
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
480
481
    #[clap(long, env)]
    sharded: Option<bool>,
482
483

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
484
485
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
486
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
487
488
    #[clap(long, env)]
    num_shard: Option<usize>,
489

490
491
492
493
494
    /// Quantization method to use for the model. It is not necessary to specify this option
    /// for pre-quantized models, since the quantization method is read from the model
    /// configuration.
    ///
    /// Marlin kernels will be used automatically for GPTQ/AWQ models.
495
496
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
497

Nicolas Patry's avatar
Nicolas Patry committed
498
499
500
501
502
503
504
    /// The number of input_ids to speculate on
    /// If using a medusa model, the heads will be picked up automatically
    /// Other wise, it will use n-gram speculation which is relatively free
    /// in terms of compute, but the speedup heavily depends on the task.
    #[clap(long, env)]
    speculate: Option<usize>,

505
506
507
508
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

509
510
    /// Specify the dtype for the key-value cache. When this option is not provided,
    /// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
511
    /// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
512
513
514
    #[clap(long, env, value_enum)]
    kv_cache_dtype: Option<KVCacheDtype>,

515
516
517
518
519
520
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

521
522
523
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
524
525
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
526
527
528
529

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
530
531
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
532
533
534
535
536
537

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
538
539
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
540

Nicolas Patry's avatar
Nicolas Patry committed
541
    /// This is the maximum allowed value for clients to set `top_n_tokens`.
542
    /// `top_n_tokens` is used to return information about the the `n` most likely
Nicolas Patry's avatar
Nicolas Patry committed
543
544
545
546
547
548
    /// tokens at each generation step, instead of just the sampled token. This
    /// information can be used for downstream tasks like for classification or
    /// ranking.
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,

549
550
551
552
    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
553
    /// Default to min(max_allocatable, max_position_embeddings) - 1
554
555
556
557
558
559
    #[clap(long, env)]
    max_input_tokens: Option<usize>,

    /// Legacy version of [`Args::max_input_tokens`].
    #[clap(long, env)]
    max_input_length: Option<usize>,
560
561
562
563
564
565
566
567
568

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
569
    /// Default to min(max_allocatable, max_position_embeddings)
570
571
    #[clap(long, env)]
    max_total_tokens: Option<usize>,
572
573
574
575
576
577
578
579
580
581
582

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
583
    #[clap(default_value = "0.3", long, env)]
584
    waiting_served_ratio: f32,
585

586
587
588
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
589
590
591
    /// Default to `max_input_tokens + 50` to give a bit of room.
    #[clap(long, env)]
    max_batch_prefill_tokens: Option<u32>,
592

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
610
611
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
630
631
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
632

633
634
635
636
637
    /// Enforce a maximum number of requests per batch
    /// Specific flag for hardware targets that do not support unpadded inference
    #[clap(long, env)]
    max_batch_size: Option<usize>,

638
639
    /// Specify the batch sizes to compute cuda graphs for.
    /// Use "0" to disable.
640
641
642
    /// Default = "1,2,4,8,16,32"
    #[clap(long, env, value_delimiter = ',')]
    cuda_graphs: Option<Vec<usize>>,
643

644
645
646
647
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

648
    /// The port to listen on.
649
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
650
    port: u16,
651
652
653

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
654
655
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
656
657

    /// The address the master shard will listen on. (setting used by torch distributed)
658
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
659
    master_addr: String,
660
661

    /// The address the master port will listen on. (setting used by torch distributed)
662
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
663
    master_port: usize,
664
665
666

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
667
    #[clap(long, env)]
668
    huggingface_hub_cache: Option<String>,
669
670
671

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
672
673
    #[clap(long, env)]
    weights_cache_override: Option<String>,
674
675
676
677
678

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
679
    #[clap(long, env)]
680
    disable_custom_kernels: bool,
681

682
683
684
685
686
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

Nicolas Patry's avatar
Nicolas Patry committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    /// Rope scaling will only be used for RoPE models
    /// and allow rescaling the position rotary to accomodate for
    /// larger prompts.
    ///
    /// Goes together with `rope_factor`.
    ///
    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
    /// basically)
    ///
    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want
    #[clap(long, env)]
    rope_scaling: Option<RopeScaling>,

    /// Rope scaling will only be used for RoPE models
    /// See `rope_scaling`
    #[clap(long, env)]
    rope_factor: Option<f32>,

707
    /// Outputs the logs in JSON format (useful for telemetry)
708
    #[clap(long, env)]
709
    json_output: bool,
710

711
712
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
713

714
715
716
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,

717
718
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
Erik Kaunismäki's avatar
Erik Kaunismäki committed
719
720
721
722

    #[clap(long, env)]
    api_key: Option<String>,

723
724
725
726
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
727

728
729
730
731
732
733
734
735
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

736
    /// ngrok edge
737
    #[clap(long, env)]
738
    ngrok_edge: Option<String>,
739

740
741
742
743
744
    /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
    /// include a `chat_template`. If not provided, the default config will be used from the model hub.
    #[clap(long, env)]
    tokenizer_config_path: Option<String>,

drbh's avatar
drbh committed
745
746
747
748
749
    /// Disable outlines grammar constrained generation.
    /// This is a feature that allows you to generate text that follows a specific grammar.
    #[clap(long, env)]
    disable_grammar_support: bool,

750
751
752
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
753
754
755
756

    /// Control the maximum number of inputs that a client can send in a single request
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
drbh's avatar
drbh committed
757
758
759
760
761

    /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
    /// startup that will be available to callers via the `adapter_id` field in a request.
    #[clap(long, env)]
    lora_adapters: Option<String>,
762

763
764
765
766
767
    /// Control if anonymous usage stats are collected.
    /// Options are "on", "off" and "no-stack"
    /// Defaul is on.
    #[clap(default_value = "on", long, env)]
    usage_stats: UsageStatsLevel,
768
769
770
771
772
773

    /// Payload size limit in bytes
    ///
    /// Default is 2MB
    #[clap(default_value = "2000000", long, env)]
    payload_limit: usize,
Nicolas Patry's avatar
Nicolas Patry committed
774
775
776
777
778
779
780
781

    /// Enables prefill logprobs
    ///
    /// Logprobs in the prompt are deactivated by default because they consume
    /// a large amount of VRAM (especially for long prompts).
    /// Using this flag reallows users to ask for them.
    #[clap(long, env)]
    enable_prefill_logprobs: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
782
783
}

784
785
786
#[derive(Debug)]
enum ShardStatus {
    Ready,
787
    Failed(usize),
788
}
789

790
791
792
793
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
794
    quantize: Option<Quantization>,
Nicolas Patry's avatar
Nicolas Patry committed
795
    speculate: Option<usize>,
796
    dtype: Option<Dtype>,
797
    kv_cache_dtype: Option<KVCacheDtype>,
798
    trust_remote_code: bool,
799
800
801
802
803
804
805
806
807
808
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
809
    cuda_graphs: Vec<usize>,
810
    cuda_memory_fraction: f32,
Nicolas Patry's avatar
Nicolas Patry committed
811
812
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
813
    max_total_tokens: Option<usize>,
814
    max_batch_size: Option<usize>,
815
    max_input_tokens: Option<usize>,
drbh's avatar
drbh committed
816
    lora_adapters: Option<String>,
Nicolas Patry's avatar
Nicolas Patry committed
817
    enable_prefill_logprobs: bool,
818
    otlp_endpoint: Option<String>,
819
    otlp_service_name: String,
820
    log_level: LevelFilter,
821
    status_sender: mpsc::Sender<ShardStatus>,
822
    shutdown: Arc<AtomicBool>,
823
824
    _shutdown_sender: mpsc::Sender<()>,
) {
825
826
827
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

828
829
830
831
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
832
833
834
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
835
836

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
837
    let mut shard_args = vec![
838
839
840
841
842
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
843
        log_level.to_string().to_uppercase(),
844
845
846
        "--json-output".to_string(),
    ];

847
848
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
849
        shard_args.push("--trust-remote-code".to_string());
850
851
    }

852
853
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
854
        shard_args.push("--sharded".to_string());
855
856
    }

857
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
858
859
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
860
    }
861

Nicolas Patry's avatar
Nicolas Patry committed
862
863
864
865
866
    if let Some(speculate) = speculate {
        shard_args.push("--speculate".to_string());
        shard_args.push(speculate.to_string())
    }

867
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
868
869
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
870
871
    }

872
873
874
875
876
    if let Some(kv_cache_dtype) = kv_cache_dtype {
        shard_args.push("--kv-cache-dtype".to_string());
        shard_args.push(kv_cache_dtype.to_string())
    }

877
878
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
879
880
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
881
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
882

Nicolas Patry's avatar
Nicolas Patry committed
883
884
885
886
887
888
    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };
889

890
    // OpenTelemetry Endpoint
891
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
892
893
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
894
895
    }

896
897
898
899
    // OpenTelemetry Service Name
    shard_args.push("--otlp-service-name".to_string());
    shard_args.push(otlp_service_name);

900
    // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
901
902
903
904
    if let Some(max_input_tokens) = max_input_tokens {
        shard_args.push("--max-input-tokens".to_string());
        shard_args.push(max_input_tokens.to_string());
    }
905

906
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
907
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
908

909
910
911
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

912
    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
913
914
915
916
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
917
    envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
918

919
920
921
922
923
924
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

925
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
926
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
927

928
929
930
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

931
932
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
933
    envs.push((
934
935
936
937
938
939
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
940
        envs.push(("HF_TOKEN".into(), api_token.into()))
941
942
    };

Nicolas Patry's avatar
Nicolas Patry committed
943
944
945
946
947
948
949
950
951
    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

952
953
954
955
956
957
    if let Some(max_total_tokens) = max_total_tokens {
        envs.push((
            "MAX_TOTAL_TOKENS".into(),
            max_total_tokens.to_string().into(),
        ));
    }
958
959
960
961
    if let Some(max_batch_size) = max_batch_size {
        envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
    }

drbh's avatar
drbh committed
962
963
964
965
966
    // Lora Adapters
    if let Some(lora_adapters) = lora_adapters {
        envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
    }

Nicolas Patry's avatar
Nicolas Patry committed
967
968
969
970
971
    // Logprobs
    if enable_prefill_logprobs {
        envs.push(("REQUEST_LOGPROBS".into(), "1".into()));
    }

972
973
974
    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
975
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
976
977
978
979
980
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
981
        envs.push((
982
983
984
985
986
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

987
    // Enable experimental support for cuda graphs
988
989
990
991
992
993
994
995
996
997
    if !cuda_graphs.is_empty() {
        envs.push((
            "CUDA_GRAPHS".into(),
            cuda_graphs
                .into_iter()
                .map(|c| c.to_string())
                .collect::<Vec<_>>()
                .join(",")
                .into(),
        ));
998
999
    }

1000
1001
    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
1002
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
1003
1004
1005
1006
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
1007
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
1008
1009
1010
1011
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
1012
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
1013
1014
1015
    }

    // Start process
1016
    tracing::info!("Starting shard");
1017
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
1018
        .args(shard_args)
1019
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
1020
        .envs(envs)
1021
        .stdin(Stdio::piped())
1022
1023
1024
1025
1026
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
1027
1028
        Ok(p) => p,
        Err(err) => {
1029
1030
1031
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
1032
1033
            }
            {
1034
                tracing::error!("{}", err);
1035
            }
1036

1037
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
1038
1039
1040
1041
1042
            return;
        }
    };

    // Redirect STDOUT to the console
1043
    let mut pstdin = p.stdin.take().unwrap();
1044
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
1045
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
1046

1047
    //stdout tracing thread
1048
    thread::spawn(move || {
1049
        log_lines(shard_stdout_reader);
1050
    });
1051
1052
1053
    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
1054
        for line in shard_stderr_reader.lines().map_while(Result::ok) {
1055
1056
1057
            err_sender.send(line).unwrap_or(());
        }
    });
1058
    // We read stdin in another thread as it seems that lines() can block in some cases
Nicolas Patry's avatar
Nicolas Patry committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
    if LevelFilter::current() >= tracing::Level::DEBUG {
        thread::spawn(move || {
            let mut stdin = io::stdin(); // We get `Stdin` here.
            loop {
                let mut buffer = vec![0; 4096];
                if let Ok(n) = stdin.read(&mut buffer) {
                    if n > 0 {
                        let _ = pstdin.write_all(&buffer[..n]);
                    }
1068
1069
                }
            }
Nicolas Patry's avatar
Nicolas Patry committed
1070
1071
        });
    }
1072
1073
1074
1075
1076
1077

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
1078
        if let Some(exit_status) = p.try_wait().unwrap() {
1079
1080
1081
1082
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
1083

1084
            tracing::error!("Shard complete standard error output:\n{err}");
1085

1086
            if let Some(signal) = exit_status.signal() {
1087
1088
1089
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

1090
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
1091
1092
1093
1094
            return;
        }

        // We received a shutdown signal
1095
        if shutdown.load(Ordering::SeqCst) {
1096
            terminate("shard", p, Duration::from_secs(90)).unwrap();
1097
1098
1099
1100
1101
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
1102
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
1103
1104
1105
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
1106
            tracing::info!("Waiting for shard to be ready...");
1107
1108
1109
1110
1111
1112
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

1113
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
1114
1115
1116
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
1117
    shutdown.store(true, Ordering::SeqCst);
1118
1119
1120
1121
1122
1123
1124

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
1125
1126
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
1127
1128
1129
        Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
            Ok(devices) => devices,
            Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
Nicolas Patry's avatar
Nicolas Patry committed
1130
        },
1131
    };
1132
1133
    let n_devices = devices.split(',').count();
    Some(n_devices)
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
1167
1168
1169
1170
1171
1172
1173
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
1174
1175
1176
1177
        }
    }
}

1178
impl TryFrom<&[u8]> for PythonLogMessage {
1179
1180
    type Error = serde_json::Error;

1181
1182
    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        serde_json::from_slice::<Self>(value)
1183
1184
1185
    }
}

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
    let mut buffer = vec![0u8; 8 * 4096];
    let mut stdout = std::io::stdout();
    loop {
        let n = bufread.read(&mut buffer);
        if let Ok(n) = n {
            if n > 0 {
                let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
                while let Some(line) = lines.next() {
                    match PythonLogMessage::try_from(line) {
                        Ok(log) => log.trace(),
                        // For interactive debugging ?
                        Err(_) => {
1199
1200
1201
1202
1203
1204
                            if LevelFilter::current() >= tracing::Level::DEBUG {
                                stdout.write_all(line).unwrap();
                                if lines.peek().is_some() {
                                    stdout.write_all(b"\n").unwrap();
                                }
                                stdout.flush().unwrap();
1205
1206
1207
1208
                            }
                        }
                    }
                }
1209
1210
            } else {
                break;
1211
            }
1212
1213
1214
1215
        }
    }
}

1216
1217
1218
1219
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
1220
1221
1222
1223
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
1224
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
1225
            let n_devices = num_cuda_devices()
1226
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
1227
            if n_devices <= 1 {
1228
1229
1230
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
1231
            }
1232
            n_devices
1233
        }
1234
1235
1236
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
1237
1238
1239
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
1240
1241
            }
            num_shard
1242
        }
1243
1244
1245
1246
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
1247
    };
1248
    if num_shard < 1 {
1249
1250
1251
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
1252
    }
1253
    Ok(num_shard)
1254
}
1255

1256
#[derive(Debug, Error)]
1257
enum LauncherError {
1258
    #[error("Invalid argument: {0}")]
1259
    ArgumentValidation(String),
1260
    #[error("not enough cuda devices: {0}")]
1261
    NotEnoughCUDADevices(String),
1262
    #[error("Download error")]
1263
    DownloadError,
1264
    #[error("Shard cannot start")]
1265
    ShardCannotStart,
1266
    #[error("Shard disconnected")]
1267
    ShardDisconnected,
1268
    #[error("Shard failed")]
1269
    ShardFailed,
1270
    #[error("Webserver failed")]
1271
    WebserverFailed,
1272
    #[error("Webserver cannot start")]
1273
1274
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1275

1276
1277
1278
1279
1280
1281
1282
fn download_convert_model(
    model_id: &str,
    revision: Option<&str>,
    trust_remote_code: bool,
    huggingface_hub_cache: Option<&str>,
    weights_cache_override: Option<&str>,
    running: Arc<AtomicBool>,
1283
    merge_lora: bool,
1284
) -> Result<(), LauncherError> {
1285
1286
1287
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
1288
    let mut download_args = vec![
1289
        "download-weights".to_string(),
1290
        model_id.to_string(),
1291
1292
1293
1294
1295
1296
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
1297

1298
1299
1300
1301
    if merge_lora {
        download_args.push("--merge-lora".to_string());
    }

1302
    // Model optional revision
1303
    if let Some(revision) = &revision {
OlivierDehaene's avatar
OlivierDehaene committed
1304
1305
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
1306
    }
1307

1308
    // Trust remote code for automatic peft fusion
1309
    if trust_remote_code {
1310
1311
1312
        download_args.push("--trust-remote-code".to_string());
    }

1313
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1314
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1315

1316
1317
1318
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

1319
1320
1321
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

1322
    // If huggingface_hub_cache is set, pass it to the download process
1323
    // Useful when running inside a docker container
1324
    if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
1325
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
1326
    };
1327

1328
1329
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
1330
    envs.push((
1331
1332
1333
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
1334

1335
1336
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1337
        envs.push(("HF_TOKEN".into(), api_token.into()))
1338
    };
1339

1340
1341
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
1342
    if let Some(weights_cache_override) = &weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
1343
        envs.push((
1344
1345
1346
1347
1348
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

1349
    // Start process
1350
    tracing::info!("Starting check and download process for {model_id}");
1351
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
1352
        .args(download_args)
1353
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
1354
        .envs(envs)
1355
1356
1357
1358
1359
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
1360
1361
        Ok(p) => p,
        Err(err) => {
1362
1363
1364
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
1365
1366
            } else {
                tracing::error!("{}", err);
1367
            }
1368

1369
1370
1371
            return Err(LauncherError::DownloadError);
        }
    };
1372

1373
    let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
1374

1375
    thread::spawn(move || {
1376
        log_lines(download_stdout);
1377
1378
1379
1380
1381
1382
1383
    });

    let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
1384
        for line in download_stderr.lines().map_while(Result::ok) {
1385
1386
            err_sender.send(line).unwrap_or(());
        }
1387
    });
1388

1389
    loop {
1390
1391
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
1392
                tracing::info!("Successfully downloaded weights for {model_id}");
1393
                break;
1394
            }
1395
1396

            let mut err = String::new();
1397
1398
1399
1400
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }

1401
1402
1403
1404
1405
1406
1407
1408
1409
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
1410
        }
1411
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
1412
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
1413
1414
1415
            return Ok(());
        }
        sleep(Duration::from_millis(100));
1416
    }
1417
1418
    Ok(())
}
1419

1420
#[allow(clippy::too_many_arguments)]
1421
1422
1423
fn spawn_shards(
    num_shard: usize,
    args: &Args,
1424
    cuda_graphs: Vec<usize>,
1425
1426
    max_total_tokens: Option<usize>,
    max_input_tokens: Option<usize>,
1427
    quantize: Option<Quantization>,
1428
    max_log_level: LevelFilter,
1429
    shutdown: Arc<AtomicBool>,
1430
1431
1432
1433
1434
1435
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1436
1437
    // Start shard processes
    for rank in 0..num_shard {
1438
1439
1440
1441
1442
1443
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1444
1445
1446
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
1447
        let otlp_endpoint = args.otlp_endpoint.clone();
1448
        let otlp_service_name = args.otlp_service_name.clone();
Nicolas Patry's avatar
Nicolas Patry committed
1449
        let speculate = args.speculate;
1450
        let dtype = args.dtype;
1451
        let kv_cache_dtype = args.kv_cache_dtype;
1452
        let trust_remote_code = args.trust_remote_code;
1453
1454
1455
1456
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
1457
        let cuda_graphs_clone = cuda_graphs.clone();
1458
        let cuda_memory_fraction = args.cuda_memory_fraction;
Nicolas Patry's avatar
Nicolas Patry committed
1459
1460
        let rope_scaling = args.rope_scaling;
        let rope_factor = args.rope_factor;
1461
        let max_batch_size = args.max_batch_size;
drbh's avatar
drbh committed
1462
        let lora_adapters = args.lora_adapters.clone();
Nicolas Patry's avatar
Nicolas Patry committed
1463
        let enable_prefill_logprobs = args.enable_prefill_logprobs;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1464
1465
        thread::spawn(move || {
            shard_manager(
1466
                model_id,
1467
                revision,
1468
                quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1469
                speculate,
1470
                dtype,
1471
                kv_cache_dtype,
1472
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1473
1474
1475
1476
1477
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
1478
1479
                huggingface_hub_cache,
                weights_cache_override,
1480
                disable_custom_kernels,
1481
1482
                watermark_gamma,
                watermark_delta,
1483
                cuda_graphs_clone,
1484
                cuda_memory_fraction,
Nicolas Patry's avatar
Nicolas Patry committed
1485
1486
                rope_scaling,
                rope_factor,
1487
1488
                max_total_tokens,
                max_batch_size,
1489
                max_input_tokens,
drbh's avatar
drbh committed
1490
                lora_adapters,
Nicolas Patry's avatar
Nicolas Patry committed
1491
                enable_prefill_logprobs,
1492
                otlp_endpoint,
1493
                otlp_service_name,
1494
                max_log_level,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
1516
            Ok(ShardStatus::Failed(rank)) => {
1517
                tracing::error!("Shard {rank} failed to start");
1518
                shutdown_shards(shutdown, shutdown_receiver);
1519
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1520
1521
1522
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
1523
                shutdown_shards(shutdown, shutdown_receiver);
1524
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1525
1526
1527
            }
        }
    }
1528
1529
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1530

Nicolas Patry's avatar
Nicolas Patry committed
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
#[derive(Debug)]
struct ComputeType {
    count: usize,
    card: String,
}

impl ComputeType {
    fn f16_flop(&self) -> Option<u64> {
        let card_flop = match &self.card[..] {
            // https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/
            // Specs are unclear https://www.itcreations.com/nvidia-gpu/nvidia-geforce-rtx-4090-gpu
            "nvidia-4090" => Some(82 * 10u64.pow(12)),
            // https://www.nvidia.com/en-us/data-center/tesla-t4/
            "nvidia-t4" => Some(65 * 10u64.pow(12)),
            // https://www.nvidia.com/en-us/data-center/l4/
            "nvidia-l4" => Some(121 * 10u64.pow(12)),
            // https://www.nvidia.com/en-us/data-center/products/a10-gpu/
            "nvidia-a10g" => Some(125 * 10u64.pow(12)),
            // https://www.nvidia.com/en-us/data-center/h100/
            // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
            "nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)),
            // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
Nicolas Patry's avatar
Nicolas Patry committed
1553
            "nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)),
Nicolas Patry's avatar
Nicolas Patry committed
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
            "nvidia-a100" => Some(312 * 10u64.pow(12)),
            card => {
                tracing::warn!("Unkown compute for card {card}");
                None
            }
        };
        card_flop.map(|f| f * self.count as u64)
    }
}

impl From<ComputeType> for OsString {
    fn from(value: ComputeType) -> Self {
        format!("{}-{}", value.count, value.card).into()
    }
}

fn compute_type(num_shard: usize) -> Option<ComputeType> {
1571
1572
1573
1574
1575
1576
1577
    let output = Command::new("nvidia-smi")
        .args(["--query-gpu=gpu_name", "--format=csv"])
        .output()
        .ok()?;
    let output = String::from_utf8(output.stdout).ok()?;
    let fullname = output.split('\n').nth(1)?;
    let cardname = fullname.replace(' ', "-").to_lowercase();
Nicolas Patry's avatar
Nicolas Patry committed
1578
1579
1580
1581
    Some(ComputeType {
        count: num_shard,
        card: cardname,
    })
1582
1583
}

1584
fn spawn_webserver(
1585
    num_shard: usize,
1586
    args: Args,
1587
1588
    max_input_tokens: Option<usize>,
    max_total_tokens: Option<usize>,
1589
    max_batch_prefill_tokens: u32,
1590
    shutdown: Arc<AtomicBool>,
1591
    shutdown_receiver: &mpsc::Receiver<()>,
1592
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1593
1594
1595
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
1596
    let mut router_args = vec![
1597
1598
        "--max-client-batch-size".to_string(),
        args.max_client_batch_size.to_string(),
1599
        "--max-concurrent-requests".to_string(),
1600
        args.max_concurrent_requests.to_string(),
1601
        "--max-best-of".to_string(),
1602
        args.max_best_of.to_string(),
1603
        "--max-stop-sequences".to_string(),
1604
        args.max_stop_sequences.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
1605
1606
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
1607
        "--max-batch-prefill-tokens".to_string(),
1608
        max_batch_prefill_tokens.to_string(),
1609
        "--waiting-served-ratio".to_string(),
1610
        args.waiting_served_ratio.to_string(),
1611
        "--max-waiting-tokens".to_string(),
1612
        args.max_waiting_tokens.to_string(),
1613
1614
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
1615
1616
        "--hostname".to_string(),
        args.hostname.to_string(),
1617
        "--port".to_string(),
1618
        args.port.to_string(),
1619
        "--master-shard-uds-path".to_string(),
1620
        format!("{}-0", args.shard_uds_path),
1621
        "--tokenizer-name".to_string(),
1622
        args.model_id,
1623
1624
        "--payload-limit".to_string(),
        args.payload_limit.to_string(),
1625
    ];
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    if let Some(max_input_tokens) = max_input_tokens {
        router_args.extend_from_slice(&[
            "--max-input-tokens".to_string(),
            max_input_tokens.to_string(),
        ]);
    }
    if let Some(max_total_tokens) = max_total_tokens {
        router_args.extend_from_slice(&[
            "--max-total-tokens".to_string(),
            max_total_tokens.to_string(),
        ]);
    }
1638

1639
    // Pass usage stats flags to router
1640
1641
    router_args.push("--usage-stats".to_string());
    router_args.push(args.usage_stats.to_string());
1642

drbh's avatar
drbh committed
1643
1644
1645
1646
1647
    // Grammar support
    if args.disable_grammar_support {
        router_args.push("--disable-grammar-support".to_string());
    }

1648
1649
1650
1651
1652
1653
    // Tokenizer config path
    if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
        router_args.push("--tokenizer-config-path".to_string());
        router_args.push(tokenizer_config_path.to_string());
    }

1654
1655
1656
1657
1658
1659
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

1660
1661
1662
1663
1664
1665
    // Router optional max batch size
    if let Some(max_batch_size) = args.max_batch_size {
        router_args.push("--max-batch-size".to_string());
        router_args.push(max_batch_size.to_string());
    }

1666
1667
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
1668
1669
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
1670
1671
    }

1672
1673
1674
1675
    if args.trust_remote_code {
        router_args.push("--trust-remote-code".to_string());
    }

1676
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
1677
        router_args.push("--json-output".to_string());
1678
1679
    }

1680
    // OpenTelemetry
1681
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
1682
1683
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
1684
1685
    }

1686
1687
1688
1689
1690
    // OpenTelemetry
    let otlp_service_name = args.otlp_service_name;
    router_args.push("--otlp-service-name".to_string());
    router_args.push(otlp_service_name);

1691
1692
    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
1693
1694
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
1695
1696
    }

Erik Kaunismäki's avatar
Erik Kaunismäki committed
1697
1698
1699
1700
1701
    // API Key
    if let Some(api_key) = args.api_key {
        router_args.push("--api-key".to_string());
        router_args.push(api_key);
    }
1702
1703
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
1704
1705
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
1706
1707
1708
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
1709
1710
    }

1711
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1712
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1713

1714
1715
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1716
        envs.push(("HF_TOKEN".into(), api_token.into()))
1717
    };
1718

1719
1720
1721
1722
1723
1724
1725
    // Parse Compute type
    if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    } else if let Some(compute_type) = compute_type(num_shard) {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    }

1726
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
1727
1728
        .args(router_args)
        .envs(envs)
1729
1730
1731
1732
1733
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1734
1735
        Ok(p) => p,
        Err(err) => {
1736
            tracing::error!("Failed to start webserver: {}", err);
1737
1738
1739
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
1740
1741
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1742
            }
1743

1744
            shutdown_shards(shutdown, shutdown_receiver);
1745
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1746
1747
1748
        }
    };

1749
1750
1751
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
1752
1753

    thread::spawn(move || {
1754
1755
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
1756
        for line in stdout.lines() {
1757
            println!("{}", line.unwrap());
1758
        }
1759
1760
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1761
        }
1762
1763
1764
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1765

OlivierDehaene's avatar
OlivierDehaene committed
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");
    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }
    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1789
1790
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
1791
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1792

1793
    // Filter events with LOG_LEVEL
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
    let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
1810

1811
    if args.json_output {
1812
1813
1814
1815
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1816
    } else {
1817
1818
1819
1820
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1821
1822
    }

1823
1824
1825
1826
1827
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

Nicolas Patry's avatar
Nicolas Patry committed
1828
    tracing::info!("{:#?}", args);
1829

1830
1831
1832
1833
1834
1835
    let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
    let quantize = config.as_ref().and_then(|c| c.quantize);
    // Quantization usually means you're even more RAM constrained.

    let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
    tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
1836
    std::env::set_var("PREFIX_CACHING", prefix_caching);
1837
    std::env::set_var("ATTENTION", attention);
1838

Nicolas Patry's avatar
Nicolas Patry committed
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
    if num_shard > 1 {
        if matches!(args.quantize, Some(Quantization::Exl2)) {
            return Err(LauncherError::ArgumentValidation(
                "Sharding is currently not supported with `exl2` quantization".into(),
            ));
        }
        tracing::info!("Sharding model on {num_shard} processes");
    }

1849
1850
1851
1852
1853
1854
1855
    let max_input_tokens = {
        match (args.max_input_tokens, args.max_input_length) {
            (Some(max_input_tokens), Some(max_input_length)) => {
                return Err(LauncherError::ArgumentValidation(
                    format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
                )));
            }
1856
1857
            (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
                Some(max_input_tokens)
1858
            }
1859
            (None, None) => None,
1860
1861
        }
    };
1862
    let max_total_tokens = args.max_total_tokens;
1863
1864
1865
1866
    let max_batch_prefill_tokens = {
        match args.max_batch_prefill_tokens {
            Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
            None => {
1867
                // TODO figure out hardware optimal value
Nicolas Patry's avatar
Nicolas Patry committed
1868
1869
1870
1871
1872
1873
1874
1875
1876
                let compute_type = compute_type(num_shard);
                let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
                let default = compute_optimal.unwrap_or(4096);
                let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
                let value = if let Some(max_position_embeddings) = max_position_embeddings {
                    default.min(max_position_embeddings)
                } else {
                    default
                };
1877
                tracing::info!("Default `max_batch_prefill_tokens` to {value}");
Nicolas Patry's avatar
Nicolas Patry committed
1878
                value as u32
1879
1880
1881
1882
            }
        }
    };

1883
    // Validate args
1884
1885
1886
1887
1888
1889
    if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
        if max_input_tokens >= max_total_tokens {
            return Err(LauncherError::ArgumentValidation(
                    format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
                ));
        }
1890
    }
1891

1892
1893
1894
1895
1896
    if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
        tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
    }
    let quantize = args.quantize.or(quantize);
    let cuda_graphs = match (&args.cuda_graphs, &quantize) {
Nicolas Patry's avatar
Nicolas Patry committed
1897
        (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
1898
1899
1900
1901
1902
        #[allow(deprecated)]
        (
            None,
            Some(
                Quantization::Bitsandbytes
1903
1904
                | Quantization::BitsandbytesNf4
                | Quantization::BitsandbytesFp4,
1905
1906
            ),
        ) => {
1907
1908
1909
1910
1911
            tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
            vec![]
        }
        (None, Some(Quantization::Exl2)) => {
            tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
1912
1913
1914
1915
1916
1917
1918
1919
1920
            vec![]
        }
        _ => {
            let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
            tracing::info!("Using default cuda graphs {cuda_graphs:?}");
            cuda_graphs
        }
    };

1921
1922
1923
1924
1925
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1926
1927
1928
1929
1930
1931
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1932

1933
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
1934
1935
1936
1937
1938
1939
1940
        if let Some(max_total_tokens) = max_total_tokens {
            if max_total_tokens as u32 > *max_batch_total_tokens {
                return Err(LauncherError::ArgumentValidation(format!(
                    "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                    max_total_tokens, max_batch_total_tokens
                )));
            }
1941
1942
1943
        }
    }

1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1958
1959
1960
1961
1962
1963
1964
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1965

1966
    // Download and convert model weights
1967
1968
1969
1970
1971
1972
1973
    download_convert_model(
        &args.model_id,
        args.revision.as_deref(),
        args.trust_remote_code,
        args.huggingface_hub_cache.as_deref(),
        args.weights_cache_override.as_deref(),
        running.clone(),
1974
        true, // if its only a lora model - we should merge the lora adapters
1975
1976
1977
1978
1979
    )?;

    // Download and convert lora adapters if any
    if let Some(lora_adapters) = &args.lora_adapters {
        for adapter in lora_adapters.split(',') {
1980
1981
1982
1983
            // skip download if a path is provided
            if adapter.contains('=') {
                continue;
            }
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007

            let adapter = adapter.trim();

            // check if adapter has more than 1 '@'
            if adapter.matches('@').count() > 1 {
                return Err(LauncherError::ArgumentValidation(format!(
                    "Invalid LoRA adapter format: {}",
                    adapter
                )));
            }

            // capture adapter_id, path, revision in format of adapter_id=path@revision
            let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
            if let Some(caps) = re.captures(adapter) {
                let adapter_id = caps.get(1).map_or("", |m| m.as_str());
                let revision = caps.get(3).map(|m| m.as_str());

                download_convert_model(
                    adapter_id,
                    revision,
                    args.trust_remote_code,
                    args.huggingface_hub_cache.as_deref(),
                    args.weights_cache_override.as_deref(),
                    running.clone(),
2008
                    false, // avoid merging lora adapters if using multi-lora
2009
2010
2011
2012
2013
2014
2015
                )?;
            } else {
                return Err(LauncherError::ArgumentValidation(format!(
                    "Invalid LoRA adapter format: {}",
                    adapter
                )));
            }
2016
2017
        }
    }
2018

OlivierDehaene's avatar
OlivierDehaene committed
2019
2020
2021
2022
2023
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

2024
    // Shared shutdown bool
2025
    let shutdown = Arc::new(AtomicBool::new(false));
2026
2027
2028
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
2029

2030
2031
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
2032

2033
2034
2035
    spawn_shards(
        num_shard,
        &args,
2036
        cuda_graphs,
2037
        max_total_tokens,
2038
        max_input_tokens,
2039
        quantize,
2040
        max_log_level,
2041
2042
2043
2044
2045
2046
2047
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
2048

2049
2050
2051
2052
2053
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
2054

2055
2056
2057
2058
2059
2060
2061
2062
2063
    let mut webserver = spawn_webserver(
        num_shard,
        args,
        max_input_tokens,
        max_total_tokens,
        max_batch_prefill_tokens,
        shutdown.clone(),
        &shutdown_receiver,
    )
2064
    .inspect_err(|_| {
2065
2066
        shutdown_shards(shutdown.clone(), &shutdown_receiver);
    })?;
2067
2068
2069
2070
2071

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
2072
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
2073
            tracing::error!("Shard {rank} crashed");
2074
2075
2076
2077
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

2078
        match webserver.try_wait().unwrap() {
2079
2080
2081
2082
2083
2084
2085
2086
2087
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
2088
    }
2089
2090

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
2091
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
2092
2093
2094
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
2095
}