main.rs 38.5 KB
Newer Older
1
use clap::{Parser, ValueEnum};
2
3
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
4
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
5
use std::env;
6
use std::ffi::OsString;
7
use std::io::{BufRead, BufReader, Lines, Read};
8
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
10
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
13
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
15
16
17
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
18
use tracing_subscriber::EnvFilter;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19

20
21
mod env_runtime;

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
    Bitsandbytes,
    Gptq,
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
            Quantization::Gptq => {
                write!(f, "gptq")
            }
        }
    }
}

42
43
44
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
45
    #[clap(name = "bfloat16")]
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
63
64
65
66
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
67
68
69
70
71
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
72
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
73
    model_id: String,
74
75
76

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
    #[clap(long, env)]
78
    revision: Option<String>,
79

80
81
82
83
84
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

85
    /// Whether to shard the model across multiple GPUs
86
87
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
88
89
    #[clap(long, env)]
    sharded: Option<bool>,
90
91

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
92
93
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
94
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
95
96
    #[clap(long, env)]
    num_shard: Option<usize>,
97

98
    /// Whether you want the model to be quantized. This will use `bitsandbytes` for
99
100
101
    /// quantization on the fly, or `gptq`.
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
102

103
104
105
106
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

107
108
109
110
111
112
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

113
114
115
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
116
117
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
118
119
120
121

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
122
123
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
124
125
126
127
128
129

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
130
131
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
132
133
134
135
136

    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
137
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
    max_input_length: usize,
139
140
141
142
143
144
145
146
147

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
148
    #[clap(default_value = "2048", long, env)]
149
    max_total_tokens: usize,
150
151
152
153
154
155
156
157
158
159
160

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
161
162
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
163

164
165
166
167
168
169
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
187
188
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
207
208
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
209

210
211
212
213
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

214
    /// The port to listen on.
215
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
216
    port: u16,
217
218
219

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
220
221
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
222
223

    /// The address the master shard will listen on. (setting used by torch distributed)
224
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
225
    master_addr: String,
226
227

    /// The address the master port will listen on. (setting used by torch distributed)
228
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
229
    master_port: usize,
230
231
232

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
233
    #[clap(long, env)]
234
    huggingface_hub_cache: Option<String>,
235
236
237

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
238
239
    #[clap(long, env)]
    weights_cache_override: Option<String>,
240
241
242
243
244

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
245
    #[clap(long, env)]
246
    disable_custom_kernels: bool,
247

248
249
250
251
252
    /// Limit the CUDA available memory.
    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
    #[clap(default_value = "1.0", long, env)]
    cuda_memory_fraction: f32,

253
    /// Outputs the logs in JSON format (useful for telemetry)
254
    #[clap(long, env)]
255
    json_output: bool,
256

257
258
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
259

260
261
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
262
263
264
265
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
266

267
268
269
270
271
272
273
274
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

275
    /// ngrok edge
276
    #[clap(long, env)]
277
    ngrok_edge: Option<String>,
278

279
280
281
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
282
283
}

284
285
286
#[derive(Debug)]
enum ShardStatus {
    Ready,
287
    Failed(usize),
288
}
289

290
291
292
293
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
294
    quantize: Option<Quantization>,
295
    dtype: Option<Dtype>,
296
    trust_remote_code: bool,
297
298
299
300
301
302
303
304
305
306
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
307
    cuda_memory_fraction: f32,
308
309
    otlp_endpoint: Option<String>,
    status_sender: mpsc::Sender<ShardStatus>,
310
    shutdown: Arc<AtomicBool>,
311
312
    _shutdown_sender: mpsc::Sender<()>,
) {
313
314
315
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

316
317
318
319
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
320
321
322
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
323
324

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
325
    let mut shard_args = vec![
326
327
328
329
330
331
332
333
334
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];

335
336
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
337
        shard_args.push("--trust-remote-code".to_string());
338
339
    }

340
341
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
342
        shard_args.push("--sharded".to_string());
343
344
    }

345
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
346
347
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
348
    }
349

350
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
351
352
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
353
354
    }

355
356
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
357
358
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
359
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
360

361
362
    // OpenTelemetry
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
363
364
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
365
366
367
    }

    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
368
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
369
370

    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
371
372
373
374
375
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
    envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
376

377
378
379
380
381
382
    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

383
    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
384
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
385
386
387

    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
388
    envs.push((
389
390
391
392
393
394
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
395
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
396
397
398
399
400
    };

    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
401
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
402
403
404
405
406
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
407
        envs.push((
408
409
410
411
412
413
414
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
415
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
416
417
418
419
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
420
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
421
422
423
424
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
425
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
426
427
428
    }

    // Start process
429
    tracing::info!("Starting shard");
430
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
431
432
        .args(shard_args)
        .envs(envs)
433
434
435
436
437
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
438
439
        Ok(p) => p,
        Err(err) => {
440
441
442
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
443
444
            }
            {
445
                tracing::error!("{}", err);
446
            }
447

448
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
449
450
451
452
453
            return;
        }
    };

    // Redirect STDOUT to the console
454
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
455
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
456

457
    //stdout tracing thread
458
    thread::spawn(move || {
459
        log_lines(shard_stdout_reader.lines());
460
461
462
463
464
465
466
    });

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
467
        if let Some(exit_status) = p.try_wait().unwrap() {
468
            // We read stderr in another thread as it seems that lines() can block in some cases
469
470
            let (err_sender, err_receiver) = mpsc::channel();
            thread::spawn(move || {
471
472
473
                for line in shard_stderr_reader.lines().flatten() {
                    err_sender.send(line).unwrap_or(());
                }
474
            });
475
476
477
478
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
479

480
            tracing::error!("Shard complete standard error output:\n{err}");
481

482
            if let Some(signal) = exit_status.signal() {
483
484
485
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

486
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
487
488
489
490
            return;
        }

        // We received a shutdown signal
491
        if shutdown.load(Ordering::SeqCst) {
492
            p.kill().unwrap();
493
            let _ = p.wait();
494
            tracing::info!("Shard terminated");
495
496
497
498
499
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
500
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
501
502
503
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
504
            tracing::info!("Waiting for shard to be ready...");
505
506
507
508
509
510
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

511
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
512
513
514
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
515
    shutdown.store(true, Ordering::SeqCst);
516
517
518
519
520
521
522

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
523
524
525
526
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
        Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
    };
527
528
    let n_devices = devices.split(',').count();
    Some(n_devices)
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
        }
    }
}

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
impl TryFrom<&String> for PythonLogMessage {
    type Error = serde_json::Error;

    fn try_from(value: &String) -> Result<Self, Self::Error> {
        serde_json::from_str::<Self>(value)
    }
}

fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
    for line in lines.flatten() {
        match PythonLogMessage::try_from(&line) {
            Ok(log) => log.trace(),
            Err(_) => tracing::debug!("{line}"),
        }
    }
}

590
591
592
593
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
594
595
596
597
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
598
599
600
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
            let n_devices = num_cuda_devices()
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
601
            if n_devices <= 1 {
602
603
604
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
605
            }
606
            n_devices
607
        }
608
609
610
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
611
612
613
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
614
615
            }
            num_shard
616
        }
617
618
619
620
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
621
    };
622
    if num_shard < 1 {
623
624
625
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
626
    }
627
    Ok(num_shard)
628
}
629

630
631
#[derive(Debug)]
enum LauncherError {
632
633
    ArgumentValidation(String),
    NotEnoughCUDADevices(String),
634
635
636
637
638
639
640
    DownloadError,
    ShardCannotStart,
    ShardDisconnected,
    ShardFailed,
    WebserverFailed,
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
641

642
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
643
644
645
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
646
    let mut download_args = vec![
647
648
649
650
651
652
653
654
        "download-weights".to_string(),
        args.model_id.to_string(),
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
655

656
657
    // Model optional revision
    if let Some(revision) = &args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
658
659
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
660
    }
661

662
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
663
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
664

665
    // If huggingface_hub_cache is set, pass it to the download process
666
667
    // Useful when running inside a docker container
    if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
668
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
669
    };
670

671
672
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
673
    envs.push((
674
675
676
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
677

678
679
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
680
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
681
    };
682

683
684
685
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = &args.weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
686
        envs.push((
687
688
689
690
691
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

692
693
    // Start process
    tracing::info!("Starting download process.");
694
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
695
696
        .args(download_args)
        .envs(envs)
697
698
699
700
701
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
702
703
        Ok(p) => p,
        Err(err) => {
704
705
706
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
707
708
            } else {
                tracing::error!("{}", err);
709
            }
710

711
712
713
            return Err(LauncherError::DownloadError);
        }
    };
714

715
716
    // Redirect STDOUT to the console
    let download_stdout = download_process.stdout.take().unwrap();
717
718
    let stdout = BufReader::new(download_stdout);

719
    thread::spawn(move || {
720
        log_lines(stdout.lines());
721
    });
722

723
    loop {
724
725
726
727
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
                tracing::info!("Successfully downloaded weights.");
                break;
728
            }
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745

            let mut err = String::new();
            download_process
                .stderr
                .take()
                .unwrap()
                .read_to_string(&mut err)
                .unwrap();
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
746
        }
747
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
748
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
749
750
751
            return Ok(());
        }
        sleep(Duration::from_millis(100));
752
    }
753
754
    Ok(())
}
755

756
#[allow(clippy::too_many_arguments)]
757
758
759
fn spawn_shards(
    num_shard: usize,
    args: &Args,
760
    shutdown: Arc<AtomicBool>,
761
762
763
764
765
766
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
767
768
    // Start shard processes
    for rank in 0..num_shard {
769
770
771
772
773
774
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
775
776
777
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
778
        let otlp_endpoint = args.otlp_endpoint.clone();
779
        let quantize = args.quantize;
780
        let dtype = args.dtype;
781
        let trust_remote_code = args.trust_remote_code;
782
783
784
785
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
786
        let cuda_memory_fraction = args.cuda_memory_fraction;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
787
788
        thread::spawn(move || {
            shard_manager(
789
                model_id,
790
                revision,
791
                quantize,
792
                dtype,
793
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
794
795
796
797
798
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
799
800
                huggingface_hub_cache,
                weights_cache_override,
801
                disable_custom_kernels,
802
803
                watermark_gamma,
                watermark_delta,
804
                cuda_memory_fraction,
805
                otlp_endpoint,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
827
            Ok(ShardStatus::Failed(rank)) => {
828
                tracing::error!("Shard {rank} failed to start");
829
                shutdown_shards(shutdown, shutdown_receiver);
830
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
831
832
833
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
834
                shutdown_shards(shutdown, shutdown_receiver);
835
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
836
837
838
            }
        }
    }
839
840
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
841

842
843
fn spawn_webserver(
    args: Args,
844
    shutdown: Arc<AtomicBool>,
845
    shutdown_receiver: &mpsc::Receiver<()>,
846
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
847
848
849
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
850
    let mut router_args = vec![
851
        "--max-concurrent-requests".to_string(),
852
        args.max_concurrent_requests.to_string(),
853
        "--max-best-of".to_string(),
854
        args.max_best_of.to_string(),
855
        "--max-stop-sequences".to_string(),
856
        args.max_stop_sequences.to_string(),
857
        "--max-input-length".to_string(),
858
        args.max_input_length.to_string(),
859
        "--max-total-tokens".to_string(),
860
        args.max_total_tokens.to_string(),
861
862
        "--max-batch-prefill-tokens".to_string(),
        args.max_batch_prefill_tokens.to_string(),
863
        "--waiting-served-ratio".to_string(),
864
        args.waiting_served_ratio.to_string(),
865
        "--max-waiting-tokens".to_string(),
866
        args.max_waiting_tokens.to_string(),
867
868
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
869
870
        "--hostname".to_string(),
        args.hostname.to_string(),
871
        "--port".to_string(),
872
        args.port.to_string(),
873
        "--master-shard-uds-path".to_string(),
874
        format!("{}-0", args.shard_uds_path),
875
        "--tokenizer-name".to_string(),
876
        args.model_id,
877
878
    ];

879
880
881
882
883
884
    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

885
886
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
887
888
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
889
890
    }

891
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
892
        router_args.push("--json-output".to_string());
893
894
    }

895
    // OpenTelemetry
896
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
897
898
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
899
900
901
902
    }

    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
903
904
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
905
906
    }

907
908
    // Ngrok
    if args.ngrok {
OlivierDehaene's avatar
OlivierDehaene committed
909
910
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
911
912
913
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
914
915
    }

916
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
917
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
918

919
920
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
921
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
922
    };
923

924
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
925
926
        .args(router_args)
        .envs(envs)
927
928
929
930
931
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
932
933
        Ok(p) => p,
        Err(err) => {
934
            tracing::error!("Failed to start webserver: {}", err);
935
936
937
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
938
939
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
940
            }
941

942
            shutdown_shards(shutdown, shutdown_receiver);
943
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
944
945
946
        }
    };

947
948
949
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
950
951

    thread::spawn(move || {
952
953
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
954
        for line in stdout.lines() {
955
            println!("{}", line.unwrap());
956
        }
957
958
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
959
        }
960
961
962
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
963

OlivierDehaene's avatar
OlivierDehaene committed
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");

    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }

    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

989
990
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
991
    let args: Args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
992

993
994
995
996
    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

997
    if args.json_output {
998
999
1000
1001
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1002
    } else {
1003
1004
1005
1006
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1007
1008
    }

1009
1010
1011
1012
1013
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

1014
1015
    tracing::info!("{:?}", args);

1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    // Validate args
    if args.max_input_length >= args.max_total_tokens {
        return Err(LauncherError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
    if args.max_input_length as u32 > args.max_batch_prefill_tokens {
        return Err(LauncherError::ArgumentValidation(format!(
            "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
            args.max_batch_prefill_tokens, args.max_input_length
        )));
    }
1028

1029
1030
1031
1032
1033
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1034
1035
1036
1037
1038
1039
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1040
1041

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1042
1043
    if num_shard > 1 {
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1044
1045
    }

1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
        if args.max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                args.max_batch_prefill_tokens, max_batch_total_tokens
            )));
        }
        if args.max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(LauncherError::ArgumentValidation(format!(
                "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                args.max_total_tokens, max_batch_total_tokens
            )));
        }
    }

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

1075
1076
1077
1078
1079
1080
1081
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1082

1083
    // Download and convert model weights
1084
    download_convert_model(&args, running.clone())?;
1085

OlivierDehaene's avatar
OlivierDehaene committed
1086
1087
1088
1089
1090
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1091
    // Shared shutdown bool
1092
    let shutdown = Arc::new(AtomicBool::new(false));
1093
1094
1095
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1096

1097
1098
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1099

1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
    spawn_shards(
        num_shard,
        &args,
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1110

1111
1112
1113
1114
1115
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1116

OlivierDehaene's avatar
OlivierDehaene committed
1117
1118
1119
1120
1121
    let mut webserver =
        spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
            shutdown_shards(shutdown.clone(), &shutdown_receiver);
            err
        })?;
1122
1123
1124
1125
1126

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1127
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1128
            tracing::error!("Shard {rank} crashed");
1129
1130
1131
1132
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1133
        match webserver.try_wait().unwrap() {
1134
1135
1136
1137
1138
1139
1140
1141
1142
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1143
    }
1144
1145

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1146
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1147
1148
1149
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1150
}