main.rs 59.6 KB
Newer Older
1
use clap::{Parser, ValueEnum};
Nicolas Patry's avatar
Nicolas Patry committed
2
3
4
5
use hf_hub::{
    api::sync::{Api, ApiBuilder},
    Repo, RepoType,
};
6
7
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
8
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
9
use std::env;
10
use std::ffi::OsString;
11
use std::io::{BufRead, BufReader, Lines};
12
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
13
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
14
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
16
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
17
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
20
21
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
22
use thiserror::Error;
23
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24

25
26
mod env_runtime;

27
#[derive(Deserialize)]
28
struct RawConfig {
29
    max_position_embeddings: Option<usize>,
30
    n_positions: Option<usize>,
31
    model_type: Option<String>,
32
33
34
    max_seq_len: Option<usize>,
}

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#[derive(Deserialize)]
struct Config {
    max_position_embeddings: Option<usize>,
}

impl From<RawConfig> for Config {
    fn from(other: RawConfig) -> Self {
        let max_position_embeddings = other
            .max_position_embeddings
            .or(other.max_seq_len)
            .or(other.n_positions);
        Config {
            max_position_embeddings,
        }
    }
}

52
53
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
54
    /// 4 bit quantization. Requires a specific AWQ quantized model:
55
    ///   <https://hf.co/models?search=awq>.
56
    /// Should replace GPTQ models wherever possible because of the better latency
57
58
59
    Awq,
    /// 8 bit quantization, doesn't require specific model.
    /// Should be a drop-in replacement to bitsandbytes with much better performance.
60
    /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
61
    Eetq,
62
63
64
65
    /// Variable bit quantization. Requires a specific EXL2 quantized model:
    /// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
    /// not support tensor parallelism (num_shard > 1).
    Exl2,
66
    /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
67
    /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
68
69
70
    /// triton kernel (wider support) when it's not.
    /// AWQ has faster kernels.
    Gptq,
71
72
    /// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
    Marlin,
73
74
75
76
77
78
    /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
    /// but it is known that the model will be much slower to run than the native f16.
    #[deprecated(
        since = "1.1.0",
        note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
    )]
79
    Bitsandbytes,
80
81
    /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
    /// but it is known that the model will be much slower to run than the native f16.
Nicolas Patry's avatar
Nicolas Patry committed
82
    BitsandbytesNF4,
83
84
    /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
    /// perplexity performance for you model
Nicolas Patry's avatar
Nicolas Patry committed
85
    BitsandbytesFP4,
Nicolas Patry's avatar
Nicolas Patry committed
86
87
88
89
90
    /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
    /// This dtype has native ops should be the fastest if available.
    /// This is currently not the fastest because of local unpacking + padding to satisfy matrix
    /// multiplication limitations.
    Fp8,
91
92
93
94
95
96
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
97
98
            #[allow(deprecated)]
            // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases
99
100
101
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
Nicolas Patry's avatar
Nicolas Patry committed
102
103
104
105
106
107
            Quantization::BitsandbytesNF4 => {
                write!(f, "bitsandbytes-nf4")
            }
            Quantization::BitsandbytesFP4 => {
                write!(f, "bitsandbytes-fp4")
            }
108
109
110
            Quantization::Exl2 => {
                write!(f, "exl2")
            }
111
112
113
            Quantization::Gptq => {
                write!(f, "gptq")
            }
114
115
116
            Quantization::Marlin => {
                write!(f, "marlin")
            }
117
118
119
            Quantization::Awq => {
                write!(f, "awq")
            }
120
121
122
            Quantization::Eetq => {
                write!(f, "eetq")
            }
Nicolas Patry's avatar
Nicolas Patry committed
123
124
125
            Quantization::Fp8 => {
                write!(f, "fp8")
            }
126
127
128
129
        }
    }
}

130
131
132
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
133
    #[clap(name = "bfloat16")]
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

Nicolas Patry's avatar
Nicolas Patry committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
    Linear,
    Dynamic,
}

impl std::fmt::Display for RopeScaling {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            RopeScaling::Linear => {
                write!(f, "linear")
            }
            RopeScaling::Dynamic => {
                write!(f, "dynamic")
            }
        }
    }
}

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#[derive(Clone, Copy, Debug, ValueEnum)]
pub enum UsageStatsLevel {
    /// Default option, usage statistics are collected anonymously
    On,
    /// Disables all collection of usage statistics
    Off,
    /// Doesn't send the error stack trace or error type, but allows sending a crash event
    NoStack,
}

impl std::fmt::Display for UsageStatsLevel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            UsageStatsLevel::On => {
                write!(f, "on")
            }
            UsageStatsLevel::Off => {
                write!(f, "off")
            }
            UsageStatsLevel::NoStack => {
                write!(f, "no-stack")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
198
199
200
201
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
202
203
204
205
206
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
207
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
208
    model_id: String,
209
210
211

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
212
    #[clap(long, env)]
213
    revision: Option<String>,
214

215
216
217
218
219
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

220
    /// Whether to shard the model across multiple GPUs
221
222
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
223
224
    #[clap(long, env)]
    sharded: Option<bool>,
225
226

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
227
228
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
229
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
230
231
    #[clap(long, env)]
    num_shard: Option<usize>,
232

233
    /// Whether you want the model to be quantized.
234
235
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
236

Nicolas Patry's avatar
Nicolas Patry committed
237
238
239
240
241
242
243
    /// The number of input_ids to speculate on
    /// If using a medusa model, the heads will be picked up automatically
    /// Other wise, it will use n-gram speculation which is relatively free
    /// in terms of compute, but the speedup heavily depends on the task.
    #[clap(long, env)]
    speculate: Option<usize>,

244
245
246
247
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

248
249
250
251
252
253
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

254
255
256
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
257
258
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
259
260
261
262

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
263
264
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
265
266
267
268
269
270

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
271
272
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
273

Nicolas Patry's avatar
Nicolas Patry committed
274
    /// This is the maximum allowed value for clients to set `top_n_tokens`.
275
    /// `top_n_tokens` is used to return information about the the `n` most likely
Nicolas Patry's avatar
Nicolas Patry committed
276
277
278
279
280
281
    /// tokens at each generation step, instead of just the sampled token. This
    /// information can be used for downstream tasks like for classification or
    /// ranking.
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,

282
283
284
285
    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
286
287
288
289
290
291
292
    /// Default to min(max_position_embeddings - 1, 4095)
    #[clap(long, env)]
    max_input_tokens: Option<usize>,

    /// Legacy version of [`Args::max_input_tokens`].
    #[clap(long, env)]
    max_input_length: Option<usize>,
293
294
295
296
297
298
299
300
301

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
302
303
304
    /// Default to min(max_position_embeddings, 4096)
    #[clap(long, env)]
    max_total_tokens: Option<usize>,
305
306
307
308
309
310
311
312
313
314
315

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
316
    #[clap(default_value = "0.3", long, env)]
317
    waiting_served_ratio: f32,
318

319
320
321
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
322
323
324
    /// Default to `max_input_tokens + 50` to give a bit of room.
    #[clap(long, env)]
    max_batch_prefill_tokens: Option<u32>,
325

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
343
344
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
363
364
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
365

366
367
368
369
370
    /// Enforce a maximum number of requests per batch
    /// Specific flag for hardware targets that do not support unpadded inference
    #[clap(long, env)]
    max_batch_size: Option<usize>,

371
372
    /// Specify the batch sizes to compute cuda graphs for.
    /// Use "0" to disable.
373
374
375
    /// Default = "1,2,4,8,16,32"
    #[clap(long, env, value_delimiter = ',')]
    cuda_graphs: Option<Vec<usize>>,
376

377
378
379
380
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

381
    /// The port to listen on.
382
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
383
    port: u16,
384
385
386

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
387
388
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
389
390

    /// The address the master shard will listen on. (setting used by torch distributed)
391
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
392
    master_addr: String,
393
394

    /// The address the master port will listen on. (setting used by torch distributed)
395
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
396
    master_port: usize,
397
398
399

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
400
    #[clap(long, env)]
401
    huggingface_hub_cache: Option<String>,
402
403
404

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
405
406
    #[clap(long, env)]
    weights_cache_override: Option<String>,
407
408
409
410
411

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
412
    #[clap(long, env)]
413
    disable_custom_kernels: bool,
414

415
416
417
418
419
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

Nicolas Patry's avatar
Nicolas Patry committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    /// Rope scaling will only be used for RoPE models
    /// and allow rescaling the position rotary to accomodate for
    /// larger prompts.
    ///
    /// Goes together with `rope_factor`.
    ///
    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
    /// basically)
    ///
    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want
    #[clap(long, env)]
    rope_scaling: Option<RopeScaling>,

    /// Rope scaling will only be used for RoPE models
    /// See `rope_scaling`
    #[clap(long, env)]
    rope_factor: Option<f32>,

440
    /// Outputs the logs in JSON format (useful for telemetry)
441
    #[clap(long, env)]
442
    json_output: bool,
443

444
445
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
446

447
448
449
    #[clap(default_value = "text-generation-inference.router", long, env)]
    otlp_service_name: String,

450
451
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
Erik Kaunismäki's avatar
Erik Kaunismäki committed
452
453
454
455

    #[clap(long, env)]
    api_key: Option<String>,

456
457
458
459
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
460

461
462
463
464
465
466
467
468
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

469
    /// ngrok edge
470
    #[clap(long, env)]
471
    ngrok_edge: Option<String>,
472

473
474
475
476
477
    /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
    /// include a `chat_template`. If not provided, the default config will be used from the model hub.
    #[clap(long, env)]
    tokenizer_config_path: Option<String>,

drbh's avatar
drbh committed
478
479
480
481
482
    /// Disable outlines grammar constrained generation.
    /// This is a feature that allows you to generate text that follows a specific grammar.
    #[clap(long, env)]
    disable_grammar_support: bool,

483
484
485
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
486
487
488
489

    /// Control the maximum number of inputs that a client can send in a single request
    #[clap(default_value = "4", long, env)]
    max_client_batch_size: usize,
drbh's avatar
drbh committed
490
491
492
493
494

    /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
    /// startup that will be available to callers via the `adapter_id` field in a request.
    #[clap(long, env)]
    lora_adapters: Option<String>,
495

496
497
498
499
500
    /// Control if anonymous usage stats are collected.
    /// Options are "on", "off" and "no-stack"
    /// Defaul is on.
    #[clap(default_value = "on", long, env)]
    usage_stats: UsageStatsLevel,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
501
502
}

503
504
505
#[derive(Debug)]
enum ShardStatus {
    Ready,
506
    Failed(usize),
507
}
508

509
510
511
512
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
513
    quantize: Option<Quantization>,
Nicolas Patry's avatar
Nicolas Patry committed
514
    speculate: Option<usize>,
515
    dtype: Option<Dtype>,
516
    trust_remote_code: bool,
517
518
519
520
521
522
523
524
525
526
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
527
    cuda_graphs: Vec<usize>,
528
    cuda_memory_fraction: f32,
Nicolas Patry's avatar
Nicolas Patry committed
529
530
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
531
532
    max_total_tokens: usize,
    max_batch_size: Option<usize>,
533
    max_input_tokens: usize,
drbh's avatar
drbh committed
534
    lora_adapters: Option<String>,
535
    otlp_endpoint: Option<String>,
536
    otlp_service_name: String,
537
    log_level: LevelFilter,
538
    status_sender: mpsc::Sender<ShardStatus>,
539
    shutdown: Arc<AtomicBool>,
540
541
    _shutdown_sender: mpsc::Sender<()>,
) {
542
543
544
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

545
546
547
548
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
549
550
551
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
552
553

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
554
    let mut shard_args = vec![
555
556
557
558
559
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
560
        log_level.to_string().to_uppercase(),
561
562
563
        "--json-output".to_string(),
    ];

564
565
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
566
        shard_args.push("--trust-remote-code".to_string());
567
568
    }

569
570
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
571
        shard_args.push("--sharded".to_string());
572
573
    }

574
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
575
576
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
577
    }
578

Nicolas Patry's avatar
Nicolas Patry committed
579
580
581
582
583
    if let Some(speculate) = speculate {
        shard_args.push("--speculate".to_string());
        shard_args.push(speculate.to_string())
    }

584
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
585
586
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
587
588
    }

589
590
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
591
592
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
593
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
594

Nicolas Patry's avatar
Nicolas Patry committed
595
596
597
598
599
600
    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };
601

602
    // OpenTelemetry Endpoint
603
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
604
605
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
606
607
    }

608
609
610
611
    // OpenTelemetry Service Name
    shard_args.push("--otlp-service-name".to_string());
    shard_args.push(otlp_service_name);

612
613
614
615
    // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
    shard_args.push("--max-input-tokens".to_string());
    shard_args.push(max_input_tokens.to_string());

616
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
617
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
618

619
620
621
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

622
    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
623
624
625
626
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
627
    envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
628

629
630
631
632
633
634
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

635
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
636
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
637

638
639
640
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

641
642
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
643
    envs.push((
644
645
646
647
648
649
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
650
        envs.push(("HF_TOKEN".into(), api_token.into()))
651
652
    };

Nicolas Patry's avatar
Nicolas Patry committed
653
654
655
656
657
658
659
660
661
    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

662
663
664
665
666
667
668
669
    envs.push((
        "MAX_TOTAL_TOKENS".into(),
        max_total_tokens.to_string().into(),
    ));
    if let Some(max_batch_size) = max_batch_size {
        envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
    }

drbh's avatar
drbh committed
670
671
672
673
674
    // Lora Adapters
    if let Some(lora_adapters) = lora_adapters {
        envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
    }

675
676
677
    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
678
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
679
680
681
682
683
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
684
        envs.push((
685
686
687
688
689
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

690
    // Enable experimental support for cuda graphs
691
692
693
694
695
696
697
698
699
700
    if !cuda_graphs.is_empty() {
        envs.push((
            "CUDA_GRAPHS".into(),
            cuda_graphs
                .into_iter()
                .map(|c| c.to_string())
                .collect::<Vec<_>>()
                .join(",")
                .into(),
        ));
701
702
    }

703
704
    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
705
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
706
707
708
709
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
710
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
711
712
713
714
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
715
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
716
717
718
    }

    // Start process
719
    tracing::info!("Starting shard");
720
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
721
        .args(shard_args)
722
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
723
        .envs(envs)
724
725
726
727
728
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
729
730
        Ok(p) => p,
        Err(err) => {
731
732
733
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
734
735
            }
            {
736
                tracing::error!("{}", err);
737
            }
738

739
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
740
741
742
743
744
            return;
        }
    };

    // Redirect STDOUT to the console
745
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
746
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
747

748
    //stdout tracing thread
749
    thread::spawn(move || {
750
        log_lines(shard_stdout_reader.lines());
751
    });
752
753
754
    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
755
        for line in shard_stderr_reader.lines().map_while(Result::ok) {
756
757
758
            err_sender.send(line).unwrap_or(());
        }
    });
759
760
761
762
763
764

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
765
        if let Some(exit_status) = p.try_wait().unwrap() {
766
767
768
769
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
770

771
            tracing::error!("Shard complete standard error output:\n{err}");
772

773
            if let Some(signal) = exit_status.signal() {
774
775
776
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

777
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
778
779
780
781
            return;
        }

        // We received a shutdown signal
782
        if shutdown.load(Ordering::SeqCst) {
783
            terminate("shard", p, Duration::from_secs(90)).unwrap();
784
785
786
787
788
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
789
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
790
791
792
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
793
            tracing::info!("Waiting for shard to be ready...");
794
795
796
797
798
799
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

800
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
801
802
803
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
804
    shutdown.store(true, Ordering::SeqCst);
805
806
807
808
809
810
811

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
812
813
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
814
815
816
        Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
            Ok(devices) => devices,
            Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
Nicolas Patry's avatar
Nicolas Patry committed
817
        },
818
    };
819
820
    let n_devices = devices.split(',').count();
    Some(n_devices)
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
854
855
856
857
858
859
860
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
861
862
863
864
        }
    }
}

865
866
867
868
869
870
871
872
873
impl TryFrom<&String> for PythonLogMessage {
    type Error = serde_json::Error;

    fn try_from(value: &String) -> Result<Self, Self::Error> {
        serde_json::from_str::<Self>(value)
    }
}

fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
OlivierDehaene's avatar
OlivierDehaene committed
874
    for line in lines.map_while(Result::ok) {
875
876
877
878
879
880
881
        match PythonLogMessage::try_from(&line) {
            Ok(log) => log.trace(),
            Err(_) => tracing::debug!("{line}"),
        }
    }
}

882
883
884
885
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
886
887
888
889
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
890
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
891
            let n_devices = num_cuda_devices()
892
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
893
            if n_devices <= 1 {
894
895
896
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
897
            }
898
            n_devices
899
        }
900
901
902
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
903
904
905
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
906
907
            }
            num_shard
908
        }
909
910
911
912
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
913
    };
914
    if num_shard < 1 {
915
916
917
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
918
    }
919
    Ok(num_shard)
920
}
921

922
#[derive(Debug, Error)]
923
enum LauncherError {
924
    #[error("Invalid argument: {0}")]
925
    ArgumentValidation(String),
926
    #[error("not enough cuda devices: {0}")]
927
    NotEnoughCUDADevices(String),
928
    #[error("Download error")]
929
    DownloadError,
930
    #[error("Shard cannot start")]
931
    ShardCannotStart,
932
    #[error("Shard disconnected")]
933
    ShardDisconnected,
934
    #[error("Shard failed")]
935
    ShardFailed,
936
    #[error("Webserver failed")]
937
    WebserverFailed,
938
    #[error("Webserver cannot start")]
939
940
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
941

942
943
944
945
946
947
948
949
fn download_convert_model(
    model_id: &str,
    revision: Option<&str>,
    trust_remote_code: bool,
    huggingface_hub_cache: Option<&str>,
    weights_cache_override: Option<&str>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
950
951
952
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
953
    let mut download_args = vec![
954
        "download-weights".to_string(),
955
        model_id.to_string(),
956
957
958
959
960
961
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
962

963
    // Model optional revision
964
    if let Some(revision) = &revision {
OlivierDehaene's avatar
OlivierDehaene committed
965
966
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
967
    }
968

969
    // Trust remote code for automatic peft fusion
970
    if trust_remote_code {
971
972
973
        download_args.push("--trust-remote-code".to_string());
    }

974
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
975
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
976

977
978
979
    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

980
981
982
    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

983
    // If huggingface_hub_cache is set, pass it to the download process
984
    // Useful when running inside a docker container
985
    if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
986
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
987
    };
988

989
990
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
991
    envs.push((
992
993
994
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
995

996
997
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
998
        envs.push(("HF_TOKEN".into(), api_token.into()))
999
    };
1000

1001
1002
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
1003
    if let Some(weights_cache_override) = &weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
1004
        envs.push((
1005
1006
1007
1008
1009
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

1010
    // Start process
1011
    tracing::info!("Starting check and download process for {model_id}");
1012
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
1013
        .args(download_args)
1014
        .env_clear()
OlivierDehaene's avatar
OlivierDehaene committed
1015
        .envs(envs)
1016
1017
1018
1019
1020
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
1021
1022
        Ok(p) => p,
        Err(err) => {
1023
1024
1025
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
1026
1027
            } else {
                tracing::error!("{}", err);
1028
            }
1029

1030
1031
1032
            return Err(LauncherError::DownloadError);
        }
    };
1033

1034
    let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
1035

1036
    thread::spawn(move || {
1037
1038
1039
1040
1041
1042
1043
1044
        log_lines(download_stdout.lines());
    });

    let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
OlivierDehaene's avatar
OlivierDehaene committed
1045
        for line in download_stderr.lines().map_while(Result::ok) {
1046
1047
            err_sender.send(line).unwrap_or(());
        }
1048
    });
1049

1050
    loop {
1051
1052
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
1053
                tracing::info!("Successfully downloaded weights for {model_id}");
1054
                break;
1055
            }
1056
1057

            let mut err = String::new();
1058
1059
1060
1061
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }

1062
1063
1064
1065
1066
1067
1068
1069
1070
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
1071
        }
1072
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
1073
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
1074
1075
1076
            return Ok(());
        }
        sleep(Duration::from_millis(100));
1077
    }
1078
1079
    Ok(())
}
1080

1081
#[allow(clippy::too_many_arguments)]
1082
1083
1084
fn spawn_shards(
    num_shard: usize,
    args: &Args,
1085
    cuda_graphs: Vec<usize>,
1086
    max_total_tokens: usize,
1087
    max_input_tokens: usize,
1088
    max_log_level: LevelFilter,
1089
    shutdown: Arc<AtomicBool>,
1090
1091
1092
1093
1094
1095
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1096
1097
    // Start shard processes
    for rank in 0..num_shard {
1098
1099
1100
1101
1102
1103
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1104
1105
1106
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
1107
        let otlp_endpoint = args.otlp_endpoint.clone();
1108
        let otlp_service_name = args.otlp_service_name.clone();
1109
        let quantize = args.quantize;
Nicolas Patry's avatar
Nicolas Patry committed
1110
        let speculate = args.speculate;
1111
        let dtype = args.dtype;
1112
        let trust_remote_code = args.trust_remote_code;
1113
1114
1115
1116
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
1117
        let cuda_graphs_clone = cuda_graphs.clone();
1118
        let cuda_memory_fraction = args.cuda_memory_fraction;
Nicolas Patry's avatar
Nicolas Patry committed
1119
1120
        let rope_scaling = args.rope_scaling;
        let rope_factor = args.rope_factor;
1121
        let max_batch_size = args.max_batch_size;
drbh's avatar
drbh committed
1122
        let lora_adapters = args.lora_adapters.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1123
1124
        thread::spawn(move || {
            shard_manager(
1125
                model_id,
1126
                revision,
1127
                quantize,
Nicolas Patry's avatar
Nicolas Patry committed
1128
                speculate,
1129
                dtype,
1130
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1131
1132
1133
1134
1135
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
1136
1137
                huggingface_hub_cache,
                weights_cache_override,
1138
                disable_custom_kernels,
1139
1140
                watermark_gamma,
                watermark_delta,
1141
                cuda_graphs_clone,
1142
                cuda_memory_fraction,
Nicolas Patry's avatar
Nicolas Patry committed
1143
1144
                rope_scaling,
                rope_factor,
1145
1146
                max_total_tokens,
                max_batch_size,
1147
                max_input_tokens,
drbh's avatar
drbh committed
1148
                lora_adapters,
1149
                otlp_endpoint,
1150
                otlp_service_name,
1151
                max_log_level,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
1173
            Ok(ShardStatus::Failed(rank)) => {
1174
                tracing::error!("Shard {rank} failed to start");
1175
                shutdown_shards(shutdown, shutdown_receiver);
1176
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1177
1178
1179
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
1180
                shutdown_shards(shutdown, shutdown_receiver);
1181
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1182
1183
1184
            }
        }
    }
1185
1186
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
fn compute_type(num_shard: usize) -> Option<String> {
    let output = Command::new("nvidia-smi")
        .args(["--query-gpu=gpu_name", "--format=csv"])
        .output()
        .ok()?;
    let output = String::from_utf8(output.stdout).ok()?;
    let fullname = output.split('\n').nth(1)?;
    let cardname = fullname.replace(' ', "-").to_lowercase();
    let compute_type = format!("{num_shard}-{cardname}");
    Some(compute_type)
}

1200
fn spawn_webserver(
1201
    num_shard: usize,
1202
    args: Args,
1203
1204
1205
    max_input_tokens: usize,
    max_total_tokens: usize,
    max_batch_prefill_tokens: u32,
1206
    shutdown: Arc<AtomicBool>,
1207
    shutdown_receiver: &mpsc::Receiver<()>,
1208
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1209
1210
1211
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
1212
    let mut router_args = vec![
1213
1214
        "--max-client-batch-size".to_string(),
        args.max_client_batch_size.to_string(),
1215
        "--max-concurrent-requests".to_string(),
1216
        args.max_concurrent_requests.to_string(),
1217
        "--max-best-of".to_string(),
1218
        args.max_best_of.to_string(),
1219
        "--max-stop-sequences".to_string(),
1220
        args.max_stop_sequences.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
1221
1222
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
1223
1224
        "--max-input-tokens".to_string(),
        max_input_tokens.to_string(),
1225
        "--max-total-tokens".to_string(),
1226
        max_total_tokens.to_string(),
1227
        "--max-batch-prefill-tokens".to_string(),
1228
        max_batch_prefill_tokens.to_string(),
1229
        "--waiting-served-ratio".to_string(),
1230
        args.waiting_served_ratio.to_string(),
1231
        "--max-waiting-tokens".to_string(),
1232
        args.max_waiting_tokens.to_string(),
1233
1234
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
1235
1236
        "--hostname".to_string(),
        args.hostname.to_string(),
1237
        "--port".to_string(),
1238
        args.port.to_string(),
1239
        "--master-shard-uds-path".to_string(),
1240
        format!("{}-0", args.shard_uds_path),
1241
        "--tokenizer-name".to_string(),
1242
        args.model_id,
1243
1244
    ];

1245
    // Pass usage stats flags to router
1246
1247
    router_args.push("--usage-stats".to_string());
    router_args.push(args.usage_stats.to_string());
1248

drbh's avatar
drbh committed
1249
1250
1251
1252
1253
    // Grammar support
    if args.disable_grammar_support {
        router_args.push("--disable-grammar-support".to_string());
    }

1254
1255
1256
1257
1258
1259
    // Tokenizer config path
    if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
        router_args.push("--tokenizer-config-path".to_string());
        router_args.push(tokenizer_config_path.to_string());
    }

1260
1261
1262
1263
1264
1265
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

1266
1267
1268
1269
1270
1271
    // Router optional max batch size
    if let Some(max_batch_size) = args.max_batch_size {
        router_args.push("--max-batch-size".to_string());
        router_args.push(max_batch_size.to_string());
    }

1272
1273
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
1274
1275
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
1276
1277
    }

1278
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
1279
        router_args.push("--json-output".to_string());
1280
1281
    }

1282
    // OpenTelemetry
1283
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
1284
1285
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
1286
1287
    }

1288
1289
1290
1291
1292
    // OpenTelemetry
    let otlp_service_name = args.otlp_service_name;
    router_args.push("--otlp-service-name".to_string());
    router_args.push(otlp_service_name);

1293
1294
    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
1295
1296
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
1297
1298
    }

Erik Kaunismäki's avatar
Erik Kaunismäki committed
1299
1300
1301
1302
1303
    // API Key
    if let Some(api_key) = args.api_key {
        router_args.push("--api-key".to_string());
        router_args.push(api_key);
    }
1304
1305
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
1306
1307
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
1308
1309
1310
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
1311
1312
    }

1313
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1314
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1315

1316
1317
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
1318
        envs.push(("HF_TOKEN".into(), api_token.into()))
1319
    };
1320

1321
1322
1323
1324
1325
1326
1327
    // Parse Compute type
    if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    } else if let Some(compute_type) = compute_type(num_shard) {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    }

1328
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
1329
1330
        .args(router_args)
        .envs(envs)
1331
1332
1333
1334
1335
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1336
1337
        Ok(p) => p,
        Err(err) => {
1338
            tracing::error!("Failed to start webserver: {}", err);
1339
1340
1341
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
1342
1343
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1344
            }
1345

1346
            shutdown_shards(shutdown, shutdown_receiver);
1347
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1348
1349
1350
        }
    };

1351
1352
1353
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
1354
1355

    thread::spawn(move || {
1356
1357
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
1358
        for line in stdout.lines() {
1359
            println!("{}", line.unwrap());
1360
        }
1361
1362
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1363
        }
1364
1365
1366
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1367

OlivierDehaene's avatar
OlivierDehaene committed
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");
    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }
    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1391
1392
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
1393
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1394

1395
    // Filter events with LOG_LEVEL
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
    let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
1412

1413
    if args.json_output {
1414
1415
1416
1417
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1418
    } else {
1419
1420
1421
1422
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1423
1424
    }

1425
1426
1427
1428
1429
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

Nicolas Patry's avatar
Nicolas Patry committed
1430
    tracing::info!("{:#?}", args);
1431

1432
1433
1434
1435
1436
    let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
        let model_id = args.model_id.clone();
        let mut path = std::path::Path::new(&args.model_id).to_path_buf();
        let filename = if !path.exists() {
            // Assume it's a hub id
Nicolas Patry's avatar
Nicolas Patry committed
1437
1438
1439
1440
1441
1442
1443

            let api = if let Ok(token) = std::env::var("HF_TOKEN") {
                // env variable has precedence over on file token.
                ApiBuilder::new().with_token(Some(token)).build()?
            } else {
                Api::new()?
            };
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
            let repo = if let Some(ref revision) = args.revision {
                api.repo(Repo::with_revision(
                    model_id,
                    RepoType::Model,
                    revision.to_string(),
                ))
            } else {
                api.model(model_id)
            };
            repo.get("config.json")?
        } else {
            path.push("config.json");
            path
        };

        let content = std::fs::read_to_string(filename)?;
1460
        let config: RawConfig = serde_json::from_str(&content)?;
1461
1462
1463
1464
1465

        if config.model_type == Some("gemma2".to_string()) {
            tracing::info!("Forcing flash decoding because of softcap usage");
            std::env::set_var("FLASH_DECODING", "1");
        }
1466
        let config: Config = config.into();
1467
1468
1469
1470

        // Quantization usually means you're even more RAM constrained.
        let max_default = 4096;

1471
1472
1473
1474
1475
1476
1477
1478
        if let Some(max_position_embeddings) = config.max_position_embeddings {
            if max_position_embeddings > max_default {
                let max = max_position_embeddings;
                if args.max_input_tokens.is_none()
                    && args.max_total_tokens.is_none()
                    && args.max_batch_prefill_tokens.is_none()
                {
                    tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
1479
                }
1480
1481
1482
                Ok(max_default)
            } else {
                Ok(max_position_embeddings)
1483
            }
1484
1485
1486
1487
1488
        } else {
            Err(Box::new(LauncherError::ArgumentValidation(
                "no max defined".to_string(),
            )))
        }
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
    };
    let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);

    let max_input_tokens = {
        match (args.max_input_tokens, args.max_input_length) {
            (Some(max_input_tokens), Some(max_input_length)) => {
                return Err(LauncherError::ArgumentValidation(
                    format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
                )));
            }
            (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
            (None, None) => {
                let value = max_position_embeddings - 1;
                tracing::info!("Default `max_input_tokens` to {value}");
                value
            }
        }
    };
    let max_total_tokens = {
        match args.max_total_tokens {
            Some(max_total_tokens) => max_total_tokens,
            None => {
                let value = max_position_embeddings;
                tracing::info!("Default `max_total_tokens` to {value}");
                value
            }
        }
    };
    let max_batch_prefill_tokens = {
        match args.max_batch_prefill_tokens {
            Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
            None => {
                let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
                    max_batch_size * max_input_tokens
                } else {
                    // Adding some edge in order to account for potential block_size alignement
                    // issue.
                    max_input_tokens + 50
                } as u32;
                tracing::info!("Default `max_batch_prefill_tokens` to {value}");
                value
            }
        }
    };

1534
    // Validate args
1535
    if max_input_tokens >= max_total_tokens {
1536
        return Err(LauncherError::ArgumentValidation(
1537
            "`max_input_tokens must be < `max_total_tokens`".to_string(),
1538
1539
        ));
    }
1540
    if max_input_tokens as u32 > max_batch_prefill_tokens {
1541
        return Err(LauncherError::ArgumentValidation(format!(
1542
1543
            "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
            max_batch_prefill_tokens, max_input_tokens
1544
1545
        )));
    }
1546

1547
    let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
Nicolas Patry's avatar
Nicolas Patry committed
1548
        (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        #[allow(deprecated)]
        (
            None,
            Some(
                Quantization::Bitsandbytes
                | Quantization::BitsandbytesNF4
                | Quantization::BitsandbytesFP4,
            ),
        ) => {
            tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
            vec![]
        }
        _ => {
            let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
            tracing::info!("Using default cuda graphs {cuda_graphs:?}");
            cuda_graphs
        }
    };

1568
1569
1570
1571
1572
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1573
1574
1575
1576
1577
1578
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1579
1580

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1581
    if num_shard > 1 {
1582
1583
1584
1585
1586
        if matches!(args.quantize, Some(Quantization::Exl2)) {
            return Err(LauncherError::ArgumentValidation(
                "Sharding is currently not supported with `exl2` quantization".into(),
            ));
        }
1587
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1588
1589
    }

1590
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
1591
        if max_batch_prefill_tokens > *max_batch_total_tokens {
1592
1593
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
1594
                max_batch_prefill_tokens, max_batch_total_tokens
1595
1596
            )));
        }
1597
        if max_total_tokens as u32 > *max_batch_total_tokens {
1598
1599
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
1600
                max_total_tokens, max_batch_total_tokens
1601
1602
1603
1604
            )));
        }
    }

1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1619
1620
1621
1622
1623
1624
1625
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1626

1627
    // Download and convert model weights
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
    download_convert_model(
        &args.model_id,
        args.revision.as_deref(),
        args.trust_remote_code,
        args.huggingface_hub_cache.as_deref(),
        args.weights_cache_override.as_deref(),
        running.clone(),
    )?;

    // Download and convert lora adapters if any
    if let Some(lora_adapters) = &args.lora_adapters {
        for adapter in lora_adapters.split(',') {
1640
1641
1642
1643
            // skip download if a path is provided
            if adapter.contains('=') {
                continue;
            }
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
            download_convert_model(
                adapter,
                None,
                args.trust_remote_code,
                args.huggingface_hub_cache.as_deref(),
                args.weights_cache_override.as_deref(),
                running.clone(),
            )?;
        }
    }
1654

OlivierDehaene's avatar
OlivierDehaene committed
1655
1656
1657
1658
1659
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1660
    // Shared shutdown bool
1661
    let shutdown = Arc::new(AtomicBool::new(false));
1662
1663
1664
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1665

1666
1667
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1668

1669
1670
1671
    spawn_shards(
        num_shard,
        &args,
1672
        cuda_graphs,
1673
        max_total_tokens,
1674
        max_input_tokens,
1675
        max_log_level,
1676
1677
1678
1679
1680
1681
1682
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1683

1684
1685
1686
1687
1688
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1689

1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
    let mut webserver = spawn_webserver(
        num_shard,
        args,
        max_input_tokens,
        max_total_tokens,
        max_batch_prefill_tokens,
        shutdown.clone(),
        &shutdown_receiver,
    )
    .map_err(|err| {
        shutdown_shards(shutdown.clone(), &shutdown_receiver);
        err
    })?;
1703
1704
1705
1706
1707

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1708
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1709
            tracing::error!("Shard {rank} crashed");
1710
1711
1712
1713
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1714
        match webserver.try_wait().unwrap() {
1715
1716
1717
1718
1719
1720
1721
1722
1723
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1724
    }
1725
1726

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1727
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1728
1729
1730
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1731
}