"vscode:/vscode.git/clone" did not exist on "8560d58bdd0f81c9eb53b8d1fa06dcf857f7af34"
main.rs 60.8 KB
Newer Older
1
use clap::{Parser, ValueEnum};
Nicolas Patry's avatar
Nicolas Patry committed
2
3
4
5
use hf_hub::{
    api::sync::{Api, ApiBuilder},
    Repo, RepoType,
};
6
7
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
8
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
9
use std::env;
10
use std::ffi::OsString;
11
use std::io::{BufRead, BufReader, Lines};
12
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
14
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
16
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
17
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
20
21
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
22
use thiserror::Error;
23
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24

25
26
mod env_runtime;

27
#[derive(Deserialize)]
28
struct RawConfig {
29
    max_position_embeddings: Option<usize>,
30
    n_positions: Option<usize>,
31
    model_type: Option<String>,
32
    max_seq_len: Option<usize>,
33
34
35
36
37
38
    quantization_config: Option<QuantizationConfig>,
}

#[derive(Deserialize)]
struct QuantizationConfig {
    quant_method: Option<Quantization>,
39
40
}

41
42
43
#[derive(Deserialize)]
struct Config {
    max_position_embeddings: Option<usize>,
44
    quantize: Option<Quantization>,
45
46
47
48
49
50
51
52
}

impl From<RawConfig> for Config {
    fn from(other: RawConfig) -> Self {
        let max_position_embeddings = other
            .max_position_embeddings
            .or(other.max_seq_len)
            .or(other.n_positions);
53
        let quantize = other.quantization_config.and_then(|q| q.quant_method);
54
55
        Config {
            max_position_embeddings,
56
            quantize,
57
58
59
60
        }
    }
}

61
62
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
63
enum Quantization {
64
    /// 4 bit quantization. Requires a specific AWQ quantized model:
65
    ///   <https://hf.co/models?search=awq>.
66
    /// Should replace GPTQ models wherever possible because of the better latency
67
68
69
    Awq,
    /// 8 bit quantization, doesn't require specific model.
    /// Should be a drop-in replacement to bitsandbytes with much better performance.
70
    /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
71
    Eetq,
72
73
74
75
    /// Variable bit quantization. Requires a specific EXL2 quantized model:
    /// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
    /// not support tensor parallelism (num_shard > 1).
    Exl2,
76
    /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
77
    /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
78
79
80
    /// triton kernel (wider support) when it's not.
    /// AWQ has faster kernels.
    Gptq,
81
82
    /// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
    Marlin,
83
84
    /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
    /// but it is known that the model will be much slower to run than the native f16.
85
86
87
88
    // #[deprecated(
    //     since = "1.1.0",
    //     note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
    // )]
89
    Bitsandbytes,
90
91
    /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
    /// but it is known that the model will be much slower to run than the native f16.
92
    BitsandbytesNf4,
93
94
    /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
    /// perplexity performance for you model
95
    BitsandbytesFp4,
Nicolas Patry's avatar
Nicolas Patry committed
96
97
98
99
100
    /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
    /// This dtype has native ops should be the fastest if available.
    /// This is currently not the fastest because of local unpacking + padding to satisfy matrix
    /// multiplication limitations.
    Fp8,
101
102
103
104
105
106
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
107
108
            #[allow(deprecated)]
            // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases
109
110
111
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
112
            Quantization::BitsandbytesNf4 => {
Nicolas Patry's avatar
Nicolas Patry committed
113
114
                write!(f, "bitsandbytes-nf4")
            }
115
            Quantization::BitsandbytesFp4 => {
Nicolas Patry's avatar
Nicolas Patry committed
116
117
                write!(f, "bitsandbytes-fp4")
            }
118
119
120
            Quantization::Exl2 => {
                write!(f, "exl2")
            }
121
122
123
            Quantization::Gptq => {
                write!(f, "gptq")
            }
124
125
126
            Quantization::Marlin => {
                write!(f, "marlin")
            }
127
128
129
            Quantization::Awq => {
                write!(f, "awq")
            }
130
131
132
            Quantization::Eetq => {
                write!(f, "eetq")
            }
Nicolas Patry's avatar
Nicolas Patry committed
133
134
135
            Quantization::Fp8 => {
                write!(f, "fp8")
            }
136
137
138
139
        }
    }
}

140
141
142
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
143
    #[clap(name = "bfloat16")]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

Nicolas Patry's avatar
Nicolas Patry committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
    Linear,
    Dynamic,
}

impl std::fmt::Display for RopeScaling {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            RopeScaling::Linear => {
                write!(f, "linear")
            }
            RopeScaling::Dynamic => {
                write!(f, "dynamic")
            }
        }
    }
}

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#[derive(Clone, Copy, Debug, ValueEnum)]
pub enum UsageStatsLevel {
    /// Default option, usage statistics are collected anonymously
    On,
    /// Disables all collection of usage statistics
    Off,
    /// Doesn't send the error stack trace or error type, but allows sending a crash event
    NoStack,
}

impl std::fmt::Display for UsageStatsLevel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            UsageStatsLevel::On => {
                write!(f, "on")
            }
            UsageStatsLevel::Off => {
                write!(f, "off")
            }
            UsageStatsLevel::NoStack => {
                write!(f, "no-stack")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
208
209
210
211
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
212
213
214
215
216
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
217
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
218
    model_id: String,
219
220
221

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
222
    #[clap(long, env)]
223
    revision: Option<String>,
224

225
226
227
228
229
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

230
    /// Whether to shard the model across multiple GPUs
231
232
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
233
234
    #[clap(long, env)]
    sharded: Option<bool>,
235
236

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
237
238
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
239
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
240
241
    #[clap(long, env)]
    num_shard: Option<usize>,
242

243
    /// Whether you want the model to be quantized.
244
245
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
246

Nicolas Patry's avatar
Nicolas Patry committed
247
248
249
250
251
252
253
    /// The number of input_ids to speculate on
    /// If using a medusa model, the heads will be picked up automatically
    /// Other wise, it will use n-gram speculation which is relatively free
    /// in terms of compute, but the speedup heavily depends on the task.
    #[clap(long, env)]
    speculate: Option<usize>,

254
255
256
257
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

258
259
260
261
262
263
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

264
265
266
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
267
268
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
269
270
271
272

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
273
274
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
275
276
277
278
279
280

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
281
282
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
283

Nicolas Patry's avatar
Nicolas Patry committed
284
    /// This is the maximum allowed value for clients to set `top_n_tokens`.
285
    /// `top_n_tokens` is used to return information about the the `n` most likely
Nicolas Patry's avatar
Nicolas Patry committed
286
287
288
289
290
291
    /// tokens at each generation step, instead of just the sampled token. This
    /// information can be used for downstream tasks like for classification or
    /// ranking.
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,

292
293
294
295
    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
296
297
298
299
300
301
302
    /// Default to min(max_position_embeddings - 1, 4095)
    #[clap(long, env)]
    max_input_tokens: Option<usize>,

    /// Legacy version of [`Args::max_input_tokens`].
    #[clap(long, env)]
    max_input_length: Option<usize>,
303
304
305
306
307
308
309
310
311

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
312
313
314
    /// Default to min(max_position_embeddings, 4096)
    #[clap(long, env)]
    max_total_tokens: Option<usize>,
315
316
317
318
319
320
321
322
323
324
325

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
326
    #[clap(default_value = "0.3", long, env)]
327
    waiting_served_ratio: f32,
328

329
330
331
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
332
333
334
    /// Default to `max_input_tokens + 50` to give a bit of room.
    #[clap(long, env)]
    max_batch_prefill_tokens: Option<u32>,
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
353
354
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
373
374
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
375

376
377
378
379
380
    /// Enforce a maximum number of requests per batch
    /// Specific flag for hardware targets that do not support unpadded inference
    #[clap(long, env)]
    max_batch_size: Option<usize>,

381
382
    /// Specify the batch sizes to compute cuda graphs for.
    /// Use "0" to disable.
383
384
385
    /// Default = "1,2,4,8,16,32"
    #[clap(long, env, value_delimiter = ',')]
    cuda_graphs: Option<Vec<usize>>,
386

387
388
389
390
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

391
    /// The port to listen on.
392
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
393
    port: u16,
394
395
396

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
397
398
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
399
400

    /// The address the master shard will listen on. (setting used by torch distributed)
401
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
402
    master_addr: String,
403
404

    /// The address the master port will listen on. (setting used by torch distributed)
405
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
406
    master_port: usize,
407
408
409

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
410
    #[clap(long, env)]
411
    huggingface_hub_cache: Option<String>,
412
413
414

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
415
416
    #[clap(long, env)]
    weights_cache_override: Option<String>,
417
418
419
420
421

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
422
    #[clap(long, env)]
423
    disable_custom_kernels: bool,
424

425
426
427
428
429
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

Nicolas Patry's avatar
Nicolas Patry committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    /// Rope scaling will only be used for RoPE models
    /// and allow rescaling the position rotary to accomodate for
    /// larger prompts.
    ///
    /// Goes together with `rope_factor`.
    ///
    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
    /// basically)
    ///
    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want
    #[clap(long, env)]
    rope_scaling: Option<RopeScaling>,

    /// Rope scaling will only be used for RoPE models
    /// See `rope_scaling`
    #[clap(long, env)]
    rope_factor: Option<f32>,

450
    /// Outputs the logs in JSON format (useful for telemetry)
451
    #[clap(long, env)]
452
    json_output: bool,
453

454
455
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
456

457
458
459
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,

460
461
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
Erik Kaunismäki's avatar
Erik Kaunismäki committed
462
463
464
465

    #[clap(long, env)]
    api_key: Option<String>,

466
467
468
469
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
470

471
472
473
474
475
476
477
478
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

479
    /// ngrok edge
480
    #[clap(long, env)]
481
    ngrok_edge: Option<String>,
482

483
484
485
486
487
    /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
    /// include a `chat_template`. If not provided, the default config will be used from the model hub.
    #[clap(long, env)]
    tokenizer_config_path: Option<String>,

drbh's avatar
drbh committed
488
489
490
491
492
    /// Disable outlines grammar constrained generation.
    /// This is a feature that allows you to generate text that follows a specific grammar.
    #[clap(long, env)]
    disable_grammar_support: bool,

493
494
495
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
496
497
498
499

    /// Control the maximum number of inputs that a client can send in a single request
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
drbh's avatar
drbh committed
500
501
502
503
504

    /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
    /// startup that will be available to callers via the `adapter_id` field in a request.
    #[clap(long, env)]
    lora_adapters: Option<String>,
505

506
507
508
509
510
    /// Control if anonymous usage stats are collected.
    /// Options are "on", "off" and "no-stack"
    /// Defaul is on.
    #[clap(default_value = "on", long, env)]
    usage_stats: UsageStatsLevel,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
511
512
}

513
514
515
#[derive(Debug)]
enum ShardStatus {
    Ready,
516
    Failed(usize),
517
}
518

519
520
521
522
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
523
    quantize: Option<Quantization>,
Nicolas Patry's avatar
Nicolas Patry committed
524
    speculate: Option<usize>,
525
    dtype: Option<Dtype>,
526
    trust_remote_code: bool,
527
528
529
530
531
532
533
534
535
536
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
537
    cuda_graphs: Vec<usize>,
538
    cuda_memory_fraction: f32,
Nicolas Patry's avatar
Nicolas Patry committed
539
540
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
541
542
    max_total_tokens: usize,
    max_batch_size: Option<usize>,
543
    max_input_tokens: usize,
drbh's avatar
drbh committed
544
    lora_adapters: Option<String>,
545
    otlp_endpoint: Option<String>,
546
    otlp_service_name: String,
547
    log_level: LevelFilter,
548
    status_sender: mpsc::Sender<ShardStatus>,
549
    shutdown: Arc<AtomicBool>,
550
551
    _shutdown_sender: mpsc::Sender<()>,
) {
552
553
554
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

555
556
557
558
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
559
560
561
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
562
563

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
564
    let mut shard_args = vec![
565
566
567
568
569
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
570
        log_level.to_string().to_uppercase(),
571
572
573
        "--json-output".to_string(),
    ];

574
575
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
576
        shard_args.push("--trust-remote-code".to_string());
577
578
    }

579
580
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
581
        shard_args.push("--sharded".to_string());
582
583
    }

584
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
585
586
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
587
    }
588

Nicolas Patry's avatar
Nicolas Patry committed
589
590
591
592
593
    if let Some(speculate) = speculate {
        shard_args.push("--speculate".to_string());
        shard_args.push(speculate.to_string())
    }

594
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
595
596
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
597
598
    }

599
600
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
601
602
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
603
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
604

Nicolas Patry's avatar
Nicolas Patry committed
605
606
607
608
609
610
    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };
611

612
    // OpenTelemetry Endpoint
613
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
614
615
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
616
617
    }

618
619
620
621
    // OpenTelemetry Service Name
    shard_args.push("--otlp-service-name".to_string());
    shard_args.push(otlp_service_name);

622
623
624
625
    // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
    shard_args.push("--max-input-tokens".to_string());
    shard_args.push(max_input_tokens.to_string());

626
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
627
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
628

629
630
631
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

632
    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
633
634
635
636
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
637
    envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
638

639
640
641
642
643
644
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

645
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
646
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
647

648
649
650
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

651
652
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
653
    envs.push((
654
655
656
657
658
659
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
660
        envs.push(("HF_TOKEN".into(), api_token.into()))
661
662
    };

Nicolas Patry's avatar
Nicolas Patry committed
663
664
665
666
667
668
669
670
671
    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

672
673
674
675
676
677
678
679
    envs.push((
        "MAX_TOTAL_TOKENS".into(),
        max_total_tokens.to_string().into(),
    ));
    if let Some(max_batch_size) = max_batch_size {
        envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
    }

drbh's avatar
drbh committed
680
681
682
683
684
    // Lora Adapters
    if let Some(lora_adapters) = lora_adapters {
        envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
    }

685
686
687
    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
688
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
689
690
691
692
693
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
694
        envs.push((
695
696
697
698
699
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

700
    // Enable experimental support for cuda graphs
701
702
703
704
705
706
707
708
709
710
    if !cuda_graphs.is_empty() {
        envs.push((
            "CUDA_GRAPHS".into(),
            cuda_graphs
                .into_iter()
                .map(|c| c.to_string())
                .collect::<Vec<_>>()
                .join(",")
                .into(),
        ));
711
712
    }

713
714
    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
715
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
716
717
718
719
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
720
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
721
722
723
724
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
725
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
726
727
728
    }

    // Start process
729
    tracing::info!("Starting shard");
730
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
731
        .args(shard_args)
732
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
733
        .envs(envs)
734
735
736
737
738
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
739
740
        Ok(p) => p,
        Err(err) => {
741
742
743
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
744
745
            }
            {
746
                tracing::error!("{}", err);
747
            }
748

749
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
750
751
752
753
754
            return;
        }
    };

    // Redirect STDOUT to the console
755
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
756
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
757

758
    //stdout tracing thread
759
    thread::spawn(move || {
760
        log_lines(shard_stdout_reader.lines());
761
    });
762
763
764
    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
765
        for line in shard_stderr_reader.lines().map_while(Result::ok) {
766
767
768
            err_sender.send(line).unwrap_or(());
        }
    });
769
770
771
772
773
774

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
775
        if let Some(exit_status) = p.try_wait().unwrap() {
776
777
778
779
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
780

781
            tracing::error!("Shard complete standard error output:\n{err}");
782

783
            if let Some(signal) = exit_status.signal() {
784
785
786
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

787
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
788
789
790
791
            return;
        }

        // We received a shutdown signal
792
        if shutdown.load(Ordering::SeqCst) {
793
            terminate("shard", p, Duration::from_secs(90)).unwrap();
794
795
796
797
798
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
799
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
800
801
802
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
803
            tracing::info!("Waiting for shard to be ready...");
804
805
806
807
808
809
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

810
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
811
812
813
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
814
    shutdown.store(true, Ordering::SeqCst);
815
816
817
818
819
820
821

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
822
823
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
824
825
826
        Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
            Ok(devices) => devices,
            Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
Nicolas Patry's avatar
Nicolas Patry committed
827
        },
828
    };
829
830
    let n_devices = devices.split(',').count();
    Some(n_devices)
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
864
865
866
867
868
869
870
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
871
872
873
874
        }
    }
}

875
876
877
878
879
880
881
882
883
impl TryFrom<&String> for PythonLogMessage {
    type Error = serde_json::Error;

    fn try_from(value: &String) -> Result<Self, Self::Error> {
        serde_json::from_str::<Self>(value)
    }
}

fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
OlivierDehaene's avatar
OlivierDehaene committed
884
    for line in lines.map_while(Result::ok) {
885
886
887
888
889
890
891
        match PythonLogMessage::try_from(&line) {
            Ok(log) => log.trace(),
            Err(_) => tracing::debug!("{line}"),
        }
    }
}

892
893
894
895
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
896
897
898
899
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
900
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
901
            let n_devices = num_cuda_devices()
902
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
903
            if n_devices <= 1 {
904
905
906
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
907
            }
908
            n_devices
909
        }
910
911
912
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
913
914
915
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
916
917
            }
            num_shard
918
        }
919
920
921
922
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
923
    };
924
    if num_shard < 1 {
925
926
927
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
928
    }
929
    Ok(num_shard)
930
}
931

932
#[derive(Debug, Error)]
933
enum LauncherError {
934
    #[error("Invalid argument: {0}")]
935
    ArgumentValidation(String),
936
    #[error("not enough cuda devices: {0}")]
937
    NotEnoughCUDADevices(String),
938
    #[error("Download error")]
939
    DownloadError,
940
    #[error("Shard cannot start")]
941
    ShardCannotStart,
942
    #[error("Shard disconnected")]
943
    ShardDisconnected,
944
    #[error("Shard failed")]
945
    ShardFailed,
946
    #[error("Webserver failed")]
947
    WebserverFailed,
948
    #[error("Webserver cannot start")]
949
950
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
951

952
953
954
955
956
957
958
959
fn download_convert_model(
    model_id: &str,
    revision: Option<&str>,
    trust_remote_code: bool,
    huggingface_hub_cache: Option<&str>,
    weights_cache_override: Option<&str>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
960
961
962
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
963
    let mut download_args = vec![
964
        "download-weights".to_string(),
965
        model_id.to_string(),
966
967
968
969
970
971
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
972

973
    // Model optional revision
974
    if let Some(revision) = &revision {
OlivierDehaene's avatar
OlivierDehaene committed
975
976
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
977
    }
978

979
    // Trust remote code for automatic peft fusion
980
    if trust_remote_code {
981
982
983
        download_args.push("--trust-remote-code".to_string());
    }

984
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
985
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
986

987
988
989
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

990
991
992
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

993
    // If huggingface_hub_cache is set, pass it to the download process
994
    // Useful when running inside a docker container
995
    if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
996
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
997
    };
998

999
1000
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
1001
    envs.push((
1002
1003
1004
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
1005

1006
1007
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1008
        envs.push(("HF_TOKEN".into(), api_token.into()))
1009
    };
1010

1011
1012
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
1013
    if let Some(weights_cache_override) = &weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
1014
        envs.push((
1015
1016
1017
1018
1019
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

1020
    // Start process
1021
    tracing::info!("Starting check and download process for {model_id}");
1022
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
1023
        .args(download_args)
1024
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
1025
        .envs(envs)
1026
1027
1028
1029
1030
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
1031
1032
        Ok(p) => p,
        Err(err) => {
1033
1034
1035
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
1036
1037
            } else {
                tracing::error!("{}", err);
1038
            }
1039

1040
1041
1042
            return Err(LauncherError::DownloadError);
        }
    };
1043

1044
    let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
1045

1046
    thread::spawn(move || {
1047
1048
1049
1050
1051
1052
1053
1054
        log_lines(download_stdout.lines());
    });

    let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
1055
        for line in download_stderr.lines().map_while(Result::ok) {
1056
1057
            err_sender.send(line).unwrap_or(());
        }
1058
    });
1059

1060
    loop {
1061
1062
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
1063
                tracing::info!("Successfully downloaded weights for {model_id}");
1064
                break;
1065
            }
1066
1067

            let mut err = String::new();
1068
1069
1070
1071
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }

1072
1073
1074
1075
1076
1077
1078
1079
1080
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
1081
        }
1082
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
1083
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
1084
1085
1086
            return Ok(());
        }
        sleep(Duration::from_millis(100));
1087
    }
1088
1089
    Ok(())
}
1090

1091
#[allow(clippy::too_many_arguments)]
1092
1093
1094
fn spawn_shards(
    num_shard: usize,
    args: &Args,
1095
    cuda_graphs: Vec<usize>,
1096
    max_total_tokens: usize,
1097
    max_input_tokens: usize,
1098
    quantize: Option<Quantization>,
1099
    max_log_level: LevelFilter,
1100
    shutdown: Arc<AtomicBool>,
1101
1102
1103
1104
1105
1106
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1107
1108
    // Start shard processes
    for rank in 0..num_shard {
1109
1110
1111
1112
1113
1114
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1115
1116
1117
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
1118
        let otlp_endpoint = args.otlp_endpoint.clone();
1119
        let otlp_service_name = args.otlp_service_name.clone();
Nicolas Patry's avatar
Nicolas Patry committed
1120
        let speculate = args.speculate;
1121
        let dtype = args.dtype;
1122
        let trust_remote_code = args.trust_remote_code;
1123
1124
1125
1126
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
1127
        let cuda_graphs_clone = cuda_graphs.clone();
1128
        let cuda_memory_fraction = args.cuda_memory_fraction;
Nicolas Patry's avatar
Nicolas Patry committed
1129
1130
        let rope_scaling = args.rope_scaling;
        let rope_factor = args.rope_factor;
1131
        let max_batch_size = args.max_batch_size;
drbh's avatar
drbh committed
1132
        let lora_adapters = args.lora_adapters.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1133
1134
        thread::spawn(move || {
            shard_manager(
1135
                model_id,
1136
                revision,
1137
                quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1138
                speculate,
1139
                dtype,
1140
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1141
1142
1143
1144
1145
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
1146
1147
                huggingface_hub_cache,
                weights_cache_override,
1148
                disable_custom_kernels,
1149
1150
                watermark_gamma,
                watermark_delta,
1151
                cuda_graphs_clone,
1152
                cuda_memory_fraction,
Nicolas Patry's avatar
Nicolas Patry committed
1153
1154
                rope_scaling,
                rope_factor,
1155
1156
                max_total_tokens,
                max_batch_size,
1157
                max_input_tokens,
drbh's avatar
drbh committed
1158
                lora_adapters,
1159
                otlp_endpoint,
1160
                otlp_service_name,
1161
                max_log_level,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
1183
            Ok(ShardStatus::Failed(rank)) => {
1184
                tracing::error!("Shard {rank} failed to start");
1185
                shutdown_shards(shutdown, shutdown_receiver);
1186
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1187
1188
1189
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
1190
                shutdown_shards(shutdown, shutdown_receiver);
1191
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1192
1193
1194
            }
        }
    }
1195
1196
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
fn compute_type(num_shard: usize) -> Option<String> {
    let output = Command::new("nvidia-smi")
        .args(["--query-gpu=gpu_name", "--format=csv"])
        .output()
        .ok()?;
    let output = String::from_utf8(output.stdout).ok()?;
    let fullname = output.split('\n').nth(1)?;
    let cardname = fullname.replace(' ', "-").to_lowercase();
    let compute_type = format!("{num_shard}-{cardname}");
    Some(compute_type)
}

1210
fn spawn_webserver(
1211
    num_shard: usize,
1212
    args: Args,
1213
1214
1215
    max_input_tokens: usize,
    max_total_tokens: usize,
    max_batch_prefill_tokens: u32,
1216
    shutdown: Arc<AtomicBool>,
1217
    shutdown_receiver: &mpsc::Receiver<()>,
1218
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1219
1220
1221
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
1222
    let mut router_args = vec![
1223
1224
        "--max-client-batch-size".to_string(),
        args.max_client_batch_size.to_string(),
1225
        "--max-concurrent-requests".to_string(),
1226
        args.max_concurrent_requests.to_string(),
1227
        "--max-best-of".to_string(),
1228
        args.max_best_of.to_string(),
1229
        "--max-stop-sequences".to_string(),
1230
        args.max_stop_sequences.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
1231
1232
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
1233
1234
        "--max-input-tokens".to_string(),
        max_input_tokens.to_string(),
1235
        "--max-total-tokens".to_string(),
1236
        max_total_tokens.to_string(),
1237
        "--max-batch-prefill-tokens".to_string(),
1238
        max_batch_prefill_tokens.to_string(),
1239
        "--waiting-served-ratio".to_string(),
1240
        args.waiting_served_ratio.to_string(),
1241
        "--max-waiting-tokens".to_string(),
1242
        args.max_waiting_tokens.to_string(),
1243
1244
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
1245
1246
        "--hostname".to_string(),
        args.hostname.to_string(),
1247
        "--port".to_string(),
1248
        args.port.to_string(),
1249
        "--master-shard-uds-path".to_string(),
1250
        format!("{}-0", args.shard_uds_path),
1251
        "--tokenizer-name".to_string(),
1252
        args.model_id,
1253
1254
    ];

1255
    // Pass usage stats flags to router
1256
1257
    router_args.push("--usage-stats".to_string());
    router_args.push(args.usage_stats.to_string());
1258

drbh's avatar
drbh committed
1259
1260
1261
1262
1263
    // Grammar support
    if args.disable_grammar_support {
        router_args.push("--disable-grammar-support".to_string());
    }

1264
1265
1266
1267
1268
1269
    // Tokenizer config path
    if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
        router_args.push("--tokenizer-config-path".to_string());
        router_args.push(tokenizer_config_path.to_string());
    }

1270
1271
1272
1273
1274
1275
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

1276
1277
1278
1279
1280
1281
    // Router optional max batch size
    if let Some(max_batch_size) = args.max_batch_size {
        router_args.push("--max-batch-size".to_string());
        router_args.push(max_batch_size.to_string());
    }

1282
1283
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
1284
1285
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
1286
1287
    }

1288
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
1289
        router_args.push("--json-output".to_string());
1290
1291
    }

1292
    // OpenTelemetry
1293
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
1294
1295
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
1296
1297
    }

1298
1299
1300
1301
1302
    // OpenTelemetry
    let otlp_service_name = args.otlp_service_name;
    router_args.push("--otlp-service-name".to_string());
    router_args.push(otlp_service_name);

1303
1304
    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
1305
1306
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
1307
1308
    }

Erik Kaunismäki's avatar
Erik Kaunismäki committed
1309
1310
1311
1312
1313
    // API Key
    if let Some(api_key) = args.api_key {
        router_args.push("--api-key".to_string());
        router_args.push(api_key);
    }
1314
1315
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
1316
1317
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
1318
1319
1320
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
1321
1322
    }

1323
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1324
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1325

1326
1327
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1328
        envs.push(("HF_TOKEN".into(), api_token.into()))
1329
    };
1330

1331
1332
1333
1334
1335
1336
1337
    // Parse Compute type
    if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    } else if let Some(compute_type) = compute_type(num_shard) {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    }

1338
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
1339
1340
        .args(router_args)
        .envs(envs)
1341
1342
1343
1344
1345
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1346
1347
        Ok(p) => p,
        Err(err) => {
1348
            tracing::error!("Failed to start webserver: {}", err);
1349
1350
1351
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
1352
1353
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1354
            }
1355

1356
            shutdown_shards(shutdown, shutdown_receiver);
1357
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1358
1359
1360
        }
    };

1361
1362
1363
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
1364
1365

    thread::spawn(move || {
1366
1367
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
1368
        for line in stdout.lines() {
1369
            println!("{}", line.unwrap());
1370
        }
1371
1372
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1373
        }
1374
1375
1376
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1377

OlivierDehaene's avatar
OlivierDehaene committed
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");
    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }
    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1401
1402
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
1403
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1404

1405
    // Filter events with LOG_LEVEL
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
    let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
1422

1423
    if args.json_output {
1424
1425
1426
1427
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1428
    } else {
1429
1430
1431
1432
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1433
1434
    }

1435
1436
1437
1438
1439
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

Nicolas Patry's avatar
Nicolas Patry committed
1440
    tracing::info!("{:#?}", args);
1441

1442
1443
1444
1445
1446
1447
    let get_max_positions_quantize =
        || -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
            let model_id = args.model_id.clone();
            let mut path = std::path::Path::new(&args.model_id).to_path_buf();
            let filename = if !path.exists() {
                // Assume it's a hub id
Nicolas Patry's avatar
Nicolas Patry committed
1448

1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
                let api = if let Ok(token) = std::env::var("HF_TOKEN") {
                    // env variable has precedence over on file token.
                    ApiBuilder::new().with_token(Some(token)).build()?
                } else {
                    Api::new()?
                };
                let repo = if let Some(ref revision) = args.revision {
                    api.repo(Repo::with_revision(
                        model_id,
                        RepoType::Model,
                        revision.to_string(),
                    ))
                } else {
                    api.model(model_id)
                };
                repo.get("config.json")?
1465
            } else {
1466
1467
                path.push("config.json");
                path
1468
1469
            };

1470
1471
            let content = std::fs::read_to_string(filename)?;
            let config: RawConfig = serde_json::from_str(&content)?;
1472

1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
            if config.model_type == Some("gemma2".to_string()) {
                tracing::info!("Forcing flash decoding because of softcap usage");
                std::env::set_var("ATTENTION", "flashdecoding");
            }
            let config: Config = config.into();
            let quantize = config.quantize;

            // Quantization usually means you're even more RAM constrained.
            let max_default = 4096;

            if let Some(max_position_embeddings) = config.max_position_embeddings {
                if max_position_embeddings > max_default {
                    let max = max_position_embeddings;
                    if args.max_input_tokens.is_none()
                        && args.max_total_tokens.is_none()
                        && args.max_batch_prefill_tokens.is_none()
                    {
                        tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
                    }
                    Ok((max_default, quantize))
                } else {
                    Ok((max_position_embeddings, quantize))
1495
                }
1496
            } else {
1497
1498
1499
                Err(Box::new(LauncherError::ArgumentValidation(
                    "no max defined".to_string(),
                )))
1500
            }
1501
1502
1503
        };
    let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
        get_max_positions_quantize().unwrap_or((4096, None));
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546

    let max_input_tokens = {
        match (args.max_input_tokens, args.max_input_length) {
            (Some(max_input_tokens), Some(max_input_length)) => {
                return Err(LauncherError::ArgumentValidation(
                    format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
                )));
            }
            (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
            (None, None) => {
                let value = max_position_embeddings - 1;
                tracing::info!("Default `max_input_tokens` to {value}");
                value
            }
        }
    };
    let max_total_tokens = {
        match args.max_total_tokens {
            Some(max_total_tokens) => max_total_tokens,
            None => {
                let value = max_position_embeddings;
                tracing::info!("Default `max_total_tokens` to {value}");
                value
            }
        }
    };
    let max_batch_prefill_tokens = {
        match args.max_batch_prefill_tokens {
            Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
            None => {
                let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
                    max_batch_size * max_input_tokens
                } else {
                    // Adding some edge in order to account for potential block_size alignement
                    // issue.
                    max_input_tokens + 50
                } as u32;
                tracing::info!("Default `max_batch_prefill_tokens` to {value}");
                value
            }
        }
    };

1547
    // Validate args
1548
    if max_input_tokens >= max_total_tokens {
1549
        return Err(LauncherError::ArgumentValidation(
1550
            "`max_input_tokens must be < `max_total_tokens`".to_string(),
1551
1552
        ));
    }
1553
    if max_input_tokens as u32 > max_batch_prefill_tokens {
1554
        return Err(LauncherError::ArgumentValidation(format!(
1555
1556
            "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
            max_batch_prefill_tokens, max_input_tokens
1557
1558
        )));
    }
1559

1560
1561
1562
1563
1564
    if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
        tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
    }
    let quantize = args.quantize.or(quantize);
    let cuda_graphs = match (&args.cuda_graphs, &quantize) {
Nicolas Patry's avatar
Nicolas Patry committed
1565
        (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
1566
1567
1568
1569
1570
        #[allow(deprecated)]
        (
            None,
            Some(
                Quantization::Bitsandbytes
1571
1572
                | Quantization::BitsandbytesNf4
                | Quantization::BitsandbytesFp4,
1573
1574
            ),
        ) => {
1575
1576
1577
1578
1579
            tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
            vec![]
        }
        (None, Some(Quantization::Exl2)) => {
            tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
1580
1581
1582
1583
1584
1585
1586
1587
1588
            vec![]
        }
        _ => {
            let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
            tracing::info!("Using default cuda graphs {cuda_graphs:?}");
            cuda_graphs
        }
    };

1589
1590
1591
1592
1593
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1594
1595
1596
1597
1598
1599
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1600
1601

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1602
    if num_shard > 1 {
1603
1604
1605
1606
1607
        if matches!(args.quantize, Some(Quantization::Exl2)) {
            return Err(LauncherError::ArgumentValidation(
                "Sharding is currently not supported with `exl2` quantization".into(),
            ));
        }
1608
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1609
1610
    }

1611
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
1612
        if max_batch_prefill_tokens > *max_batch_total_tokens {
1613
1614
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
1615
                max_batch_prefill_tokens, max_batch_total_tokens
1616
1617
            )));
        }
1618
        if max_total_tokens as u32 > *max_batch_total_tokens {
1619
1620
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
1621
                max_total_tokens, max_batch_total_tokens
1622
1623
1624
1625
            )));
        }
    }

1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1640
1641
1642
1643
1644
1645
1646
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1647

1648
    // Download and convert model weights
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
    download_convert_model(
        &args.model_id,
        args.revision.as_deref(),
        args.trust_remote_code,
        args.huggingface_hub_cache.as_deref(),
        args.weights_cache_override.as_deref(),
        running.clone(),
    )?;

    // Download and convert lora adapters if any
    if let Some(lora_adapters) = &args.lora_adapters {
        for adapter in lora_adapters.split(',') {
1661
1662
1663
1664
            // skip download if a path is provided
            if adapter.contains('=') {
                continue;
            }
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
            download_convert_model(
                adapter,
                None,
                args.trust_remote_code,
                args.huggingface_hub_cache.as_deref(),
                args.weights_cache_override.as_deref(),
                running.clone(),
            )?;
        }
    }
1675

OlivierDehaene's avatar
OlivierDehaene committed
1676
1677
1678
1679
1680
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1681
    // Shared shutdown bool
1682
    let shutdown = Arc::new(AtomicBool::new(false));
1683
1684
1685
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1686

1687
1688
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1689

1690
1691
1692
    spawn_shards(
        num_shard,
        &args,
1693
        cuda_graphs,
1694
        max_total_tokens,
1695
        max_input_tokens,
1696
        quantize,
1697
        max_log_level,
1698
1699
1700
1701
1702
1703
1704
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1705

1706
1707
1708
1709
1710
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1711

1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
    let mut webserver = spawn_webserver(
        num_shard,
        args,
        max_input_tokens,
        max_total_tokens,
        max_batch_prefill_tokens,
        shutdown.clone(),
        &shutdown_receiver,
    )
    .map_err(|err| {
        shutdown_shards(shutdown.clone(), &shutdown_receiver);
        err
    })?;
1725
1726
1727
1728
1729

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1730
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1731
            tracing::error!("Shard {rank} crashed");
1732
1733
1734
1735
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1736
        match webserver.try_wait().unwrap() {
1737
1738
1739
1740
1741
1742
1743
1744
1745
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1746
    }
1747
1748

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1749
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1750
1751
1752
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1753
}