main.rs 67.8 KB
Newer Older
1
use clap::{Parser, ValueEnum};
Nicolas Patry's avatar
Nicolas Patry committed
2
3
4
5
use hf_hub::{
    api::sync::{Api, ApiBuilder},
    Repo, RepoType,
};
6
7
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
8
use regex::Regex;
9
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
10
use std::env;
11
use std::ffi::OsString;
12
use std::io::{BufRead, BufReader};
13
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
15
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
17
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
18
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
20
21
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
22
23
24
25
use std::{
    fs, io,
    io::{Read, Write},
};
26
use thiserror::Error;
27
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28

29
mod env_runtime;
30
mod gpu;
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
fn get_config(
    model_id: &str,
    revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
    let mut path = std::path::Path::new(model_id).to_path_buf();
    let model_id = model_id.to_string();
    let filename = if !path.exists() {
        // Assume it's a hub id

        let api = if let Ok(token) = std::env::var("HF_TOKEN") {
            // env variable has precedence over on file token.
            ApiBuilder::new().with_token(Some(token)).build()?
        } else {
            Api::new()?
        };
        let repo = if let Some(ref revision) = revision {
            api.repo(Repo::with_revision(
                model_id,
                RepoType::Model,
                revision.to_string(),
            ))
        } else {
            api.model(model_id)
        };
        repo.get("config.json")?
    } else {
        path.push("config.json");
        path
    };

    let content = std::fs::read_to_string(filename)?;
    let config: RawConfig = serde_json::from_str(&content)?;

    let config: Config = config.into();
    Ok(config)
}

fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
70
    let compute_capability = gpu::get_cuda_capability();
71
    let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
72
73
74
75
76
77
78
79
80
81
82
    let mut attention: Option<String> = std::env::var("ATTENTION").ok();
    if let Some(config) = config {
        if prefix_caching.is_none() {
            if config.vision_config.is_some() {
                tracing::info!("Disabling prefix caching because of VLM model");
                prefix_caching = Some("0".to_string());
            } else if config.is_encoder_decoder {
                tracing::info!("Disabling prefix caching because of seq2seq model");
                prefix_caching = Some("0".to_string());
            }
        }
83
84
85
86
87
88
89

        let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
            "paged"
        } else {
            "flashdecoding"
        };

90
91
92
93
94
95
96
        match config.head_dim {
            Some(h) if h == 64 || h == 128 || h == 256 => {
                if lora_adapters.is_some() && prefix_caching.is_none() {
                    tracing::info!("Disabling prefix caching because of lora adapters");
                    prefix_caching = Some("0".to_string());
                }
                match config.model_type.as_deref() {
Daniël de Kok's avatar
Daniël de Kok committed
97
                    Some("falcon") | Some("deepseek_v2") => {
98
99
100
101
                        // Required because gemma2 needs bfloat16 which is not supported by
                        // flashinfer ?
                        if attention.is_none() {
                            tracing::info!(
102
                                "Forcing attention to '{fallback_attention}' because model {} requires it",
103
104
                                config.model_type.as_ref().unwrap()
                            );
105
106
107
108
109
                            attention = Some(fallback_attention.to_string());
                        }
                        if fallback_attention == "paged" && prefix_caching.is_none() {
                            tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
                            prefix_caching = Some("0".to_string());
110
111
112
113
114
115
116
117
                        }
                    }
                    Some("t5") => {}
                    _ => {}
                }
            }
            _ => {
                if attention.is_none() {
118
119
                    tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
                    attention = Some(fallback_attention.to_string());
120
121
122
123
124
125
126
                }
                if prefix_caching.is_none() {
                    prefix_caching = Some("0".to_string());
                }
            }
        }
    }
127
128
129
130
    if attention == Some("paged".to_string()) && prefix_caching.is_none() {
        tracing::info!("Disabling prefix caching on paged attention");
        prefix_caching = Some("0".to_string());
    }
131

132
    let attention = attention.unwrap_or("flashinfer".to_string());
133
134
    let prefix_caching = prefix_caching.unwrap_or("true".to_string());

135
136
137
    (prefix_caching, attention)
}

138
#[derive(Deserialize)]
139
struct RawConfig {
140
    max_position_embeddings: Option<usize>,
141
    n_positions: Option<usize>,
142
    model_type: Option<String>,
143
    max_seq_len: Option<usize>,
144
    quantization_config: Option<QuantizationConfig>,
145
146
147
148
149
150
    n_embd: Option<usize>,
    hidden_size: Option<usize>,
    num_attention_heads: Option<usize>,
    head_dim: Option<usize>,
    vision_config: Option<VisionConfig>,
    is_encoder_decoder: Option<bool>,
151
152
153
154
155
}

#[derive(Deserialize)]
struct QuantizationConfig {
    quant_method: Option<Quantization>,
156
157
}

158
159
160
#[derive(Deserialize)]
struct VisionConfig {}

161
162
163
#[derive(Deserialize)]
struct Config {
    max_position_embeddings: Option<usize>,
164
    quantize: Option<Quantization>,
165
166
167
168
    head_dim: Option<usize>,
    model_type: Option<String>,
    vision_config: Option<VisionConfig>,
    is_encoder_decoder: bool,
169
170
171
172
173
174
175
176
}

impl From<RawConfig> for Config {
    fn from(other: RawConfig) -> Self {
        let max_position_embeddings = other
            .max_position_embeddings
            .or(other.max_seq_len)
            .or(other.n_positions);
177
        let quantize = other.quantization_config.and_then(|q| q.quant_method);
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        let head_dim = other.head_dim.or_else(|| {
            match (other.hidden_size, other.n_embd, other.num_attention_heads) {
                (Some(hidden_size), _, Some(num_attention_heads))
                    if hidden_size % num_attention_heads == 0 =>
                {
                    Some(hidden_size / num_attention_heads)
                }
                // Legacy
                (_, Some(hidden_size), Some(num_attention_heads))
                    if hidden_size % num_attention_heads == 0 =>
                {
                    Some(hidden_size / num_attention_heads)
                }
                _ => None,
            }
        });
        let model_type = other.model_type;
        let vision_config = other.vision_config;
        let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
197
198
        Config {
            max_position_embeddings,
199
            quantize,
200
201
202
203
            head_dim,
            model_type,
            vision_config,
            is_encoder_decoder,
204
205
206
207
        }
    }
}

208
209
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
210
enum Quantization {
211
    /// 4 bit quantization. Requires a specific AWQ quantized model:
212
    ///   <https://hf.co/models?search=awq>.
213
    /// Should replace GPTQ models wherever possible because of the better latency
214
215
216
    Awq,
    /// 8 bit quantization, doesn't require specific model.
    /// Should be a drop-in replacement to bitsandbytes with much better performance.
217
    /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
218
    Eetq,
219
220
221
222
    /// Variable bit quantization. Requires a specific EXL2 quantized model:
    /// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
    /// not support tensor parallelism (num_shard > 1).
    Exl2,
223
    /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
224
    /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
225
226
227
    /// triton kernel (wider support) when it's not.
    /// AWQ has faster kernels.
    Gptq,
228
229
    /// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
    Marlin,
230
231
    /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
    /// but it is known that the model will be much slower to run than the native f16.
232
233
234
235
    // #[deprecated(
    //     since = "1.1.0",
    //     note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
    // )]
236
    Bitsandbytes,
237
238
    /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
    /// but it is known that the model will be much slower to run than the native f16.
239
    BitsandbytesNf4,
240
241
    /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
    /// perplexity performance for you model
242
    BitsandbytesFp4,
Nicolas Patry's avatar
Nicolas Patry committed
243
244
245
246
247
    /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
    /// This dtype has native ops should be the fastest if available.
    /// This is currently not the fastest because of local unpacking + padding to satisfy matrix
    /// multiplication limitations.
    Fp8,
248
249
250
251
252
253
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
254
255
            #[allow(deprecated)]
            // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases
256
257
258
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
259
            Quantization::BitsandbytesNf4 => {
Nicolas Patry's avatar
Nicolas Patry committed
260
261
                write!(f, "bitsandbytes-nf4")
            }
262
            Quantization::BitsandbytesFp4 => {
Nicolas Patry's avatar
Nicolas Patry committed
263
264
                write!(f, "bitsandbytes-fp4")
            }
265
266
267
            Quantization::Exl2 => {
                write!(f, "exl2")
            }
268
269
270
            Quantization::Gptq => {
                write!(f, "gptq")
            }
271
272
273
            Quantization::Marlin => {
                write!(f, "marlin")
            }
274
275
276
            Quantization::Awq => {
                write!(f, "awq")
            }
277
278
279
            Quantization::Eetq => {
                write!(f, "eetq")
            }
Nicolas Patry's avatar
Nicolas Patry committed
280
281
282
            Quantization::Fp8 => {
                write!(f, "fp8")
            }
283
284
285
286
        }
    }
}

287
288
289
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
290
    #[clap(name = "bfloat16")]
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
    #[clap(name = "fp8_e5m2")]
    Fp8e5m2,
}

impl std::fmt::Display for KVCacheDtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            KVCacheDtype::Fp8e5m2 => {
                write!(f, "fp8_e5m2")
            }
        }
    }
}

Nicolas Patry's avatar
Nicolas Patry committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
    Linear,
    Dynamic,
}

impl std::fmt::Display for RopeScaling {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            RopeScaling::Linear => {
                write!(f, "linear")
            }
            RopeScaling::Dynamic => {
                write!(f, "dynamic")
            }
        }
    }
}

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
#[derive(Clone, Copy, Debug, ValueEnum)]
pub enum UsageStatsLevel {
    /// Default option, usage statistics are collected anonymously
    On,
    /// Disables all collection of usage statistics
    Off,
    /// Doesn't send the error stack trace or error type, but allows sending a crash event
    NoStack,
}

impl std::fmt::Display for UsageStatsLevel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            UsageStatsLevel::On => {
                write!(f, "on")
            }
            UsageStatsLevel::Off => {
                write!(f, "off")
            }
            UsageStatsLevel::NoStack => {
                write!(f, "no-stack")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
371
372
373
374
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
375
376
377
378
379
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
380
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
381
    model_id: String,
382
383
384

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
385
    #[clap(long, env)]
386
    revision: Option<String>,
387

388
389
390
391
392
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

393
    /// Whether to shard the model across multiple GPUs
394
395
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
396
397
    #[clap(long, env)]
    sharded: Option<bool>,
398
399

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
400
401
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
402
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
403
404
    #[clap(long, env)]
    num_shard: Option<usize>,
405

406
407
408
409
410
    /// Quantization method to use for the model. It is not necessary to specify this option
    /// for pre-quantized models, since the quantization method is read from the model
    /// configuration.
    ///
    /// Marlin kernels will be used automatically for GPTQ/AWQ models.
411
412
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
413

Nicolas Patry's avatar
Nicolas Patry committed
414
415
416
417
418
419
420
    /// The number of input_ids to speculate on
    /// If using a medusa model, the heads will be picked up automatically
    /// Other wise, it will use n-gram speculation which is relatively free
    /// in terms of compute, but the speedup heavily depends on the task.
    #[clap(long, env)]
    speculate: Option<usize>,

421
422
423
424
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

425
426
427
428
429
430
    /// Specify the dtype for the key-value cache. When this option is not provided,
    /// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
    /// the only supported value is `fp8_e5m2` on CUDA.
    #[clap(long, env, value_enum)]
    kv_cache_dtype: Option<KVCacheDtype>,

431
432
433
434
435
436
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

437
438
439
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
440
441
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
442
443
444
445

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
446
447
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
448
449
450
451
452
453

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
454
455
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
456

Nicolas Patry's avatar
Nicolas Patry committed
457
    /// This is the maximum allowed value for clients to set `top_n_tokens`.
458
    /// `top_n_tokens` is used to return information about the the `n` most likely
Nicolas Patry's avatar
Nicolas Patry committed
459
460
461
462
463
464
    /// tokens at each generation step, instead of just the sampled token. This
    /// information can be used for downstream tasks like for classification or
    /// ranking.
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,

465
466
467
468
    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
469
470
471
472
473
474
475
    /// Default to min(max_position_embeddings - 1, 4095)
    #[clap(long, env)]
    max_input_tokens: Option<usize>,

    /// Legacy version of [`Args::max_input_tokens`].
    #[clap(long, env)]
    max_input_length: Option<usize>,
476
477
478
479
480
481
482
483
484

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
485
486
487
    /// Default to min(max_position_embeddings, 4096)
    #[clap(long, env)]
    max_total_tokens: Option<usize>,
488
489
490
491
492
493
494
495
496
497
498

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
499
    #[clap(default_value = "0.3", long, env)]
500
    waiting_served_ratio: f32,
501

502
503
504
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
505
506
507
    /// Default to `max_input_tokens + 50` to give a bit of room.
    #[clap(long, env)]
    max_batch_prefill_tokens: Option<u32>,
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
526
527
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
546
547
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
548

549
550
551
552
553
    /// Enforce a maximum number of requests per batch
    /// Specific flag for hardware targets that do not support unpadded inference
    #[clap(long, env)]
    max_batch_size: Option<usize>,

554
555
    /// Specify the batch sizes to compute cuda graphs for.
    /// Use "0" to disable.
556
557
558
    /// Default = "1,2,4,8,16,32"
    #[clap(long, env, value_delimiter = ',')]
    cuda_graphs: Option<Vec<usize>>,
559

560
561
562
563
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

564
    /// The port to listen on.
565
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
566
    port: u16,
567
568
569

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
570
571
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
572
573

    /// The address the master shard will listen on. (setting used by torch distributed)
574
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
575
    master_addr: String,
576
577

    /// The address the master port will listen on. (setting used by torch distributed)
578
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
579
    master_port: usize,
580
581
582

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
583
    #[clap(long, env)]
584
    huggingface_hub_cache: Option<String>,
585
586
587

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
588
589
    #[clap(long, env)]
    weights_cache_override: Option<String>,
590
591
592
593
594

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
595
    #[clap(long, env)]
596
    disable_custom_kernels: bool,
597

598
599
600
601
602
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

Nicolas Patry's avatar
Nicolas Patry committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    /// Rope scaling will only be used for RoPE models
    /// and allow rescaling the position rotary to accomodate for
    /// larger prompts.
    ///
    /// Goes together with `rope_factor`.
    ///
    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
    /// basically)
    ///
    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want
    #[clap(long, env)]
    rope_scaling: Option<RopeScaling>,

    /// Rope scaling will only be used for RoPE models
    /// See `rope_scaling`
    #[clap(long, env)]
    rope_factor: Option<f32>,

623
    /// Outputs the logs in JSON format (useful for telemetry)
624
    #[clap(long, env)]
625
    json_output: bool,
626

627
628
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
629

630
631
632
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,

633
634
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
Erik Kaunismäki's avatar
Erik Kaunismäki committed
635
636
637
638

    #[clap(long, env)]
    api_key: Option<String>,

639
640
641
642
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
643

644
645
646
647
648
649
650
651
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

652
    /// ngrok edge
653
    #[clap(long, env)]
654
    ngrok_edge: Option<String>,
655

656
657
658
659
660
    /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
    /// include a `chat_template`. If not provided, the default config will be used from the model hub.
    #[clap(long, env)]
    tokenizer_config_path: Option<String>,

drbh's avatar
drbh committed
661
662
663
664
665
    /// Disable outlines grammar constrained generation.
    /// This is a feature that allows you to generate text that follows a specific grammar.
    #[clap(long, env)]
    disable_grammar_support: bool,

666
667
668
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
669
670
671
672

    /// Control the maximum number of inputs that a client can send in a single request
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
drbh's avatar
drbh committed
673
674
675
676
677

    /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
    /// startup that will be available to callers via the `adapter_id` field in a request.
    #[clap(long, env)]
    lora_adapters: Option<String>,
678

679
680
681
682
683
    /// Control if anonymous usage stats are collected.
    /// Options are "on", "off" and "no-stack"
    /// Defaul is on.
    #[clap(default_value = "on", long, env)]
    usage_stats: UsageStatsLevel,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
684
685
}

686
687
688
#[derive(Debug)]
enum ShardStatus {
    Ready,
689
    Failed(usize),
690
}
691

692
693
694
695
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
696
    quantize: Option<Quantization>,
Nicolas Patry's avatar
Nicolas Patry committed
697
    speculate: Option<usize>,
698
    dtype: Option<Dtype>,
699
    kv_cache_dtype: Option<KVCacheDtype>,
700
    trust_remote_code: bool,
701
702
703
704
705
706
707
708
709
710
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
711
    cuda_graphs: Vec<usize>,
712
    cuda_memory_fraction: f32,
Nicolas Patry's avatar
Nicolas Patry committed
713
714
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
715
716
    max_total_tokens: usize,
    max_batch_size: Option<usize>,
717
    max_input_tokens: usize,
drbh's avatar
drbh committed
718
    lora_adapters: Option<String>,
719
    otlp_endpoint: Option<String>,
720
    otlp_service_name: String,
721
    log_level: LevelFilter,
722
    status_sender: mpsc::Sender<ShardStatus>,
723
    shutdown: Arc<AtomicBool>,
724
725
    _shutdown_sender: mpsc::Sender<()>,
) {
726
727
728
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

729
730
731
732
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
733
734
735
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
736
737

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
738
    let mut shard_args = vec![
739
740
741
742
743
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
744
        log_level.to_string().to_uppercase(),
745
746
747
        "--json-output".to_string(),
    ];

748
749
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
750
        shard_args.push("--trust-remote-code".to_string());
751
752
    }

753
754
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
755
        shard_args.push("--sharded".to_string());
756
757
    }

758
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
759
760
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
761
    }
762

Nicolas Patry's avatar
Nicolas Patry committed
763
764
765
766
767
    if let Some(speculate) = speculate {
        shard_args.push("--speculate".to_string());
        shard_args.push(speculate.to_string())
    }

768
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
769
770
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
771
772
    }

773
774
775
776
777
    if let Some(kv_cache_dtype) = kv_cache_dtype {
        shard_args.push("--kv-cache-dtype".to_string());
        shard_args.push(kv_cache_dtype.to_string())
    }

778
779
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
780
781
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
782
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
783

Nicolas Patry's avatar
Nicolas Patry committed
784
785
786
787
788
789
    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };
790

791
    // OpenTelemetry Endpoint
792
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
793
794
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
795
796
    }

797
798
799
800
    // OpenTelemetry Service Name
    shard_args.push("--otlp-service-name".to_string());
    shard_args.push(otlp_service_name);

801
802
803
804
    // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
    shard_args.push("--max-input-tokens".to_string());
    shard_args.push(max_input_tokens.to_string());

805
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
806
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
807

808
809
810
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

811
    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
812
813
814
815
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
816
    envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
817

818
819
820
821
822
823
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

824
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
825
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
826

827
828
829
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

830
831
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
832
    envs.push((
833
834
835
836
837
838
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
839
        envs.push(("HF_TOKEN".into(), api_token.into()))
840
841
    };

Nicolas Patry's avatar
Nicolas Patry committed
842
843
844
845
846
847
848
849
850
    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

851
852
853
854
855
856
857
858
    envs.push((
        "MAX_TOTAL_TOKENS".into(),
        max_total_tokens.to_string().into(),
    ));
    if let Some(max_batch_size) = max_batch_size {
        envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
    }

drbh's avatar
drbh committed
859
860
861
862
863
    // Lora Adapters
    if let Some(lora_adapters) = lora_adapters {
        envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
    }

864
865
866
    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
867
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
868
869
870
871
872
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
873
        envs.push((
874
875
876
877
878
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

879
    // Enable experimental support for cuda graphs
880
881
882
883
884
885
886
887
888
889
    if !cuda_graphs.is_empty() {
        envs.push((
            "CUDA_GRAPHS".into(),
            cuda_graphs
                .into_iter()
                .map(|c| c.to_string())
                .collect::<Vec<_>>()
                .join(",")
                .into(),
        ));
890
891
    }

892
893
    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
894
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
895
896
897
898
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
899
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
900
901
902
903
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
904
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
905
906
907
    }

    // Start process
908
    tracing::info!("Starting shard");
909
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
910
        .args(shard_args)
911
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
912
        .envs(envs)
913
        .stdin(Stdio::piped())
914
915
916
917
918
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
919
920
        Ok(p) => p,
        Err(err) => {
921
922
923
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
924
925
            }
            {
926
                tracing::error!("{}", err);
927
            }
928

929
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
930
931
932
933
934
            return;
        }
    };

    // Redirect STDOUT to the console
935
    let mut pstdin = p.stdin.take().unwrap();
936
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
937
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
938

939
    //stdout tracing thread
940
    thread::spawn(move || {
941
        log_lines(shard_stdout_reader);
942
    });
943
944
945
    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
946
        for line in shard_stderr_reader.lines().map_while(Result::ok) {
947
948
949
            err_sender.send(line).unwrap_or(());
        }
    });
950
    // We read stdin in another thread as it seems that lines() can block in some cases
Nicolas Patry's avatar
Nicolas Patry committed
951
952
953
954
955
956
957
958
959
    if LevelFilter::current() >= tracing::Level::DEBUG {
        thread::spawn(move || {
            let mut stdin = io::stdin(); // We get `Stdin` here.
            loop {
                let mut buffer = vec![0; 4096];
                if let Ok(n) = stdin.read(&mut buffer) {
                    if n > 0 {
                        let _ = pstdin.write_all(&buffer[..n]);
                    }
960
961
                }
            }
Nicolas Patry's avatar
Nicolas Patry committed
962
963
        });
    }
964
965
966
967
968
969

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
970
        if let Some(exit_status) = p.try_wait().unwrap() {
971
972
973
974
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
975

976
            tracing::error!("Shard complete standard error output:\n{err}");
977

978
            if let Some(signal) = exit_status.signal() {
979
980
981
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

982
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
983
984
985
986
            return;
        }

        // We received a shutdown signal
987
        if shutdown.load(Ordering::SeqCst) {
988
            terminate("shard", p, Duration::from_secs(90)).unwrap();
989
990
991
992
993
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
994
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
995
996
997
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
998
            tracing::info!("Waiting for shard to be ready...");
999
1000
1001
1002
1003
1004
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

1005
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
1006
1007
1008
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
1009
    shutdown.store(true, Ordering::SeqCst);
1010
1011
1012
1013
1014
1015
1016

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
1017
1018
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
1019
1020
1021
        Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
            Ok(devices) => devices,
            Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
Nicolas Patry's avatar
Nicolas Patry committed
1022
        },
1023
    };
1024
1025
    let n_devices = devices.split(',').count();
    Some(n_devices)
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
1059
1060
1061
1062
1063
1064
1065
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
1066
1067
1068
1069
        }
    }
}

1070
impl TryFrom<&[u8]> for PythonLogMessage {
1071
1072
    type Error = serde_json::Error;

1073
1074
    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        serde_json::from_slice::<Self>(value)
1075
1076
1077
    }
}

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
    let mut buffer = vec![0u8; 8 * 4096];
    let mut stdout = std::io::stdout();
    loop {
        let n = bufread.read(&mut buffer);
        if let Ok(n) = n {
            if n > 0 {
                let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
                while let Some(line) = lines.next() {
                    match PythonLogMessage::try_from(line) {
                        Ok(log) => log.trace(),
                        // For interactive debugging ?
                        Err(_) => {
1091
1092
1093
1094
1095
1096
                            if LevelFilter::current() >= tracing::Level::DEBUG {
                                stdout.write_all(line).unwrap();
                                if lines.peek().is_some() {
                                    stdout.write_all(b"\n").unwrap();
                                }
                                stdout.flush().unwrap();
1097
1098
1099
1100
1101
                            }
                        }
                    }
                }
            }
1102
1103
1104
1105
        }
    }
}

1106
1107
1108
1109
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
1110
1111
1112
1113
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
1114
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
1115
            let n_devices = num_cuda_devices()
1116
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
1117
            if n_devices <= 1 {
1118
1119
1120
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
1121
            }
1122
            n_devices
1123
        }
1124
1125
1126
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
1127
1128
1129
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
1130
1131
            }
            num_shard
1132
        }
1133
1134
1135
1136
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
1137
    };
1138
    if num_shard < 1 {
1139
1140
1141
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
1142
    }
1143
    Ok(num_shard)
1144
}
1145

1146
#[derive(Debug, Error)]
1147
enum LauncherError {
1148
    #[error("Invalid argument: {0}")]
1149
    ArgumentValidation(String),
1150
    #[error("not enough cuda devices: {0}")]
1151
    NotEnoughCUDADevices(String),
1152
    #[error("Download error")]
1153
    DownloadError,
1154
    #[error("Shard cannot start")]
1155
    ShardCannotStart,
1156
    #[error("Shard disconnected")]
1157
    ShardDisconnected,
1158
    #[error("Shard failed")]
1159
    ShardFailed,
1160
    #[error("Webserver failed")]
1161
    WebserverFailed,
1162
    #[error("Webserver cannot start")]
1163
1164
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1165

1166
1167
1168
1169
1170
1171
1172
1173
fn download_convert_model(
    model_id: &str,
    revision: Option<&str>,
    trust_remote_code: bool,
    huggingface_hub_cache: Option<&str>,
    weights_cache_override: Option<&str>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
1174
1175
1176
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
1177
    let mut download_args = vec![
1178
        "download-weights".to_string(),
1179
        model_id.to_string(),
1180
1181
1182
1183
1184
1185
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
1186

1187
    // Model optional revision
1188
    if let Some(revision) = &revision {
OlivierDehaene's avatar
OlivierDehaene committed
1189
1190
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
1191
    }
1192

1193
    // Trust remote code for automatic peft fusion
1194
    if trust_remote_code {
1195
1196
1197
        download_args.push("--trust-remote-code".to_string());
    }

1198
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1199
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1200

1201
1202
1203
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

1204
1205
1206
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

1207
    // If huggingface_hub_cache is set, pass it to the download process
1208
    // Useful when running inside a docker container
1209
    if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
1210
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
1211
    };
1212

1213
1214
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
1215
    envs.push((
1216
1217
1218
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
1219

1220
1221
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1222
        envs.push(("HF_TOKEN".into(), api_token.into()))
1223
    };
1224

1225
1226
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
1227
    if let Some(weights_cache_override) = &weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
1228
        envs.push((
1229
1230
1231
1232
1233
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

1234
    // Start process
1235
    tracing::info!("Starting check and download process for {model_id}");
1236
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
1237
        .args(download_args)
1238
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
1239
        .envs(envs)
1240
1241
1242
1243
1244
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
1245
1246
        Ok(p) => p,
        Err(err) => {
1247
1248
1249
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
1250
1251
            } else {
                tracing::error!("{}", err);
1252
            }
1253

1254
1255
1256
            return Err(LauncherError::DownloadError);
        }
    };
1257

1258
    let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
1259

1260
    thread::spawn(move || {
1261
        log_lines(download_stdout);
1262
1263
1264
1265
1266
1267
1268
    });

    let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
1269
        for line in download_stderr.lines().map_while(Result::ok) {
1270
1271
            err_sender.send(line).unwrap_or(());
        }
1272
    });
1273

1274
    loop {
1275
1276
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
1277
                tracing::info!("Successfully downloaded weights for {model_id}");
1278
                break;
1279
            }
1280
1281

            let mut err = String::new();
1282
1283
1284
1285
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }

1286
1287
1288
1289
1290
1291
1292
1293
1294
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
1295
        }
1296
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
1297
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
1298
1299
1300
            return Ok(());
        }
        sleep(Duration::from_millis(100));
1301
    }
1302
1303
    Ok(())
}
1304

1305
#[allow(clippy::too_many_arguments)]
1306
1307
1308
fn spawn_shards(
    num_shard: usize,
    args: &Args,
1309
    cuda_graphs: Vec<usize>,
1310
    max_total_tokens: usize,
1311
    max_input_tokens: usize,
1312
    quantize: Option<Quantization>,
1313
    max_log_level: LevelFilter,
1314
    shutdown: Arc<AtomicBool>,
1315
1316
1317
1318
1319
1320
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1321
1322
    // Start shard processes
    for rank in 0..num_shard {
1323
1324
1325
1326
1327
1328
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1329
1330
1331
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
1332
        let otlp_endpoint = args.otlp_endpoint.clone();
1333
        let otlp_service_name = args.otlp_service_name.clone();
Nicolas Patry's avatar
Nicolas Patry committed
1334
        let speculate = args.speculate;
1335
        let dtype = args.dtype;
1336
        let kv_cache_dtype = args.kv_cache_dtype;
1337
        let trust_remote_code = args.trust_remote_code;
1338
1339
1340
1341
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
1342
        let cuda_graphs_clone = cuda_graphs.clone();
1343
        let cuda_memory_fraction = args.cuda_memory_fraction;
Nicolas Patry's avatar
Nicolas Patry committed
1344
1345
        let rope_scaling = args.rope_scaling;
        let rope_factor = args.rope_factor;
1346
        let max_batch_size = args.max_batch_size;
drbh's avatar
drbh committed
1347
        let lora_adapters = args.lora_adapters.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1348
1349
        thread::spawn(move || {
            shard_manager(
1350
                model_id,
1351
                revision,
1352
                quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1353
                speculate,
1354
                dtype,
1355
                kv_cache_dtype,
1356
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1357
1358
1359
1360
1361
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
1362
1363
                huggingface_hub_cache,
                weights_cache_override,
1364
                disable_custom_kernels,
1365
1366
                watermark_gamma,
                watermark_delta,
1367
                cuda_graphs_clone,
1368
                cuda_memory_fraction,
Nicolas Patry's avatar
Nicolas Patry committed
1369
1370
                rope_scaling,
                rope_factor,
1371
1372
                max_total_tokens,
                max_batch_size,
1373
                max_input_tokens,
drbh's avatar
drbh committed
1374
                lora_adapters,
1375
                otlp_endpoint,
1376
                otlp_service_name,
1377
                max_log_level,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
1399
            Ok(ShardStatus::Failed(rank)) => {
1400
                tracing::error!("Shard {rank} failed to start");
1401
                shutdown_shards(shutdown, shutdown_receiver);
1402
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1403
1404
1405
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
1406
                shutdown_shards(shutdown, shutdown_receiver);
1407
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1408
1409
1410
            }
        }
    }
1411
1412
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1413

1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
fn compute_type(num_shard: usize) -> Option<String> {
    let output = Command::new("nvidia-smi")
        .args(["--query-gpu=gpu_name", "--format=csv"])
        .output()
        .ok()?;
    let output = String::from_utf8(output.stdout).ok()?;
    let fullname = output.split('\n').nth(1)?;
    let cardname = fullname.replace(' ', "-").to_lowercase();
    let compute_type = format!("{num_shard}-{cardname}");
    Some(compute_type)
}

1426
fn spawn_webserver(
1427
    num_shard: usize,
1428
    args: Args,
1429
1430
1431
    max_input_tokens: usize,
    max_total_tokens: usize,
    max_batch_prefill_tokens: u32,
1432
    shutdown: Arc<AtomicBool>,
1433
    shutdown_receiver: &mpsc::Receiver<()>,
1434
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1435
1436
1437
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
1438
    let mut router_args = vec![
1439
1440
        "--max-client-batch-size".to_string(),
        args.max_client_batch_size.to_string(),
1441
        "--max-concurrent-requests".to_string(),
1442
        args.max_concurrent_requests.to_string(),
1443
        "--max-best-of".to_string(),
1444
        args.max_best_of.to_string(),
1445
        "--max-stop-sequences".to_string(),
1446
        args.max_stop_sequences.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
1447
1448
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
1449
1450
        "--max-input-tokens".to_string(),
        max_input_tokens.to_string(),
1451
        "--max-total-tokens".to_string(),
1452
        max_total_tokens.to_string(),
1453
        "--max-batch-prefill-tokens".to_string(),
1454
        max_batch_prefill_tokens.to_string(),
1455
        "--waiting-served-ratio".to_string(),
1456
        args.waiting_served_ratio.to_string(),
1457
        "--max-waiting-tokens".to_string(),
1458
        args.max_waiting_tokens.to_string(),
1459
1460
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
1461
1462
        "--hostname".to_string(),
        args.hostname.to_string(),
1463
        "--port".to_string(),
1464
        args.port.to_string(),
1465
        "--master-shard-uds-path".to_string(),
1466
        format!("{}-0", args.shard_uds_path),
1467
        "--tokenizer-name".to_string(),
1468
        args.model_id,
1469
1470
    ];

1471
    // Pass usage stats flags to router
1472
1473
    router_args.push("--usage-stats".to_string());
    router_args.push(args.usage_stats.to_string());
1474

drbh's avatar
drbh committed
1475
1476
1477
1478
1479
    // Grammar support
    if args.disable_grammar_support {
        router_args.push("--disable-grammar-support".to_string());
    }

1480
1481
1482
1483
1484
1485
    // Tokenizer config path
    if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
        router_args.push("--tokenizer-config-path".to_string());
        router_args.push(tokenizer_config_path.to_string());
    }

1486
1487
1488
1489
1490
1491
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

1492
1493
1494
1495
1496
1497
    // Router optional max batch size
    if let Some(max_batch_size) = args.max_batch_size {
        router_args.push("--max-batch-size".to_string());
        router_args.push(max_batch_size.to_string());
    }

1498
1499
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
1500
1501
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
1502
1503
    }

1504
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
1505
        router_args.push("--json-output".to_string());
1506
1507
    }

1508
    // OpenTelemetry
1509
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
1510
1511
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
1512
1513
    }

1514
1515
1516
1517
1518
    // OpenTelemetry
    let otlp_service_name = args.otlp_service_name;
    router_args.push("--otlp-service-name".to_string());
    router_args.push(otlp_service_name);

1519
1520
    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
1521
1522
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
1523
1524
    }

Erik Kaunismäki's avatar
Erik Kaunismäki committed
1525
1526
1527
1528
1529
    // API Key
    if let Some(api_key) = args.api_key {
        router_args.push("--api-key".to_string());
        router_args.push(api_key);
    }
1530
1531
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
1532
1533
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
1534
1535
1536
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
1537
1538
    }

1539
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1540
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1541

1542
1543
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1544
        envs.push(("HF_TOKEN".into(), api_token.into()))
1545
    };
1546

1547
1548
1549
1550
1551
1552
1553
    // Parse Compute type
    if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    } else if let Some(compute_type) = compute_type(num_shard) {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    }

1554
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
1555
1556
        .args(router_args)
        .envs(envs)
1557
1558
1559
1560
1561
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1562
1563
        Ok(p) => p,
        Err(err) => {
1564
            tracing::error!("Failed to start webserver: {}", err);
1565
1566
1567
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
1568
1569
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1570
            }
1571

1572
            shutdown_shards(shutdown, shutdown_receiver);
1573
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1574
1575
1576
        }
    };

1577
1578
1579
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
1580
1581

    thread::spawn(move || {
1582
1583
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
1584
        for line in stdout.lines() {
1585
            println!("{}", line.unwrap());
1586
        }
1587
1588
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1589
        }
1590
1591
1592
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1593

OlivierDehaene's avatar
OlivierDehaene committed
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");
    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }
    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1617
1618
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
1619
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1620

1621
    // Filter events with LOG_LEVEL
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
    let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
1638

1639
    if args.json_output {
1640
1641
1642
1643
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1644
    } else {
1645
1646
1647
1648
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1649
1650
    }

1651
1652
1653
1654
1655
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

Nicolas Patry's avatar
Nicolas Patry committed
1656
    tracing::info!("{:#?}", args);
1657

1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
    let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
    let quantize = config.as_ref().and_then(|c| c.quantize);
    // Quantization usually means you're even more RAM constrained.
    let max_default = 4096;

    let max_position_embeddings = if let Some(config) = &config {
        if let Some(max_position_embeddings) = config.max_position_embeddings {
            if max_position_embeddings > max_default {
                let max = max_position_embeddings;
                if args.max_input_tokens.is_none()
                    && args.max_total_tokens.is_none()
                    && args.max_batch_prefill_tokens.is_none()
                {
                    tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
1672
                }
1673
                max_default
1674
            } else {
1675
                max_position_embeddings
1676
            }
1677
1678
1679
1680
1681
1682
1683
1684
        } else {
            max_default
        }
    } else {
        max_default
    };
    let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
    tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
1685
    std::env::set_var("PREFIX_CACHING", prefix_caching);
1686
    std::env::set_var("ATTENTION", attention);
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729

    let max_input_tokens = {
        match (args.max_input_tokens, args.max_input_length) {
            (Some(max_input_tokens), Some(max_input_length)) => {
                return Err(LauncherError::ArgumentValidation(
                    format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
                )));
            }
            (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
            (None, None) => {
                let value = max_position_embeddings - 1;
                tracing::info!("Default `max_input_tokens` to {value}");
                value
            }
        }
    };
    let max_total_tokens = {
        match args.max_total_tokens {
            Some(max_total_tokens) => max_total_tokens,
            None => {
                let value = max_position_embeddings;
                tracing::info!("Default `max_total_tokens` to {value}");
                value
            }
        }
    };
    let max_batch_prefill_tokens = {
        match args.max_batch_prefill_tokens {
            Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
            None => {
                let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
                    max_batch_size * max_input_tokens
                } else {
                    // Adding some edge in order to account for potential block_size alignement
                    // issue.
                    max_input_tokens + 50
                } as u32;
                tracing::info!("Default `max_batch_prefill_tokens` to {value}");
                value
            }
        }
    };

1730
    // Validate args
1731
    if max_input_tokens >= max_total_tokens {
1732
        return Err(LauncherError::ArgumentValidation(
1733
            "`max_input_tokens must be < `max_total_tokens`".to_string(),
1734
1735
        ));
    }
1736

1737
1738
1739
1740
1741
    if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
        tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
    }
    let quantize = args.quantize.or(quantize);
    let cuda_graphs = match (&args.cuda_graphs, &quantize) {
Nicolas Patry's avatar
Nicolas Patry committed
1742
        (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
1743
1744
1745
1746
1747
        #[allow(deprecated)]
        (
            None,
            Some(
                Quantization::Bitsandbytes
1748
1749
                | Quantization::BitsandbytesNf4
                | Quantization::BitsandbytesFp4,
1750
1751
            ),
        ) => {
1752
1753
1754
1755
1756
            tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
            vec![]
        }
        (None, Some(Quantization::Exl2)) => {
            tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
1757
1758
1759
1760
1761
1762
1763
1764
1765
            vec![]
        }
        _ => {
            let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
            tracing::info!("Using default cuda graphs {cuda_graphs:?}");
            cuda_graphs
        }
    };

1766
1767
1768
1769
1770
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1771
1772
1773
1774
1775
1776
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1777
1778

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1779
    if num_shard > 1 {
1780
1781
1782
1783
1784
        if matches!(args.quantize, Some(Quantization::Exl2)) {
            return Err(LauncherError::ArgumentValidation(
                "Sharding is currently not supported with `exl2` quantization".into(),
            ));
        }
1785
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1786
1787
    }

1788
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
1789
        if max_total_tokens as u32 > *max_batch_total_tokens {
1790
1791
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
1792
                max_total_tokens, max_batch_total_tokens
1793
1794
1795
1796
            )));
        }
    }

1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1811
1812
1813
1814
1815
1816
1817
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1818

1819
    // Download and convert model weights
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
    download_convert_model(
        &args.model_id,
        args.revision.as_deref(),
        args.trust_remote_code,
        args.huggingface_hub_cache.as_deref(),
        args.weights_cache_override.as_deref(),
        running.clone(),
    )?;

    // Download and convert lora adapters if any
    if let Some(lora_adapters) = &args.lora_adapters {
        for adapter in lora_adapters.split(',') {
1832
1833
1834
1835
            // skip download if a path is provided
            if adapter.contains('=') {
                continue;
            }
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866

            let adapter = adapter.trim();

            // check if adapter has more than 1 '@'
            if adapter.matches('@').count() > 1 {
                return Err(LauncherError::ArgumentValidation(format!(
                    "Invalid LoRA adapter format: {}",
                    adapter
                )));
            }

            // capture adapter_id, path, revision in format of adapter_id=path@revision
            let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
            if let Some(caps) = re.captures(adapter) {
                let adapter_id = caps.get(1).map_or("", |m| m.as_str());
                let revision = caps.get(3).map(|m| m.as_str());

                download_convert_model(
                    adapter_id,
                    revision,
                    args.trust_remote_code,
                    args.huggingface_hub_cache.as_deref(),
                    args.weights_cache_override.as_deref(),
                    running.clone(),
                )?;
            } else {
                return Err(LauncherError::ArgumentValidation(format!(
                    "Invalid LoRA adapter format: {}",
                    adapter
                )));
            }
1867
1868
        }
    }
1869

OlivierDehaene's avatar
OlivierDehaene committed
1870
1871
1872
1873
1874
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1875
    // Shared shutdown bool
1876
    let shutdown = Arc::new(AtomicBool::new(false));
1877
1878
1879
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1880

1881
1882
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1883

1884
1885
1886
    spawn_shards(
        num_shard,
        &args,
1887
        cuda_graphs,
1888
        max_total_tokens,
1889
        max_input_tokens,
1890
        quantize,
1891
        max_log_level,
1892
1893
1894
1895
1896
1897
1898
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1899

1900
1901
1902
1903
1904
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1905

1906
1907
1908
1909
1910
1911
1912
1913
1914
    let mut webserver = spawn_webserver(
        num_shard,
        args,
        max_input_tokens,
        max_total_tokens,
        max_batch_prefill_tokens,
        shutdown.clone(),
        &shutdown_receiver,
    )
1915
    .inspect_err(|_| {
1916
1917
        shutdown_shards(shutdown.clone(), &shutdown_receiver);
    })?;
1918
1919
1920
1921
1922

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1923
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1924
            tracing::error!("Shard {rank} crashed");
1925
1926
1927
1928
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1929
        match webserver.try_wait().unwrap() {
1930
1931
1932
1933
1934
1935
1936
1937
1938
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1939
    }
1940
1941

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1942
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1943
1944
1945
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1946
}