main.rs 41.6 KB
Newer Older
1
use clap::{Parser, ValueEnum};
2
3
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
4
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
5
use std::env;
6
use std::ffi::OsString;
7
use std::io::{BufRead, BufReader, Lines, Read};
8
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
10
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
13
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
15
16
17
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
18
use tracing_subscriber::EnvFilter;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19

20
21
mod env_runtime;

22
23
24
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
    Bitsandbytes,
Nicolas Patry's avatar
Nicolas Patry committed
25
26
    BitsandbytesNF4,
    BitsandbytesFP4,
27
28
29
30
31
32
33
34
35
36
    Gptq,
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
Nicolas Patry's avatar
Nicolas Patry committed
37
38
39
40
41
42
            Quantization::BitsandbytesNF4 => {
                write!(f, "bitsandbytes-nf4")
            }
            Quantization::BitsandbytesFP4 => {
                write!(f, "bitsandbytes-fp4")
            }
43
44
45
46
47
48
49
            Quantization::Gptq => {
                write!(f, "gptq")
            }
        }
    }
}

50
51
52
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
53
    #[clap(name = "bfloat16")]
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

Nicolas Patry's avatar
Nicolas Patry committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
    Linear,
    Dynamic,
}

impl std::fmt::Display for RopeScaling {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            RopeScaling::Linear => {
                write!(f, "linear")
            }
            RopeScaling::Dynamic => {
                write!(f, "dynamic")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
91
92
93
94
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
95
96
97
98
99
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
101
    model_id: String,
102
103
104

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
105
    #[clap(long, env)]
106
    revision: Option<String>,
107

108
109
110
111
112
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

113
    /// Whether to shard the model across multiple GPUs
114
115
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
116
117
    #[clap(long, env)]
    sharded: Option<bool>,
118
119

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
120
121
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
122
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
123
124
    #[clap(long, env)]
    num_shard: Option<usize>,
125

126
    /// Whether you want the model to be quantized. This will use `bitsandbytes` for
Nicolas Patry's avatar
Nicolas Patry committed
127
128
    /// quantization on the fly, or `gptq`. 4bit quantization is available through
    /// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
129
130
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
131

132
133
134
135
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

136
137
138
139
140
141
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

142
143
144
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
145
146
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
147
148
149
150

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
151
152
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
153
154
155
156
157
158

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
159
160
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
161

Nicolas Patry's avatar
Nicolas Patry committed
162
163
164
165
166
167
168
169
    /// This is the maximum allowed value for clients to set `top_n_tokens`.
    /// `top_n_tokens is used to return information about the the `n` most likely
    /// tokens at each generation step, instead of just the sampled token. This
    /// information can be used for downstream tasks like for classification or
    /// ranking.
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,

170
171
172
173
    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
174
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
175
    max_input_length: usize,
176
177
178
179
180
181
182
183
184

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
185
    #[clap(default_value = "2048", long, env)]
186
    max_total_tokens: usize,
187
188
189
190
191
192
193
194
195
196
197

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
198
199
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
200

201
202
203
204
205
206
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
224
225
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
244
245
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
246

247
248
249
250
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

251
    /// The port to listen on.
252
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
253
    port: u16,
254
255
256

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
257
258
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
259
260

    /// The address the master shard will listen on. (setting used by torch distributed)
261
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
262
    master_addr: String,
263
264

    /// The address the master port will listen on. (setting used by torch distributed)
265
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
266
    master_port: usize,
267
268
269

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
270
    #[clap(long, env)]
271
    huggingface_hub_cache: Option<String>,
272
273
274

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
275
276
    #[clap(long, env)]
    weights_cache_override: Option<String>,
277
278
279
280
281

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
282
    #[clap(long, env)]
283
    disable_custom_kernels: bool,
284

285
286
287
288
289
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

Nicolas Patry's avatar
Nicolas Patry committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    /// Rope scaling will only be used for RoPE models
    /// and allow rescaling the position rotary to accomodate for
    /// larger prompts.
    ///
    /// Goes together with `rope_factor`.
    ///
    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
    /// basically)
    ///
    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want
    #[clap(long, env)]
    rope_scaling: Option<RopeScaling>,

    /// Rope scaling will only be used for RoPE models
    /// See `rope_scaling`
    #[clap(long, env)]
    rope_factor: Option<f32>,

310
    /// Outputs the logs in JSON format (useful for telemetry)
311
    #[clap(long, env)]
312
    json_output: bool,
313

314
315
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
316

317
318
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
319
320
321
322
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
323

324
325
326
327
328
329
330
331
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

332
    /// ngrok edge
333
    #[clap(long, env)]
334
    ngrok_edge: Option<String>,
335

336
337
338
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
339
340
}

341
342
343
#[derive(Debug)]
enum ShardStatus {
    Ready,
344
    Failed(usize),
345
}
346

347
348
349
350
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
351
    quantize: Option<Quantization>,
352
    dtype: Option<Dtype>,
353
    trust_remote_code: bool,
354
355
356
357
358
359
360
361
362
363
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
364
    cuda_memory_fraction: f32,
Nicolas Patry's avatar
Nicolas Patry committed
365
366
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
367
368
    otlp_endpoint: Option<String>,
    status_sender: mpsc::Sender<ShardStatus>,
369
    shutdown: Arc<AtomicBool>,
370
371
    _shutdown_sender: mpsc::Sender<()>,
) {
372
373
374
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

375
376
377
378
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
379
380
381
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
382
383

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
384
    let mut shard_args = vec![
385
386
387
388
389
390
391
392
393
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];

394
395
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
396
        shard_args.push("--trust-remote-code".to_string());
397
398
    }

399
400
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
401
        shard_args.push("--sharded".to_string());
402
403
    }

404
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
405
406
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
407
    }
408

409
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
410
411
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
412
413
    }

414
415
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
416
417
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
418
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
419

Nicolas Patry's avatar
Nicolas Patry committed
420
421
422
423
424
425
    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };
426
427
    // OpenTelemetry
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
428
429
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
430
431
432
    }

    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
433
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
434
435

    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
436
437
438
439
440
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
    envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
441

442
443
444
445
446
447
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

448
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
449
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
450
451
452

    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
453
    envs.push((
454
455
456
457
458
459
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
460
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
461
462
    };

Nicolas Patry's avatar
Nicolas Patry committed
463
464
465
466
467
468
469
470
471
    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

472
473
474
    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
475
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
476
477
478
479
480
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
481
        envs.push((
482
483
484
485
486
487
488
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
489
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
490
491
492
493
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
494
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
495
496
497
498
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
499
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
500
501
502
    }

    // Start process
503
    tracing::info!("Starting shard");
504
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
505
506
        .args(shard_args)
        .envs(envs)
507
508
509
510
511
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
512
513
        Ok(p) => p,
        Err(err) => {
514
515
516
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
517
518
            }
            {
519
                tracing::error!("{}", err);
520
            }
521

522
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
523
524
525
526
527
            return;
        }
    };

    // Redirect STDOUT to the console
528
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
529
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
530

531
    //stdout tracing thread
532
    thread::spawn(move || {
533
        log_lines(shard_stdout_reader.lines());
534
535
536
537
538
539
540
    });

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
541
        if let Some(exit_status) = p.try_wait().unwrap() {
542
            // We read stderr in another thread as it seems that lines() can block in some cases
543
544
            let (err_sender, err_receiver) = mpsc::channel();
            thread::spawn(move || {
545
546
547
                for line in shard_stderr_reader.lines().flatten() {
                    err_sender.send(line).unwrap_or(());
                }
548
            });
549
550
551
552
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
553

554
            tracing::error!("Shard complete standard error output:\n{err}");
555

556
            if let Some(signal) = exit_status.signal() {
557
558
559
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

560
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
561
562
563
564
            return;
        }

        // We received a shutdown signal
565
        if shutdown.load(Ordering::SeqCst) {
566
            p.kill().unwrap();
567
            let _ = p.wait();
568
            tracing::info!("Shard terminated");
569
570
571
572
573
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
574
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
575
576
577
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
578
            tracing::info!("Waiting for shard to be ready...");
579
580
581
582
583
584
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

585
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
586
587
588
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
589
    shutdown.store(true, Ordering::SeqCst);
590
591
592
593
594
595
596

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
597
598
599
600
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
        Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
    };
601
602
    let n_devices = devices.split(',').count();
    Some(n_devices)
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
        }
    }
}

647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
impl TryFrom<&String> for PythonLogMessage {
    type Error = serde_json::Error;

    fn try_from(value: &String) -> Result<Self, Self::Error> {
        serde_json::from_str::<Self>(value)
    }
}

fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
    for line in lines.flatten() {
        match PythonLogMessage::try_from(&line) {
            Ok(log) => log.trace(),
            Err(_) => tracing::debug!("{line}"),
        }
    }
}

664
665
666
667
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
668
669
670
671
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
672
673
674
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
            let n_devices = num_cuda_devices()
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
675
            if n_devices <= 1 {
676
677
678
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
679
            }
680
            n_devices
681
        }
682
683
684
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
685
686
687
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
688
689
            }
            num_shard
690
        }
691
692
693
694
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
695
    };
696
    if num_shard < 1 {
697
698
699
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
700
    }
701
    Ok(num_shard)
702
}
703

704
705
#[derive(Debug)]
enum LauncherError {
706
707
    ArgumentValidation(String),
    NotEnoughCUDADevices(String),
708
709
710
711
712
713
714
    DownloadError,
    ShardCannotStart,
    ShardDisconnected,
    ShardFailed,
    WebserverFailed,
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
715

716
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
717
718
719
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
720
    let mut download_args = vec![
721
722
723
724
725
726
727
728
        "download-weights".to_string(),
        args.model_id.to_string(),
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
729

730
731
    // Model optional revision
    if let Some(revision) = &args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
732
733
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
734
    }
735

736
737
738
739
740
    // Trust remote code for automatic peft fusion
    if args.trust_remote_code {
        download_args.push("--trust-remote-code".to_string());
    }

741
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
742
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
743

744
    // If huggingface_hub_cache is set, pass it to the download process
745
746
    // Useful when running inside a docker container
    if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
747
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
748
    };
749

750
751
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
752
    envs.push((
753
754
755
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
756

757
758
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
759
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
760
    };
761

762
763
764
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = &args.weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
765
        envs.push((
766
767
768
769
770
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

771
772
    // Start process
    tracing::info!("Starting download process.");
773
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
774
775
        .args(download_args)
        .envs(envs)
776
777
778
779
780
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
781
782
        Ok(p) => p,
        Err(err) => {
783
784
785
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
786
787
            } else {
                tracing::error!("{}", err);
788
            }
789

790
791
792
            return Err(LauncherError::DownloadError);
        }
    };
793

794
795
    // Redirect STDOUT to the console
    let download_stdout = download_process.stdout.take().unwrap();
796
797
    let stdout = BufReader::new(download_stdout);

798
    thread::spawn(move || {
799
        log_lines(stdout.lines());
800
    });
801

802
    loop {
803
804
805
806
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
                tracing::info!("Successfully downloaded weights.");
                break;
807
            }
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824

            let mut err = String::new();
            download_process
                .stderr
                .take()
                .unwrap()
                .read_to_string(&mut err)
                .unwrap();
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
825
        }
826
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
827
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
828
829
830
            return Ok(());
        }
        sleep(Duration::from_millis(100));
831
    }
832
833
    Ok(())
}
834

835
#[allow(clippy::too_many_arguments)]
836
837
838
fn spawn_shards(
    num_shard: usize,
    args: &Args,
839
    shutdown: Arc<AtomicBool>,
840
841
842
843
844
845
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
846
847
    // Start shard processes
    for rank in 0..num_shard {
848
849
850
851
852
853
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
854
855
856
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
857
        let otlp_endpoint = args.otlp_endpoint.clone();
858
        let quantize = args.quantize;
859
        let dtype = args.dtype;
860
        let trust_remote_code = args.trust_remote_code;
861
862
863
864
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
865
        let cuda_memory_fraction = args.cuda_memory_fraction;
Nicolas Patry's avatar
Nicolas Patry committed
866
867
        let rope_scaling = args.rope_scaling;
        let rope_factor = args.rope_factor;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
868
869
        thread::spawn(move || {
            shard_manager(
870
                model_id,
871
                revision,
872
                quantize,
873
                dtype,
874
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
875
876
877
878
879
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
880
881
                huggingface_hub_cache,
                weights_cache_override,
882
                disable_custom_kernels,
883
884
                watermark_gamma,
                watermark_delta,
885
                cuda_memory_fraction,
Nicolas Patry's avatar
Nicolas Patry committed
886
887
                rope_scaling,
                rope_factor,
888
                otlp_endpoint,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
910
            Ok(ShardStatus::Failed(rank)) => {
911
                tracing::error!("Shard {rank} failed to start");
912
                shutdown_shards(shutdown, shutdown_receiver);
913
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
914
915
916
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
917
                shutdown_shards(shutdown, shutdown_receiver);
918
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
919
920
921
            }
        }
    }
922
923
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
924

925
926
fn spawn_webserver(
    args: Args,
927
    shutdown: Arc<AtomicBool>,
928
    shutdown_receiver: &mpsc::Receiver<()>,
929
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
930
931
932
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
933
    let mut router_args = vec![
934
        "--max-concurrent-requests".to_string(),
935
        args.max_concurrent_requests.to_string(),
936
        "--max-best-of".to_string(),
937
        args.max_best_of.to_string(),
938
        "--max-stop-sequences".to_string(),
939
        args.max_stop_sequences.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
940
941
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
942
        "--max-input-length".to_string(),
943
        args.max_input_length.to_string(),
944
        "--max-total-tokens".to_string(),
945
        args.max_total_tokens.to_string(),
946
947
        "--max-batch-prefill-tokens".to_string(),
        args.max_batch_prefill_tokens.to_string(),
948
        "--waiting-served-ratio".to_string(),
949
        args.waiting_served_ratio.to_string(),
950
        "--max-waiting-tokens".to_string(),
951
        args.max_waiting_tokens.to_string(),
952
953
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
954
955
        "--hostname".to_string(),
        args.hostname.to_string(),
956
        "--port".to_string(),
957
        args.port.to_string(),
958
        "--master-shard-uds-path".to_string(),
959
        format!("{}-0", args.shard_uds_path),
960
        "--tokenizer-name".to_string(),
961
        args.model_id,
962
963
    ];

964
965
966
967
968
969
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

970
971
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
972
973
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
974
975
    }

976
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
977
        router_args.push("--json-output".to_string());
978
979
    }

980
    // OpenTelemetry
981
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
982
983
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
984
985
986
987
    }

    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
988
989
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
990
991
    }

992
993
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
994
995
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
996
997
998
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
999
1000
    }

1001
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1002
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1003

1004
1005
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
1006
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
1007
    };
1008

1009
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
1010
1011
        .args(router_args)
        .envs(envs)
1012
1013
1014
1015
1016
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1017
1018
        Ok(p) => p,
        Err(err) => {
1019
            tracing::error!("Failed to start webserver: {}", err);
1020
1021
1022
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
1023
1024
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1025
            }
1026

1027
            shutdown_shards(shutdown, shutdown_receiver);
1028
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1029
1030
1031
        }
    };

1032
1033
1034
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
1035
1036

    thread::spawn(move || {
1037
1038
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
1039
        for line in stdout.lines() {
1040
            println!("{}", line.unwrap());
1041
        }
1042
1043
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1044
        }
1045
1046
1047
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1048

OlivierDehaene's avatar
OlivierDehaene committed
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");

    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }

    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1074
1075
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
1076
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1077

1078
1079
1080
1081
    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

1082
    if args.json_output {
1083
1084
1085
1086
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1087
    } else {
1088
1089
1090
1091
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1092
1093
    }

1094
1095
1096
1097
1098
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

1099
1100
    tracing::info!("{:?}", args);

1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    // Validate args
    if args.max_input_length >= args.max_total_tokens {
        return Err(LauncherError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
    if args.max_input_length as u32 > args.max_batch_prefill_tokens {
        return Err(LauncherError::ArgumentValidation(format!(
            "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
            args.max_batch_prefill_tokens, args.max_input_length
        )));
    }
1113

1114
1115
1116
1117
1118
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1119
1120
1121
1122
1123
1124
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1125
1126

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1127
1128
    if num_shard > 1 {
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1129
1130
    }

1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
        if args.max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                args.max_batch_prefill_tokens, max_batch_total_tokens
            )));
        }
        if args.max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                args.max_total_tokens, max_batch_total_tokens
            )));
        }
    }

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1160
1161
1162
1163
1164
1165
1166
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1167

1168
    // Download and convert model weights
1169
    download_convert_model(&args, running.clone())?;
1170

OlivierDehaene's avatar
OlivierDehaene committed
1171
1172
1173
1174
1175
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1176
    // Shared shutdown bool
1177
    let shutdown = Arc::new(AtomicBool::new(false));
1178
1179
1180
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1181

1182
1183
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1184

1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    spawn_shards(
        num_shard,
        &args,
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1195

1196
1197
1198
1199
1200
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1201

OlivierDehaene's avatar
OlivierDehaene committed
1202
1203
1204
1205
1206
    let mut webserver =
        spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
            shutdown_shards(shutdown.clone(), &shutdown_receiver);
            err
        })?;
1207
1208
1209
1210
1211

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1212
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1213
            tracing::error!("Shard {rank} crashed");
1214
1215
1216
1217
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1218
        match webserver.try_wait().unwrap() {
1219
1220
1221
1222
1223
1224
1225
1226
1227
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1228
    }
1229
1230

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1231
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1232
1233
1234
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1235
}