main.rs 41.7 KB
Newer Older
1
use clap::{Parser, ValueEnum};
2
3
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
4
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
5
use std::env;
6
use std::ffi::OsString;
7
use std::io::{BufRead, BufReader, Lines, Read};
8
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
10
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
13
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
15
16
17
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
18
use tracing_subscriber::EnvFilter;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19

20
21
mod env_runtime;

22
23
24
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
    Bitsandbytes,
Nicolas Patry's avatar
Nicolas Patry committed
25
26
    BitsandbytesNF4,
    BitsandbytesFP4,
27
    Gptq,
28
    Awq,
29
30
31
32
33
34
35
36
37
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
Nicolas Patry's avatar
Nicolas Patry committed
38
39
40
41
42
43
            Quantization::BitsandbytesNF4 => {
                write!(f, "bitsandbytes-nf4")
            }
            Quantization::BitsandbytesFP4 => {
                write!(f, "bitsandbytes-fp4")
            }
44
45
46
            Quantization::Gptq => {
                write!(f, "gptq")
            }
47
48
49
            Quantization::Awq => {
                write!(f, "awq")
            }
50
51
52
53
        }
    }
}

54
55
56
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
57
    #[clap(name = "bfloat16")]
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

Nicolas Patry's avatar
Nicolas Patry committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
    Linear,
    Dynamic,
}

impl std::fmt::Display for RopeScaling {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            RopeScaling::Linear => {
                write!(f, "linear")
            }
            RopeScaling::Dynamic => {
                write!(f, "dynamic")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
95
96
97
98
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
99
100
101
102
103
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
105
    model_id: String,
106
107
108

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
109
    #[clap(long, env)]
110
    revision: Option<String>,
111

112
113
114
115
116
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

117
    /// Whether to shard the model across multiple GPUs
118
119
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
120
121
    #[clap(long, env)]
    sharded: Option<bool>,
122
123

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
124
125
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
126
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
127
128
    #[clap(long, env)]
    num_shard: Option<usize>,
129

130
    /// Whether you want the model to be quantized. This will use `bitsandbytes` for
Nicolas Patry's avatar
Nicolas Patry committed
131
132
    /// quantization on the fly, or `gptq`. 4bit quantization is available through
    /// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
133
134
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
135

136
137
138
139
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

140
141
142
143
144
145
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

146
147
148
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
150
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
151
152
153
154

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
155
156
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
157
158
159
160
161
162

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
163
164
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
165

Nicolas Patry's avatar
Nicolas Patry committed
166
167
168
169
170
171
172
173
    /// This is the maximum allowed value for clients to set `top_n_tokens`.
    /// `top_n_tokens is used to return information about the the `n` most likely
    /// tokens at each generation step, instead of just the sampled token. This
    /// information can be used for downstream tasks like for classification or
    /// ranking.
    #[clap(default_value = "5", long, env)]
    max_top_n_tokens: u32,

174
175
176
177
    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
178
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
179
    max_input_length: usize,
180
181
182
183
184
185
186
187
188

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
189
    #[clap(default_value = "2048", long, env)]
190
    max_total_tokens: usize,
191
192
193
194
195
196
197
198
199
200
201

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
202
203
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
204

205
206
207
208
209
210
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
228
229
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
248
249
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
250

251
252
253
254
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

255
    /// The port to listen on.
256
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
257
    port: u16,
258
259
260

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
261
262
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
263
264

    /// The address the master shard will listen on. (setting used by torch distributed)
265
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
266
    master_addr: String,
267
268

    /// The address the master port will listen on. (setting used by torch distributed)
269
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
270
    master_port: usize,
271
272
273

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
274
    #[clap(long, env)]
275
    huggingface_hub_cache: Option<String>,
276
277
278

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
279
280
    #[clap(long, env)]
    weights_cache_override: Option<String>,
281
282
283
284
285

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
286
    #[clap(long, env)]
287
    disable_custom_kernels: bool,
288

289
290
291
292
293
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

Nicolas Patry's avatar
Nicolas Patry committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    /// Rope scaling will only be used for RoPE models
    /// and allow rescaling the position rotary to accomodate for
    /// larger prompts.
    ///
    /// Goes together with `rope_factor`.
    ///
    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
    /// basically)
    ///
    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want
    #[clap(long, env)]
    rope_scaling: Option<RopeScaling>,

    /// Rope scaling will only be used for RoPE models
    /// See `rope_scaling`
    #[clap(long, env)]
    rope_factor: Option<f32>,

314
    /// Outputs the logs in JSON format (useful for telemetry)
315
    #[clap(long, env)]
316
    json_output: bool,
317

318
319
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
320

321
322
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
323
324
325
326
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
327

328
329
330
331
332
333
334
335
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

336
    /// ngrok edge
337
    #[clap(long, env)]
338
    ngrok_edge: Option<String>,
339

340
341
342
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
343
344
}

345
346
347
#[derive(Debug)]
enum ShardStatus {
    Ready,
348
    Failed(usize),
349
}
350

351
352
353
354
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
355
    quantize: Option<Quantization>,
356
    dtype: Option<Dtype>,
357
    trust_remote_code: bool,
358
359
360
361
362
363
364
365
366
367
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
368
    cuda_memory_fraction: f32,
Nicolas Patry's avatar
Nicolas Patry committed
369
370
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
371
372
    otlp_endpoint: Option<String>,
    status_sender: mpsc::Sender<ShardStatus>,
373
    shutdown: Arc<AtomicBool>,
374
375
    _shutdown_sender: mpsc::Sender<()>,
) {
376
377
378
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

379
380
381
382
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
383
384
385
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
386
387

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
388
    let mut shard_args = vec![
389
390
391
392
393
394
395
396
397
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];

398
399
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
400
        shard_args.push("--trust-remote-code".to_string());
401
402
    }

403
404
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
405
        shard_args.push("--sharded".to_string());
406
407
    }

408
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
409
410
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
411
    }
412

413
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
414
415
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
416
417
    }

418
419
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
420
421
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
422
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
423

Nicolas Patry's avatar
Nicolas Patry committed
424
425
426
427
428
429
    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };
430
431
    // OpenTelemetry
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
432
433
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
434
435
436
    }

    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
437
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
438
439

    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
440
441
442
443
444
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
    envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
445

446
447
448
449
450
451
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

452
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
453
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
454
455
456

    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
457
    envs.push((
458
459
460
461
462
463
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
464
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
465
466
    };

Nicolas Patry's avatar
Nicolas Patry committed
467
468
469
470
471
472
473
474
475
    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

476
477
478
    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
479
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
480
481
482
483
484
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
485
        envs.push((
486
487
488
489
490
491
492
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
493
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
494
495
496
497
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
498
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
499
500
501
502
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
503
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
504
505
506
    }

    // Start process
507
    tracing::info!("Starting shard");
508
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
509
510
        .args(shard_args)
        .envs(envs)
511
512
513
514
515
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
516
517
        Ok(p) => p,
        Err(err) => {
518
519
520
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
521
522
            }
            {
523
                tracing::error!("{}", err);
524
            }
525

526
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
527
528
529
530
531
            return;
        }
    };

    // Redirect STDOUT to the console
532
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
533
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
534

535
    //stdout tracing thread
536
    thread::spawn(move || {
537
        log_lines(shard_stdout_reader.lines());
538
539
540
541
542
543
544
    });

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
545
        if let Some(exit_status) = p.try_wait().unwrap() {
546
            // We read stderr in another thread as it seems that lines() can block in some cases
547
548
            let (err_sender, err_receiver) = mpsc::channel();
            thread::spawn(move || {
549
550
551
                for line in shard_stderr_reader.lines().flatten() {
                    err_sender.send(line).unwrap_or(());
                }
552
            });
553
554
555
556
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
557

558
            tracing::error!("Shard complete standard error output:\n{err}");
559

560
            if let Some(signal) = exit_status.signal() {
561
562
563
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

564
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
565
566
567
568
            return;
        }

        // We received a shutdown signal
569
        if shutdown.load(Ordering::SeqCst) {
570
            p.kill().unwrap();
571
            let _ = p.wait();
572
            tracing::info!("Shard terminated");
573
574
575
576
577
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
578
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
579
580
581
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
582
            tracing::info!("Waiting for shard to be ready...");
583
584
585
586
587
588
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

589
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
590
591
592
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
593
    shutdown.store(true, Ordering::SeqCst);
594
595
596
597
598
599
600

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
601
602
603
604
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
        Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
    };
605
606
    let n_devices = devices.split(',').count();
    Some(n_devices)
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
        }
    }
}

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
impl TryFrom<&String> for PythonLogMessage {
    type Error = serde_json::Error;

    fn try_from(value: &String) -> Result<Self, Self::Error> {
        serde_json::from_str::<Self>(value)
    }
}

fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
    for line in lines.flatten() {
        match PythonLogMessage::try_from(&line) {
            Ok(log) => log.trace(),
            Err(_) => tracing::debug!("{line}"),
        }
    }
}

668
669
670
671
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
672
673
674
675
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
676
677
678
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
            let n_devices = num_cuda_devices()
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
679
            if n_devices <= 1 {
680
681
682
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
683
            }
684
            n_devices
685
        }
686
687
688
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
689
690
691
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
692
693
            }
            num_shard
694
        }
695
696
697
698
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
699
    };
700
    if num_shard < 1 {
701
702
703
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
704
    }
705
    Ok(num_shard)
706
}
707

708
709
#[derive(Debug)]
enum LauncherError {
710
711
    ArgumentValidation(String),
    NotEnoughCUDADevices(String),
712
713
714
715
716
717
718
    DownloadError,
    ShardCannotStart,
    ShardDisconnected,
    ShardFailed,
    WebserverFailed,
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
719

720
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
721
722
723
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
724
    let mut download_args = vec![
725
726
727
728
729
730
731
732
        "download-weights".to_string(),
        args.model_id.to_string(),
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
733

734
735
    // Model optional revision
    if let Some(revision) = &args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
736
737
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
738
    }
739

740
741
742
743
744
    // Trust remote code for automatic peft fusion
    if args.trust_remote_code {
        download_args.push("--trust-remote-code".to_string());
    }

745
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
746
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
747

748
    // If huggingface_hub_cache is set, pass it to the download process
749
750
    // Useful when running inside a docker container
    if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
751
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
752
    };
753

754
755
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
756
    envs.push((
757
758
759
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
760

761
762
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
763
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
764
    };
765

766
767
768
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = &args.weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
769
        envs.push((
770
771
772
773
774
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

775
776
    // Start process
    tracing::info!("Starting download process.");
777
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
778
779
        .args(download_args)
        .envs(envs)
780
781
782
783
784
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
785
786
        Ok(p) => p,
        Err(err) => {
787
788
789
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
790
791
            } else {
                tracing::error!("{}", err);
792
            }
793

794
795
796
            return Err(LauncherError::DownloadError);
        }
    };
797

798
799
    // Redirect STDOUT to the console
    let download_stdout = download_process.stdout.take().unwrap();
800
801
    let stdout = BufReader::new(download_stdout);

802
    thread::spawn(move || {
803
        log_lines(stdout.lines());
804
    });
805

806
    loop {
807
808
809
810
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
                tracing::info!("Successfully downloaded weights.");
                break;
811
            }
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828

            let mut err = String::new();
            download_process
                .stderr
                .take()
                .unwrap()
                .read_to_string(&mut err)
                .unwrap();
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
829
        }
830
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
831
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
832
833
834
            return Ok(());
        }
        sleep(Duration::from_millis(100));
835
    }
836
837
    Ok(())
}
838

839
#[allow(clippy::too_many_arguments)]
840
841
842
fn spawn_shards(
    num_shard: usize,
    args: &Args,
843
    shutdown: Arc<AtomicBool>,
844
845
846
847
848
849
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
850
851
    // Start shard processes
    for rank in 0..num_shard {
852
853
854
855
856
857
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
858
859
860
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
861
        let otlp_endpoint = args.otlp_endpoint.clone();
862
        let quantize = args.quantize;
863
        let dtype = args.dtype;
864
        let trust_remote_code = args.trust_remote_code;
865
866
867
868
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
869
        let cuda_memory_fraction = args.cuda_memory_fraction;
Nicolas Patry's avatar
Nicolas Patry committed
870
871
        let rope_scaling = args.rope_scaling;
        let rope_factor = args.rope_factor;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
872
873
        thread::spawn(move || {
            shard_manager(
874
                model_id,
875
                revision,
876
                quantize,
877
                dtype,
878
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
879
880
881
882
883
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
884
885
                huggingface_hub_cache,
                weights_cache_override,
886
                disable_custom_kernels,
887
888
                watermark_gamma,
                watermark_delta,
889
                cuda_memory_fraction,
Nicolas Patry's avatar
Nicolas Patry committed
890
891
                rope_scaling,
                rope_factor,
892
                otlp_endpoint,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
914
            Ok(ShardStatus::Failed(rank)) => {
915
                tracing::error!("Shard {rank} failed to start");
916
                shutdown_shards(shutdown, shutdown_receiver);
917
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
918
919
920
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
921
                shutdown_shards(shutdown, shutdown_receiver);
922
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
923
924
925
            }
        }
    }
926
927
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
928

929
930
fn spawn_webserver(
    args: Args,
931
    shutdown: Arc<AtomicBool>,
932
    shutdown_receiver: &mpsc::Receiver<()>,
933
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
934
935
936
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
937
    let mut router_args = vec![
938
        "--max-concurrent-requests".to_string(),
939
        args.max_concurrent_requests.to_string(),
940
        "--max-best-of".to_string(),
941
        args.max_best_of.to_string(),
942
        "--max-stop-sequences".to_string(),
943
        args.max_stop_sequences.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
944
945
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
946
        "--max-input-length".to_string(),
947
        args.max_input_length.to_string(),
948
        "--max-total-tokens".to_string(),
949
        args.max_total_tokens.to_string(),
950
951
        "--max-batch-prefill-tokens".to_string(),
        args.max_batch_prefill_tokens.to_string(),
952
        "--waiting-served-ratio".to_string(),
953
        args.waiting_served_ratio.to_string(),
954
        "--max-waiting-tokens".to_string(),
955
        args.max_waiting_tokens.to_string(),
956
957
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
958
959
        "--hostname".to_string(),
        args.hostname.to_string(),
960
        "--port".to_string(),
961
        args.port.to_string(),
962
        "--master-shard-uds-path".to_string(),
963
        format!("{}-0", args.shard_uds_path),
964
        "--tokenizer-name".to_string(),
965
        args.model_id,
966
967
    ];

968
969
970
971
972
973
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

974
975
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
976
977
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
978
979
    }

980
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
981
        router_args.push("--json-output".to_string());
982
983
    }

984
    // OpenTelemetry
985
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
986
987
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
988
989
990
991
    }

    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
992
993
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
994
995
    }

996
997
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
998
999
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
1000
1001
1002
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
1003
1004
    }

1005
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
1006
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
1007

1008
1009
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
1010
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
1011
    };
1012

1013
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
1014
1015
        .args(router_args)
        .envs(envs)
1016
1017
1018
1019
1020
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1021
1022
        Ok(p) => p,
        Err(err) => {
1023
            tracing::error!("Failed to start webserver: {}", err);
1024
1025
1026
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
1027
1028
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1029
            }
1030

1031
            shutdown_shards(shutdown, shutdown_receiver);
1032
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1033
1034
1035
        }
    };

1036
1037
1038
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
1039
1040

    thread::spawn(move || {
1041
1042
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
1043
        for line in stdout.lines() {
1044
            println!("{}", line.unwrap());
1045
        }
1046
1047
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1048
        }
1049
1050
1051
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1052

OlivierDehaene's avatar
OlivierDehaene committed
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");

    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }

    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1078
1079
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
1080
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1081

1082
1083
1084
1085
    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

1086
    if args.json_output {
1087
1088
1089
1090
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1091
    } else {
1092
1093
1094
1095
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1096
1097
    }

1098
1099
1100
1101
1102
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

1103
1104
    tracing::info!("{:?}", args);

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
    // Validate args
    if args.max_input_length >= args.max_total_tokens {
        return Err(LauncherError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
    if args.max_input_length as u32 > args.max_batch_prefill_tokens {
        return Err(LauncherError::ArgumentValidation(format!(
            "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
            args.max_batch_prefill_tokens, args.max_input_length
        )));
    }
1117

1118
1119
1120
1121
1122
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1123
1124
1125
1126
1127
1128
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1129
1130

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1131
1132
    if num_shard > 1 {
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1133
1134
    }

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
        if args.max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                args.max_batch_prefill_tokens, max_batch_total_tokens
            )));
        }
        if args.max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                args.max_total_tokens, max_batch_total_tokens
            )));
        }
    }

1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1164
1165
1166
1167
1168
1169
1170
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1171

1172
    // Download and convert model weights
1173
    download_convert_model(&args, running.clone())?;
1174

OlivierDehaene's avatar
OlivierDehaene committed
1175
1176
1177
1178
1179
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1180
    // Shared shutdown bool
1181
    let shutdown = Arc::new(AtomicBool::new(false));
1182
1183
1184
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1185

1186
1187
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1188

1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    spawn_shards(
        num_shard,
        &args,
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1199

1200
1201
1202
1203
1204
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1205

OlivierDehaene's avatar
OlivierDehaene committed
1206
1207
1208
1209
1210
    let mut webserver =
        spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
            shutdown_shards(shutdown.clone(), &shutdown_receiver);
            err
        })?;
1211
1212
1213
1214
1215

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1216
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1217
            tracing::error!("Shard {rank} crashed");
1218
1219
1220
1221
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1222
        match webserver.try_wait().unwrap() {
1223
1224
1225
1226
1227
1228
1229
1230
1231
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1232
    }
1233
1234

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1235
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1236
1237
1238
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1239
}