Unverified Commit 77758f60 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

chore(launcher): refactor logic (#242)

Hopefully it's cleaner
parent 7de8a377
...@@ -4,7 +4,6 @@ use std::env; ...@@ -4,7 +4,6 @@ use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
use std::path::Path; use std::path::Path;
use std::process::ExitCode;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError; use std::sync::mpsc::TryRecvError;
use std::sync::Arc; use std::sync::Arc;
...@@ -73,248 +72,454 @@ struct Args { ...@@ -73,248 +72,454 @@ struct Args {
watermark_delta: Option<f32>, watermark_delta: Option<f32>,
} }
fn main() -> ExitCode { #[derive(Debug)]
// Pattern match configuration enum ShardStatus {
let args = Args::parse(); Ready,
Failed((usize, String)),
}
if args.json_output { #[allow(clippy::too_many_arguments)]
tracing_subscriber::fmt().json().init(); fn shard_manager(
} else { model_id: String,
tracing_subscriber::fmt().compact().init(); revision: Option<String>,
quantize: bool,
uds_path: String,
rank: usize,
world_size: usize,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();
// Process args
let mut shard_argv = vec![
"text-generation-server".to_string(),
"serve".to_string(),
model_id,
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Activate tensor parallelism
if world_size > 1 {
shard_argv.push("--sharded".to_string());
} }
tracing::info!("{:?}", args); if quantize {
shard_argv.push("--quantize".to_string())
}
let Args { // Model optional revision
model_id, if let Some(revision) = revision {
revision, shard_argv.push("--revision".to_string());
sharded, shard_argv.push(revision)
num_shard, }
quantize,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_total_tokens,
waiting_served_ratio,
max_waiting_tokens,
port,
shard_uds_path,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
json_output,
otlp_endpoint,
cors_allow_origin,
watermark_gamma,
watermark_delta,
} = args;
// get the number of shards given `sharded` and `num_shard` // OpenTelemetry
let num_shard = if let Some(sharded) = sharded { if let Some(otlp_endpoint) = otlp_endpoint {
// sharded is set shard_argv.push("--otlp-endpoint".to_string());
match sharded { shard_argv.push(otlp_endpoint);
// sharded is set and true }
true => {
match num_shard { // Copy current process env
None => { let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); // Torch Distributed Env vars
let n_devices = num_cuda_devices() env.push(("RANK".into(), rank.to_string().into()));
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
if n_devices <= 1 { env.push(("MASTER_ADDR".into(), master_addr.into()));
panic!("`sharded` is true but only found {n_devices} CUDA devices"); env.push(("MASTER_PORT".into(), master_port.to_string().into()));
} env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
n_devices
} // Safetensors load fast
Some(num_shard) => { env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// we can't have only one shard while sharded
if num_shard <= 1 { // Enable hf transfer for insane download speeds
panic!("`sharded` is true but `num_shard` <= 1"); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
} env.push((
num_shard "HF_HUB_ENABLE_HF_TRANSFER".into(),
} enable_hf_transfer.into(),
));
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = weights_cache_override {
env.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process
tracing::info!("Starting shard {rank}");
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
// NCCL env vars
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
} }
} }
// sharded is set and false status_sender
false => { .send(ShardStatus::Failed((rank, err.to_string())))
let num_shard = num_shard.unwrap_or(1); .unwrap();
// we can't have more than one shard while not sharded return;
if num_shard != 1 { }
panic!("`sharded` is false but `num_shard` != 1"); };
}
num_shard // Redirect STDOUT to the console
let shard_stdout = p.stdout.take().unwrap();
thread::spawn(move || {
// Enter shard-manager tracing span
let stdout = BufReader::new(shard_stdout);
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
});
let mut ready = false;
let start_time = Instant::now();
let mut wait_time = Instant::now();
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return;
}
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated");
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
}
}
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards");
// Update shutdown value to true
// This will be picked up by the shard manager
{
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv();
}
fn num_cuda_devices() -> Option<usize> {
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
let n_devices = cuda_visible_devices.split(',').count();
return Some(n_devices);
}
None
}
#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
Trace,
Debug,
Info,
Success,
Warning,
Error,
Critical,
}
#[derive(Deserialize)]
struct PythonLogLevel {
name: PythonLogLevelEnum,
}
#[derive(Deserialize)]
struct PythonLogRecord {
level: PythonLogLevel,
}
#[derive(Deserialize)]
struct PythonLogMessage {
text: String,
record: PythonLogRecord,
}
impl PythonLogMessage {
fn trace(&self) {
match self.record.level.name {
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
}
}
}
fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
// get the number of shards given `sharded` and `num_shard`
let num_shard = match (sharded, num_shard) {
(Some(true), None) => {
// try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
let n_devices =
num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices");
} }
n_devices
} }
} else { (Some(true), Some(num_shard)) => {
match num_shard { // we can't have only one shard while sharded
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard if num_shard <= 1 {
None => num_cuda_devices().unwrap_or(1), panic!("`sharded` is true but `num_shard` <= 1");
Some(num_shard) => num_shard, }
num_shard
} }
(Some(false), Some(num_shard)) => num_shard,
(Some(false), None) => 1,
(None, None) => num_cuda_devices().unwrap_or(1),
(None, Some(num_shard)) => num_shard,
}; };
if num_shard < 1 { if num_shard < 1 {
panic!("`num_shard` cannot be < 1"); panic!("`num_shard` cannot be < 1");
} }
num_shard
}
if num_shard > 1 { #[derive(Debug)]
tracing::info!("Sharding model on {num_shard} processes"); enum LauncherError {
} DownloadError,
ShardCannotStart,
// Signal handler ShardDisconnected,
let running = Arc::new(AtomicBool::new(true)); ShardFailed,
let r = running.clone(); WebserverFailed,
ctrlc::set_handler(move || { WebserverCannotStart,
r.store(false, Ordering::SeqCst); }
})
.expect("Error setting Ctrl-C handler");
// Check if model_id is a local model fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
let local_path = Path::new(&model_id); let mut download_argv = vec![
let is_local_model = local_path.exists() && local_path.is_dir(); "text-generation-server".to_string(),
"download-weights".to_string(),
args.model_id.to_string(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Download weights for sharded models // Model optional revision
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 { if let Some(revision) = &args.revision {
let mut download_argv = vec![ download_argv.push("--revision".to_string());
"text-generation-server".to_string(), download_argv.push(revision.to_string())
"download-weights".to_string(), }
model_id.clone(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Model optional revision
if let Some(ref revision) = revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
}
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// If huggingface_hub_cache is set, pass it to the shard // If huggingface_hub_cache is set, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// Enable hf transfer for insane download speeds // Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push(( env.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(), "HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(), enable_hf_transfer.into(),
)); ));
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Popen::create( let mut download_process = match Popen::create(
&download_argv, &download_argv,
PopenConfig { PopenConfig {
stdout: Redirection::Pipe, stdout: Redirection::Pipe,
stderr: Redirection::Pipe, stderr: Redirection::Pipe,
// Needed for the shutdown procedure // Needed for the shutdown procedure
setpgid: true, setpgid: true,
env: Some(env), env: Some(env),
..Default::default() ..Default::default()
}, },
) { ) {
Ok(p) => p, Ok(p) => p,
Err(err) => { Err(err) => {
if let PopenError::IoError(ref err) = err { if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
}
} }
return ExitCode::FAILURE;
} }
}; return Err(LauncherError::DownloadError);
}
};
// Redirect STDOUT to the console // Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap(); let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || { thread::spawn(move || {
// Enter download tracing span // Enter download tracing span
let stdout = BufReader::new(download_stdout); let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() { for line in stdout.lines() {
// Parse loguru logs // Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) { if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace(); log.trace();
}
} }
}); }
});
loop { loop {
if let Some(status) = download_process.poll() { if let Some(status) = download_process.poll() {
match status { match status {
ExitStatus::Exited(exit_code) => { ExitStatus::Exited(exit_code) => {
if exit_code == 0 { if exit_code == 0 {
tracing::info!("Successfully downloaded weights."); tracing::info!("Successfully downloaded weights.");
break; break;
} else { } else {
let mut err = String::new(); let mut err = String::new();
download_process download_process
.stderr .stderr
.take() .take()
.unwrap() .unwrap()
.read_to_string(&mut err) .read_to_string(&mut err)
.unwrap(); .unwrap();
tracing::error!("Download encountered an error: {err}"); tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE; return Err(LauncherError::DownloadError);
}
}
_ => {
tracing::error!("Download process exited with an unknown status.");
return ExitCode::FAILURE;
} }
} }
_ => {
tracing::error!("Download process exited with an unknown status.");
return Err(LauncherError::DownloadError);
}
} }
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
} }
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return Ok(());
}
sleep(Duration::from_millis(100));
} }
Ok(())
}
// Shared shutdown bool fn spawn_shards(
let shutdown = Arc::new(Mutex::new(false)); num_shard: usize,
// Shared shutdown channel args: &Args,
// When shutting down, the main thread will wait for all senders to be dropped shutdown: Arc<Mutex<bool>>,
let (shutdown_sender, shutdown_receiver) = mpsc::channel(); shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>,
// Shared channel to track shard status status_receiver: &mpsc::Receiver<ShardStatus>,
let (status_sender, status_receiver) = mpsc::channel(); status_sender: mpsc::Sender<ShardStatus>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_id = model_id.clone(); let model_id = args.model_id.clone();
let revision = revision.clone(); let revision = args.revision.clone();
let uds_path = shard_uds_path.clone(); let uds_path = args.shard_uds_path.clone();
let master_addr = master_addr.clone(); let master_addr = args.master_addr.clone();
let huggingface_hub_cache = huggingface_hub_cache.clone(); let huggingface_hub_cache = args.huggingface_hub_cache.clone();
let weights_cache_override = weights_cache_override.clone(); let weights_cache_override = args.weights_cache_override.clone();
let status_sender = status_sender.clone(); let status_sender = status_sender.clone();
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize.clone();
let master_port = args.master_port.clone();
let disable_custom_kernels = args.disable_custom_kernels.clone();
let watermark_gamma = args.watermark_gamma.clone();
let watermark_delta = args.watermark_delta.clone();
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
...@@ -355,422 +560,224 @@ fn main() -> ExitCode { ...@@ -355,422 +560,224 @@ fn main() -> ExitCode {
Ok(ShardStatus::Failed((rank, err))) => { Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err); tracing::error!("Shard {} failed to start:\n{}", rank, err);
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE; return Err(LauncherError::ShardCannotStart);
} }
Err(TryRecvError::Disconnected) => { Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected"); tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE; return Err(LauncherError::ShardDisconnected);
} }
} }
} }
Ok(())
}
// We might have received a termination signal fn spawn_webserver(
if !running.load(Ordering::SeqCst) { args: Args,
shutdown_shards(shutdown, &shutdown_receiver); shutdown: Arc<Mutex<bool>>,
return ExitCode::SUCCESS; shutdown_receiver: &mpsc::Receiver<()>,
} ) -> Result<Popen, LauncherError> {
// All shard started // All shard started
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut argv = vec![ let mut argv = vec![
"text-generation-router".to_string(), "text-generation-router".to_string(),
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(), args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(), "--max-best-of".to_string(),
max_best_of.to_string(), args.max_best_of.to_string(),
"--max-stop-sequences".to_string(), "--max-stop-sequences".to_string(),
max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-input-length".to_string(), "--max-input-length".to_string(),
max_input_length.to_string(), args.max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
max_total_tokens.to_string(), args.max_total_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
max_waiting_tokens.to_string(), args.max_waiting_tokens.to_string(),
"--port".to_string(), "--port".to_string(),
port.to_string(), args.port.to_string(),
"--master-shard-uds-path".to_string(), "--master-shard-uds-path".to_string(),
format!("{shard_uds_path}-0"), format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
model_id, args.model_id,
]; ];
// Deprecate max_batch_size // Deprecate max_batch_size
if let Some(max_batch_size) = max_batch_size { if let Some(max_batch_size) = args.max_batch_size {
argv.push("--max-batch-size".to_string()); argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string()) argv.push(max_batch_size.to_string())
} else { } else {
argv.push("--max-batch-total-tokens".to_string()); argv.push("--max-batch-total-tokens".to_string());
argv.push(max_batch_total_tokens.to_string()) argv.push(args.max_batch_total_tokens.to_string())
}
// Model optional revision
if let Some(ref revision) = revision {
argv.push("--revision".to_string());
argv.push(revision.to_string())
}
if json_output {
argv.push("--json-output".to_string());
}
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
argv.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint);
}
// CORS origins
for origin in cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
let mut webserver = match Popen::create(
&argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
tracing::error!("Failed to start webserver: {}", err);
if let PopenError::IoError(err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-router`")
}
} else {
tracing::error!("{}", err);
}
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
};
// Redirect STDOUT and STDERR to the console
let webserver_stdout = webserver.stdout.take().unwrap();
let webserver_stderr = webserver.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(webserver_stdout);
let stderr = BufReader::new(webserver_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
// Default exit code
let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}");
exit_code = ExitCode::FAILURE;
break;
};
match webserver.poll() {
Some(_) => {
tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
None => {
sleep(Duration::from_millis(100));
}
};
}
// Graceful termination
webserver.terminate().unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
}
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
}
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: bool,
uds_path: String,
rank: usize,
world_size: usize,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();
// Process args
let mut shard_argv = vec![
"text-generation-server".to_string(),
"serve".to_string(),
model_id,
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Activate tensor parallelism
if world_size > 1 {
shard_argv.push("--sharded".to_string());
} }
if quantize { // Model optional revision
shard_argv.push("--quantize".to_string()) if let Some(ref revision) = args.revision {
argv.push("--revision".to_string());
argv.push(revision.to_string())
} }
// Model optional revision if args.json_output {
if let Some(revision) = revision { argv.push("--json-output".to_string());
shard_argv.push("--revision".to_string());
shard_argv.push(revision)
} }
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = args.otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string()); argv.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint); argv.push(otlp_endpoint);
}
// CORS origins
for origin in args.cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(),
));
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// If huggingface_hub_cache is some, pass it to the shard let mut webserver = match Popen::create(
// Useful when running inside a docker container &argv,
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = weights_cache_override {
env.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process
tracing::info!("Starting shard {rank}");
let mut p = match Popen::create(
&shard_argv,
PopenConfig { PopenConfig {
stdout: Redirection::Pipe, stdout: Redirection::Pipe,
stderr: Redirection::Pipe, stderr: Redirection::Pipe,
// Needed for the shutdown procedure // Needed for the shutdown procedure
setpgid: true, setpgid: true,
// NCCL env vars
env: Some(env), env: Some(env),
..Default::default() ..Default::default()
}, },
) { ) {
Ok(p) => p, Ok(p) => p,
Err(err) => { Err(err) => {
if let PopenError::IoError(ref err) = err { tracing::error!("Failed to start webserver: {}", err);
if let PopenError::IoError(err) = err {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-router`")
} }
} else {
tracing::error!("{}", err);
} }
status_sender
.send(ShardStatus::Failed((rank, err.to_string()))) shutdown_shards(shutdown, &shutdown_receiver);
.unwrap(); return Err(LauncherError::WebserverCannotStart);
return;
} }
}; };
// Redirect STDOUT to the console // Redirect STDOUT and STDERR to the console
let shard_stdout = p.stdout.take().unwrap(); let webserver_stdout = webserver.stdout.take().unwrap();
let webserver_stderr = webserver.stderr.take().unwrap();
thread::spawn(move || { thread::spawn(move || {
// Enter shard-manager tracing span let stdout = BufReader::new(webserver_stdout);
let stdout = BufReader::new(shard_stdout); let stderr = BufReader::new(webserver_stderr);
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() { for line in stdout.lines() {
// Parse loguru logs println!("{}", line.unwrap());
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
} }
}); for line in stderr.lines() {
println!("{}", line.unwrap());
let mut ready = false;
let start_time = Instant::now();
let mut wait_time = Instant::now();
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return;
} }
});
Ok(webserver)
}
// We received a shutdown signal fn main() -> Result<(), LauncherError> {
if *shutdown.lock().unwrap() { // Pattern match configuration
p.terminate().unwrap(); let args = Args::parse();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated");
return;
}
// Shard is ready if args.json_output {
if uds.exists() && !ready { tracing_subscriber::fmt().json().init();
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); } else {
status_sender.send(ShardStatus::Ready).unwrap(); tracing_subscriber::fmt().compact().init();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
} }
}
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) { tracing::info!("{:?}", args);
tracing::info!("Shutting down shards");
// Update shutdown value to true let num_shard = find_num_shards(args.sharded, args.num_shard);
// This will be picked up by the shard manager if num_shard > 1 {
{ tracing::info!("Sharding model on {num_shard} processes");
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
} }
// Wait for shards to shutdown // Signal handler
// This will block till all shutdown_sender are dropped let running = Arc::new(AtomicBool::new(true));
let _ = shutdown_receiver.recv(); let r = running.clone();
} ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
fn num_cuda_devices() -> Option<usize> { // Check if model_id is a local model
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { let local_path = Path::new(&args.model_id);
let n_devices = cuda_visible_devices.split(',').count(); let is_local_model = local_path.exists() && local_path.is_dir();
return Some(n_devices);
// Download weights for sharded models
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
download_model(&args, running.clone())?;
} }
None
}
#[derive(Deserialize)] // Shared shutdown bool
#[serde(rename_all = "UPPERCASE")] let shutdown = Arc::new(Mutex::new(false));
enum PythonLogLevelEnum { // Shared shutdown channel
Trace, // When shutting down, the main thread will wait for all senders to be dropped
Debug, let (shutdown_sender, shutdown_receiver) = mpsc::channel();
Info,
Success,
Warning,
Error,
Critical,
}
#[derive(Deserialize)] // Shared channel to track shard status
struct PythonLogLevel { let (status_sender, status_receiver) = mpsc::channel();
name: PythonLogLevelEnum,
}
#[derive(Deserialize)] spawn_shards(
struct PythonLogRecord { num_shard,
level: PythonLogLevel, &args,
} shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
&status_receiver,
status_sender,
running.clone(),
)?;
#[derive(Deserialize)] // We might have received a termination signal
struct PythonLogMessage { if !running.load(Ordering::SeqCst) {
text: String, shutdown_shards(shutdown, &shutdown_receiver);
record: PythonLogRecord, return Ok(());
} }
impl PythonLogMessage { let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
fn trace(&self) {
match self.record.level.name { // Default exit code
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), let mut exit_code = Ok(());
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text), while running.load(Ordering::SeqCst) {
PythonLogLevelEnum::Success => tracing::info!("{}", self.text), if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), tracing::error!("Shard {rank} failed:\n{err}");
PythonLogLevelEnum::Error => tracing::error!("{}", self.text), exit_code = Err(LauncherError::ShardFailed);
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), break;
} };
match webserver.poll() {
Some(_) => {
tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver);
return Err(LauncherError::WebserverFailed);
}
None => {
sleep(Duration::from_millis(100));
}
};
} }
// Graceful termination
webserver.terminate().unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment