main.rs 38.3 KB
Newer Older
1
use clap::{Parser, ValueEnum};
2
3
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
4
use serde::Deserialize;
Nicolas Patry's avatar
Nicolas Patry committed
5
use std::env;
6
use std::ffi::OsString;
7
use std::io::{BufRead, BufReader, Lines, Read};
8
use std::os::unix::process::{CommandExt, ExitStatusExt};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use std::path::Path;
OlivierDehaene's avatar
OlivierDehaene committed
10
use std::process::{Child, Command, ExitStatus, Stdio};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
13
use std::sync::{mpsc, Arc};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
15
16
17
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
18
use tracing_subscriber::EnvFilter;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19

20
21
mod env_runtime;

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
    Bitsandbytes,
    Gptq,
}

impl std::fmt::Display for Quantization {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Quantization::Bitsandbytes => {
                write!(f, "bitsandbytes")
            }
            Quantization::Gptq => {
                write!(f, "gptq")
            }
        }
    }
}

42
43
44
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
    Float16,
45
    #[clap(name = "bfloat16")]
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    BFloat16,
}

impl std::fmt::Display for Dtype {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // To keep in track with `server`.
        match self {
            Dtype::Float16 => {
                write!(f, "float16")
            }
            Dtype::BFloat16 => {
                write!(f, "bfloat16")
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
63
64
65
66
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
67
68
69
70
71
    /// The name of the model to load.
    /// Can be a MODEL_ID as listed on <https://hf.co/models> like
    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
    /// Or it can be a local directory containing the necessary files
    /// as saved by `save_pretrained(...)` methods of transformers
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
72
    #[clap(default_value = "bigscience/bloom-560m", long, env)]
73
    model_id: String,
74
75
76

    /// The actual revision of the model if you're referring to a model
    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
    #[clap(long, env)]
78
    revision: Option<String>,
79

80
81
82
83
84
    /// The number of tokenizer workers used for payload validation and truncation inside the
    /// router.
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,

85
    /// Whether to shard the model across multiple GPUs
86
87
    /// By default text-generation-inference will use all available GPUs to run
    /// the model. Setting it to `false` deactivates `num_shard`.
88
89
    #[clap(long, env)]
    sharded: Option<bool>,
90
91

    /// The number of shards to use if you don't want to use all GPUs on a given machine.
92
93
    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
94
    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
95
96
    #[clap(long, env)]
    num_shard: Option<usize>,
97

98
    /// Whether you want the model to be quantized. This will use `bitsandbytes` for
99
100
101
    /// quantization on the fly, or `gptq`.
    #[clap(long, env, value_enum)]
    quantize: Option<Quantization>,
102

103
104
105
106
    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
    #[clap(long, env, value_enum)]
    dtype: Option<Dtype>,

107
108
109
110
111
112
    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
    /// encouraged when loading a model with custom code to ensure no malicious code has been
    /// contributed in a newer revision.
    #[clap(long, env, value_enum)]
    trust_remote_code: bool,

113
114
115
    /// The maximum amount of concurrent requests for this particular deployment.
    /// Having a low limit will refuse clients requests instead of having them
    /// wait for too long and is usually good to handle backpressure correctly.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
116
117
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
118
119
120
121

    /// This is the maximum allowed value for clients to set `best_of`.
    /// Best of makes `n` generations at the same time, and return the best
    /// in terms of overall log probability over the entire generated sequence
122
123
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
124
125
126
127
128
129

    /// This is the maximum allowed value for clients to set `stop_sequences`.
    /// Stop sequences are used to allow the model to stop on more than just
    /// the EOS token, and enable more complex "prompting" where users can preprompt
    /// the model in a specific way and define their "own" stop token aligned with
    /// their prompt.
130
131
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
132
133
134
135
136

    /// This is the maximum allowed input length (expressed in number of tokens)
    /// for users. The larger this value, the longer prompt users can send which
    /// can impact the overall memory required to handle the load.
    /// Please note that some models have a finite range of sequence they can handle.
137
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
    max_input_length: usize,
139
140
141
142
143
144
145
146
147

    /// This is the most important value to set as it defines the "memory budget"
    /// of running clients requests.
    /// Clients will send input sequences and ask to generate `max_new_tokens`
    /// on top. with a value of `1512` users can send either a prompt of
    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
    /// `1511` max_new_tokens.
    /// The larger this value, the larger amount each request will be in your RAM
    /// and the less effective batching can be.
148
    #[clap(default_value = "2048", long, env)]
149
    max_total_tokens: usize,
150
151
152
153
154
155
156
157
158
159
160

    /// This represents the ratio of waiting queries vs running queries where
    /// you want to start considering pausing the running queries to include the waiting
    /// ones into the same batch.
    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
    /// only 10 queries left in the current batch we check if we can fit those 12
    /// waiting queries into the batching strategy, and if yes, then batching happens
    /// delaying the 10 running queries by a `prefill` run.
    ///
    /// This setting is only applied if there is room in the batch
    /// as defined by `max_batch_total_tokens`.
161
162
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
163

164
165
166
167
168
169
    /// Limits the number of tokens for the prefill operation.
    /// Since this operation take the most memory and is compute bound, it is interesting
    /// to limit the number of requests that can be sent.
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    /// **IMPORTANT** This is one critical control to allow maximum usage
    /// of the available hardware.
    ///
    /// This represents the total amount of potential tokens within a batch.
    /// When using padding (not recommended) this would be equivalent of
    /// `batch_size` * `max_total_tokens`.
    ///
    /// However in the non-padded (flash attention) version this can be much finer.
    ///
    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
    /// or a single query of `1000` tokens.
    ///
    /// Overall this number should be the largest possible amount that fits the
    /// remaining memory (after the model is loaded). Since the actual memory overhead
    /// depends on other parameters like if you're using quantization, flash attention
    /// or the model implementation, text-generation-inference cannot infer this number
    /// automatically.
187
    #[clap(default_value = "16000", long, env)]
188
    max_batch_total_tokens: u32,
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    /// This setting defines how many tokens can be passed before forcing the waiting
    /// queries to be put on the batch (if the size of the batch allows for it).
    /// New queries require 1 `prefill` forward, which is different from `decode`
    /// and therefore you need to pause the running batch in order to run `prefill`
    /// to create the correct values for the waiting queries to be able to join the batch.
    ///
    /// With a value too small, queries will always "steal" the compute to run `prefill`
    /// and running queries will be delayed by a lot.
    ///
    /// With a value too big, waiting queries could wait for a very long time
    /// before being allowed a slot in the running batch. If your server is busy
    /// that means that requests that could run in ~2s on an empty server could
    /// end up running in ~20s because the query had to wait for 18s.
    ///
    /// This number is expressed in number of tokens to make it a bit more
    /// "model" agnostic, but what should really matter is the overall latency
    /// for end users.
207
208
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
209

210
211
212
213
    /// The IP address to listen on
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,

214
    /// The port to listen on.
215
    #[clap(default_value = "3000", long, short, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
216
    port: u16,
217
218
219

    /// The name of the socket for gRPC communication between the webserver
    /// and the shards.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
220
221
    #[clap(default_value = "/tmp/text-generation-server", long, env)]
    shard_uds_path: String,
222
223

    /// The address the master shard will listen on. (setting used by torch distributed)
224
    #[clap(default_value = "localhost", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
225
    master_addr: String,
226
227

    /// The address the master port will listen on. (setting used by torch distributed)
228
    #[clap(default_value = "29500", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
229
    master_port: usize,
230
231
232

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
233
    #[clap(long, env)]
234
    huggingface_hub_cache: Option<String>,
235
236
237

    /// The location of the huggingface hub cache.
    /// Used to override the location if you want to provide a mounted disk for instance
238
239
    #[clap(long, env)]
    weights_cache_override: Option<String>,
240
241
242
243
244

    /// For some models (like bloom), text-generation-inference implemented custom
    /// cuda kernels to speed up inference. Those kernels were only tested on A100.
    /// Use this flag to disable them if you're running on different hardware and
    /// encounter issues.
245
    #[clap(long, env)]
246
    disable_custom_kernels: bool,
247
248

    /// Outputs the logs in JSON format (useful for telemetry)
249
    #[clap(long, env)]
250
    json_output: bool,
251

252
253
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
254

255
256
    #[clap(long, env)]
    cors_allow_origin: Vec<String>,
257
258
259
260
    #[clap(long, env)]
    watermark_gamma: Option<f32>,
    #[clap(long, env)]
    watermark_delta: Option<f32>,
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    /// Enable ngrok tunneling
    #[clap(long, env)]
    ngrok: bool,

    /// ngrok authentication token
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,

    /// ngrok domain name where the axum webserver will be available at
    #[clap(long, env)]
    ngrok_domain: Option<String>,

    /// ngrok basic auth username
    #[clap(long, env)]
    ngrok_username: Option<String>,

    /// ngrok basic auth password
    #[clap(long, env)]
    ngrok_password: Option<String>,

282
283
284
    /// Display a lot of information about your runtime environment
    #[clap(long, short, action)]
    env: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
285
286
}

287
288
289
#[derive(Debug)]
enum ShardStatus {
    Ready,
290
    Failed(usize),
291
}
292

293
294
295
296
#[allow(clippy::too_many_arguments)]
fn shard_manager(
    model_id: String,
    revision: Option<String>,
297
    quantize: Option<Quantization>,
298
    dtype: Option<Dtype>,
299
    trust_remote_code: bool,
300
301
302
303
304
305
306
307
308
309
310
311
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
    otlp_endpoint: Option<String>,
    status_sender: mpsc::Sender<ShardStatus>,
312
    shutdown: Arc<AtomicBool>,
313
314
    _shutdown_sender: mpsc::Sender<()>,
) {
315
316
317
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

318
319
320
321
    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
322
323
324
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }
325
326

    // Process args
OlivierDehaene's avatar
OlivierDehaene committed
327
    let mut shard_args = vec![
328
329
330
331
332
333
334
335
336
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];

337
338
    // Activate trust remote code
    if trust_remote_code {
OlivierDehaene's avatar
OlivierDehaene committed
339
        shard_args.push("--trust-remote-code".to_string());
340
341
    }

342
343
    // Activate tensor parallelism
    if world_size > 1 {
OlivierDehaene's avatar
OlivierDehaene committed
344
        shard_args.push("--sharded".to_string());
345
346
    }

347
    if let Some(quantize) = quantize {
OlivierDehaene's avatar
OlivierDehaene committed
348
349
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
350
    }
351

352
    if let Some(dtype) = dtype {
OlivierDehaene's avatar
OlivierDehaene committed
353
354
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
355
356
    }

357
358
    // Model optional revision
    if let Some(revision) = revision {
OlivierDehaene's avatar
OlivierDehaene committed
359
360
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
361
    }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
362

363
364
    // OpenTelemetry
    if let Some(otlp_endpoint) = otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
365
366
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
367
368
369
    }

    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
370
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
371

372
    // Use cuda allocator. It leads to less memory fragmentation
OlivierDehaene's avatar
OlivierDehaene committed
373
    envs.push((
374
375
376
377
        "PYTORCH_CUDA_ALLOC_CONF".into(),
        "backend:cudaMallocAsync".into(),
    ));

378
    // Torch Distributed Env vars
OlivierDehaene's avatar
OlivierDehaene committed
379
380
381
382
383
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
    envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
384
385

    // Safetensors load fast
OlivierDehaene's avatar
OlivierDehaene committed
386
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
387
388
389

    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
390
    envs.push((
391
392
393
394
395
396
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
397
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
398
399
400
401
402
    };

    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
403
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
404
405
406
407
408
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
409
        envs.push((
410
411
412
413
414
415
416
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
OlivierDehaene's avatar
OlivierDehaene committed
417
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
418
419
420
421
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
OlivierDehaene's avatar
OlivierDehaene committed
422
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
423
424
425
426
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
OlivierDehaene's avatar
OlivierDehaene committed
427
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
428
429
430
431
    }

    // Start process
    tracing::info!("Starting shard {rank}");
432
    let mut p = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
433
434
        .args(shard_args)
        .envs(envs)
435
436
437
438
439
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
440
441
        Ok(p) => p,
        Err(err) => {
442
443
444
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
445
446
            }
            {
447
                tracing::error!("{}", err);
448
            }
449

450
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
451
452
453
454
455
            return;
        }
    };

    // Redirect STDOUT to the console
456
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
457
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
458

459
    //stdout tracing thread
460
    thread::spawn(move || {
461
        log_lines(shard_stdout_reader.lines());
462
463
464
465
466
467
468
    });

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
469
        if let Some(exit_status) = p.try_wait().unwrap() {
470
            // We read stderr in another thread as it seems that lines() can block in some cases
471
472
            let (err_sender, err_receiver) = mpsc::channel();
            thread::spawn(move || {
473
474
475
                for line in shard_stderr_reader.lines().flatten() {
                    err_sender.send(line).unwrap_or(());
                }
476
            });
477
478
479
480
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }
481

482
            tracing::error!("Shard complete standard error output:\n{err}");
483

484
            if let Some(signal) = exit_status.signal() {
485
486
487
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

488
            status_sender.send(ShardStatus::Failed(rank)).unwrap();
489
490
491
492
            return;
        }

        // We received a shutdown signal
493
        if shutdown.load(Ordering::SeqCst) {
494
            p.kill().unwrap();
495
            let _ = p.wait();
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            tracing::info!("Shard {rank} terminated");
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
            tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
            tracing::info!("Waiting for shard {rank} to be ready...");
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}

513
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
514
515
516
    tracing::info!("Shutting down shards");
    // Update shutdown value to true
    // This will be picked up by the shard manager
517
    shutdown.store(true, Ordering::SeqCst);
518
519
520
521
522
523
524

    // Wait for shards to shutdown
    // This will block till all shutdown_sender are dropped
    let _ = shutdown_receiver.recv();
}

fn num_cuda_devices() -> Option<usize> {
525
526
527
528
    let devices = match env::var("CUDA_VISIBLE_DEVICES") {
        Ok(devices) => devices,
        Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
    };
529
530
    let n_devices = devices.split(',').count();
    Some(n_devices)
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
    Trace,
    Debug,
    Info,
    Success,
    Warning,
    Error,
    Critical,
}

#[derive(Deserialize)]
struct PythonLogLevel {
    name: PythonLogLevelEnum,
}

#[derive(Deserialize)]
struct PythonLogRecord {
    level: PythonLogLevel,
}

#[derive(Deserialize)]
struct PythonLogMessage {
    text: String,
    record: PythonLogRecord,
}

impl PythonLogMessage {
    fn trace(&self) {
        match self.record.level.name {
            PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
            PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
            PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
            PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
            PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
            PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
        }
    }
}

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
impl TryFrom<&String> for PythonLogMessage {
    type Error = serde_json::Error;

    fn try_from(value: &String) -> Result<Self, Self::Error> {
        serde_json::from_str::<Self>(value)
    }
}

fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
    for line in lines.flatten() {
        match PythonLogMessage::try_from(&line) {
            Ok(log) => log.trace(),
            Err(_) => tracing::debug!("{line}"),
        }
    }
}

592
593
594
595
fn find_num_shards(
    sharded: Option<bool>,
    num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
596
597
598
599
    // get the number of shards given `sharded` and `num_shard`
    let num_shard = match (sharded, num_shard) {
        (Some(true), None) => {
            // try to default to the number of available GPUs
600
601
602
            tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
            let n_devices = num_cuda_devices()
                .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
603
            if n_devices <= 1 {
604
605
606
                return Err(LauncherError::NotEnoughCUDADevices(format!(
                    "`sharded` is true but only found {n_devices} CUDA devices"
                )));
607
            }
608
            n_devices
609
        }
610
611
612
        (Some(true), Some(num_shard)) => {
            // we can't have only one shard while sharded
            if num_shard <= 1 {
613
614
615
                return Err(LauncherError::ArgumentValidation(
                    "`sharded` is true but `num_shard` <= 1".to_string(),
                ));
616
617
            }
            num_shard
618
        }
619
620
621
622
        (Some(false), Some(num_shard)) => num_shard,
        (Some(false), None) => 1,
        (None, None) => num_cuda_devices().unwrap_or(1),
        (None, Some(num_shard)) => num_shard,
623
    };
624
    if num_shard < 1 {
625
626
627
        return Err(LauncherError::ArgumentValidation(
            "`num_shard` cannot be < 1".to_string(),
        ));
628
    }
629
    Ok(num_shard)
630
}
631

632
633
#[derive(Debug)]
enum LauncherError {
634
635
    ArgumentValidation(String),
    NotEnoughCUDADevices(String),
636
637
638
639
640
641
642
    DownloadError,
    ShardCannotStart,
    ShardDisconnected,
    ShardFailed,
    WebserverFailed,
    WebserverCannotStart,
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
643

644
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
645
646
647
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

OlivierDehaene's avatar
OlivierDehaene committed
648
    let mut download_args = vec![
649
650
651
652
653
654
655
656
        "download-weights".to_string(),
        args.model_id.to_string(),
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];
657

658
659
    // Model optional revision
    if let Some(revision) = &args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
660
661
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
662
    }
663

664
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
665
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
666

667
    // If huggingface_hub_cache is set, pass it to the download process
668
669
    // Useful when running inside a docker container
    if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
OlivierDehaene's avatar
OlivierDehaene committed
670
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
671
    };
672

673
674
    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
OlivierDehaene's avatar
OlivierDehaene committed
675
    envs.push((
676
677
678
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));
679

680
681
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
682
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
683
    };
684

685
686
687
    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = &args.weights_cache_override {
OlivierDehaene's avatar
OlivierDehaene committed
688
        envs.push((
689
690
691
692
693
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

694
695
    // Start process
    tracing::info!("Starting download process.");
696
    let mut download_process = match Command::new("text-generation-server")
OlivierDehaene's avatar
OlivierDehaene committed
697
698
        .args(download_args)
        .envs(envs)
699
700
701
702
703
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
704
705
        Ok(p) => p,
        Err(err) => {
706
707
708
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
709
710
            } else {
                tracing::error!("{}", err);
711
            }
712

713
714
715
            return Err(LauncherError::DownloadError);
        }
    };
716

717
718
    // Redirect STDOUT to the console
    let download_stdout = download_process.stdout.take().unwrap();
719
720
    let stdout = BufReader::new(download_stdout);

721
    thread::spawn(move || {
722
        log_lines(stdout.lines());
723
    });
724

725
    loop {
726
727
728
729
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
                tracing::info!("Successfully downloaded weights.");
                break;
730
            }
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747

            let mut err = String::new();
            download_process
                .stderr
                .take()
                .unwrap()
                .read_to_string(&mut err)
                .unwrap();
            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
748
        }
749
        if !running.load(Ordering::SeqCst) {
OlivierDehaene's avatar
OlivierDehaene committed
750
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
751
752
753
            return Ok(());
        }
        sleep(Duration::from_millis(100));
754
    }
755
756
    Ok(())
}
757

758
#[allow(clippy::too_many_arguments)]
759
760
761
fn spawn_shards(
    num_shard: usize,
    args: &Args,
762
    shutdown: Arc<AtomicBool>,
763
764
765
766
767
768
    shutdown_receiver: &mpsc::Receiver<()>,
    shutdown_sender: mpsc::Sender<()>,
    status_receiver: &mpsc::Receiver<ShardStatus>,
    status_sender: mpsc::Sender<ShardStatus>,
    running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
769
770
    // Start shard processes
    for rank in 0..num_shard {
771
772
773
774
775
776
        let model_id = args.model_id.clone();
        let revision = args.revision.clone();
        let uds_path = args.shard_uds_path.clone();
        let master_addr = args.master_addr.clone();
        let huggingface_hub_cache = args.huggingface_hub_cache.clone();
        let weights_cache_override = args.weights_cache_override.clone();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
777
778
779
        let status_sender = status_sender.clone();
        let shutdown = shutdown.clone();
        let shutdown_sender = shutdown_sender.clone();
780
        let otlp_endpoint = args.otlp_endpoint.clone();
781
        let quantize = args.quantize;
782
        let dtype = args.dtype;
783
        let trust_remote_code = args.trust_remote_code;
784
785
786
787
        let master_port = args.master_port;
        let disable_custom_kernels = args.disable_custom_kernels;
        let watermark_gamma = args.watermark_gamma;
        let watermark_delta = args.watermark_delta;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
788
789
        thread::spawn(move || {
            shard_manager(
790
                model_id,
791
                revision,
792
                quantize,
793
                dtype,
794
                trust_remote_code,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
795
796
797
798
799
                uds_path,
                rank,
                num_shard,
                master_addr,
                master_port,
800
801
                huggingface_hub_cache,
                weights_cache_override,
802
                disable_custom_kernels,
803
804
                watermark_gamma,
                watermark_delta,
805
                otlp_endpoint,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
                status_sender,
                shutdown,
                shutdown_sender,
            )
        });
    }
    drop(shutdown_sender);

    // Wait for shard to start
    let mut shard_ready = 0;
    while running.load(Ordering::SeqCst) {
        match status_receiver.try_recv() {
            Ok(ShardStatus::Ready) => {
                shard_ready += 1;
                if shard_ready == num_shard {
                    break;
                }
            }
            Err(TryRecvError::Empty) => {
                sleep(Duration::from_millis(100));
            }
827
            Ok(ShardStatus::Failed(rank)) => {
828
                tracing::error!("Shard {rank} failed to start");
829
                shutdown_shards(shutdown, shutdown_receiver);
830
                return Err(LauncherError::ShardCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
831
832
833
            }
            Err(TryRecvError::Disconnected) => {
                tracing::error!("Shard status channel disconnected");
834
                shutdown_shards(shutdown, shutdown_receiver);
835
                return Err(LauncherError::ShardDisconnected);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
836
837
838
            }
        }
    }
839
840
    Ok(())
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
841

842
843
fn spawn_webserver(
    args: Args,
844
    shutdown: Arc<AtomicBool>,
845
    shutdown_receiver: &mpsc::Receiver<()>,
846
) -> Result<Child, LauncherError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
847
848
849
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
OlivierDehaene's avatar
OlivierDehaene committed
850
    let mut router_args = vec![
851
        "--max-concurrent-requests".to_string(),
852
        args.max_concurrent_requests.to_string(),
853
        "--max-best-of".to_string(),
854
        args.max_best_of.to_string(),
855
        "--max-stop-sequences".to_string(),
856
        args.max_stop_sequences.to_string(),
857
        "--max-input-length".to_string(),
858
        args.max_input_length.to_string(),
859
        "--max-total-tokens".to_string(),
860
        args.max_total_tokens.to_string(),
861
862
863
864
        "--max-batch-prefill-tokens".to_string(),
        args.max_batch_prefill_tokens.to_string(),
        "--max-batch-total-tokens".to_string(),
        args.max_batch_total_tokens.to_string(),
865
        "--waiting-served-ratio".to_string(),
866
        args.waiting_served_ratio.to_string(),
867
        "--max-waiting-tokens".to_string(),
868
        args.max_waiting_tokens.to_string(),
869
870
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
871
872
        "--hostname".to_string(),
        args.hostname.to_string(),
873
        "--port".to_string(),
874
        args.port.to_string(),
875
        "--master-shard-uds-path".to_string(),
876
        format!("{}-0", args.shard_uds_path),
877
        "--tokenizer-name".to_string(),
878
        args.model_id,
879
880
    ];

881
882
    // Model optional revision
    if let Some(ref revision) = args.revision {
OlivierDehaene's avatar
OlivierDehaene committed
883
884
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
885
886
    }

887
    if args.json_output {
OlivierDehaene's avatar
OlivierDehaene committed
888
        router_args.push("--json-output".to_string());
889
890
    }

891
    // OpenTelemetry
892
    if let Some(otlp_endpoint) = args.otlp_endpoint {
OlivierDehaene's avatar
OlivierDehaene committed
893
894
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
895
896
897
898
    }

    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
OlivierDehaene's avatar
OlivierDehaene committed
899
900
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
901
902
    }

903
904
905
906
907
908
909
    // Ngrok
    if args.ngrok {
        let authtoken = args.ngrok_authtoken.ok_or_else(|| {
            tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
            LauncherError::WebserverCannotStart
        })?;

OlivierDehaene's avatar
OlivierDehaene committed
910
911
912
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
        router_args.push(authtoken);
913
914

        if let Some(domain) = args.ngrok_domain {
OlivierDehaene's avatar
OlivierDehaene committed
915
916
            router_args.push("--ngrok-domain".to_string());
            router_args.push(domain);
917
918
919
        }

        if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
OlivierDehaene's avatar
OlivierDehaene committed
920
921
922
923
            router_args.push("--ngrok-username".to_string());
            router_args.push(username);
            router_args.push("--ngrok-password".to_string());
            router_args.push(password);
924
925
926
        }
    }

927
    // Copy current process env
OlivierDehaene's avatar
OlivierDehaene committed
928
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
929

930
931
    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
OlivierDehaene's avatar
OlivierDehaene committed
932
        envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
933
    };
934

935
    let mut webserver = match Command::new("text-generation-router")
OlivierDehaene's avatar
OlivierDehaene committed
936
937
        .args(router_args)
        .envs(envs)
938
939
940
941
942
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
943
944
        Ok(p) => p,
        Err(err) => {
945
            tracing::error!("Failed to start webserver: {}", err);
946
947
948
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
949
950
            } else {
                tracing::error!("{}", err);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
951
            }
952

953
            shutdown_shards(shutdown, shutdown_receiver);
954
            return Err(LauncherError::WebserverCannotStart);
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
955
956
957
        }
    };

958
959
960
    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();
961
962

    thread::spawn(move || {
963
964
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
965
        for line in stdout.lines() {
966
            println!("{}", line.unwrap());
967
        }
968
969
        for line in stderr.lines() {
            println!("{}", line.unwrap());
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
970
        }
971
972
973
    });
    Ok(webserver)
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
974

OlivierDehaene's avatar
OlivierDehaene committed
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
    tracing::info!("Terminating {process_name}");

    let terminate_time = Instant::now();
    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();

    tracing::info!("Waiting for {process_name} to gracefully shutdown");

    while terminate_time.elapsed() < timeout {
        if let Some(status) = process.try_wait()? {
            tracing::info!("{process_name} terminated");
            return Ok(status);
        }
        sleep(Duration::from_millis(100));
    }

    tracing::info!("Killing {process_name}");

    process.kill()?;
    let exit_status = process.wait()?;

    tracing::info!("{process_name} killed");
    Ok(exit_status)
}

1000
1001
1002
fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1003

1004
1005
1006
1007
    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

1008
    if args.json_output {
1009
1010
1011
1012
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
1013
    } else {
1014
1015
1016
1017
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1018
1019
    }

1020
1021
1022
1023
1024
    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

1025
1026
    tracing::info!("{:?}", args);

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    // Validate args
    if args.max_input_length >= args.max_total_tokens {
        return Err(LauncherError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
    if args.max_input_length as u32 > args.max_batch_prefill_tokens {
        return Err(LauncherError::ArgumentValidation(format!(
            "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
            args.max_batch_prefill_tokens, args.max_input_length
        )));
    }
    if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
        return Err(LauncherError::ArgumentValidation(format!(
            "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
            args.max_batch_prefill_tokens, args.max_batch_total_tokens
        )));
    }
    if args.max_total_tokens as u32 > args.max_batch_total_tokens {
        return Err(LauncherError::ArgumentValidation(format!(
            "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
            args.max_total_tokens, args.max_batch_total_tokens
        )));
    }
    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
1056
1057
1058
1059
1060
1061
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }
1062
1063

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
1064
1065
    if num_shard > 1 {
        tracing::info!("Sharding model on {num_shard} processes");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1066
1067
    }

1068
1069
1070
1071
1072
1073
1074
    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");
1075

1076
    // Download and convert model weights
1077
    download_convert_model(&args, running.clone())?;
1078

OlivierDehaene's avatar
OlivierDehaene committed
1079
1080
1081
1082
1083
    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

1084
    // Shared shutdown bool
1085
    let shutdown = Arc::new(AtomicBool::new(false));
1086
1087
1088
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();
1089

1090
1091
    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();
1092

1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    spawn_shards(
        num_shard,
        &args,
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
    )?;
1103

1104
1105
1106
1107
1108
    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }
1109

OlivierDehaene's avatar
OlivierDehaene committed
1110
1111
1112
1113
1114
    let mut webserver =
        spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
            shutdown_shards(shutdown.clone(), &shutdown_receiver);
            err
        })?;
1115
1116
1117
1118
1119

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
1120
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
OlivierDehaene's avatar
OlivierDehaene committed
1121
            tracing::error!("Shard {rank} crashed");
1122
1123
1124
1125
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

1126
        match webserver.try_wait().unwrap() {
1127
1128
1129
1130
1131
1132
1133
1134
1135
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
1136
    }
1137
1138

    // Graceful termination
OlivierDehaene's avatar
OlivierDehaene committed
1139
    terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
1140
1141
1142
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
1143
}