Unverified Commit c58a0c18 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

v0.9.2 (#616)

parent 5b9de4a1
This diff is collapsed.
...@@ -8,7 +8,7 @@ members = [ ...@@ -8,7 +8,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.9.1" version = "0.9.2"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.9.1" "version": "0.9.2"
}, },
"paths": { "paths": {
"/": { "/": {
......
...@@ -7,7 +7,7 @@ use std::ffi::OsString; ...@@ -7,7 +7,7 @@ use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, Stdio}; use std::process::{Child, Command, ExitStatus, Stdio};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError; use std::sync::mpsc::TryRecvError;
use std::sync::{mpsc, Arc}; use std::sync::{mpsc, Arc};
...@@ -319,7 +319,7 @@ fn shard_manager( ...@@ -319,7 +319,7 @@ fn shard_manager(
} }
// Process args // Process args
let mut shard_argv = vec![ let mut shard_args = vec![
"serve".to_string(), "serve".to_string(),
model_id, model_id,
"--uds-path".to_string(), "--uds-path".to_string(),
...@@ -331,77 +331,77 @@ fn shard_manager( ...@@ -331,77 +331,77 @@ fn shard_manager(
// Activate trust remote code // Activate trust remote code
if trust_remote_code { if trust_remote_code {
shard_argv.push("--trust-remote-code".to_string()); shard_args.push("--trust-remote-code".to_string());
} }
// Activate tensor parallelism // Activate tensor parallelism
if world_size > 1 { if world_size > 1 {
shard_argv.push("--sharded".to_string()); shard_args.push("--sharded".to_string());
} }
if let Some(quantize) = quantize { if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string()); shard_args.push("--quantize".to_string());
shard_argv.push(quantize.to_string()) shard_args.push(quantize.to_string())
} }
if let Some(dtype) = dtype { if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string()); shard_args.push("--dtype".to_string());
shard_argv.push(dtype.to_string()) shard_args.push(dtype.to_string())
} }
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_argv.push("--revision".to_string()); shard_args.push("--revision".to_string());
shard_argv.push(revision) shard_args.push(revision)
} }
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string()); shard_args.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint); shard_args.push(otlp_endpoint);
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation // Use cuda allocator. It leads to less memory fragmentation
env.push(( envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(), "PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(), "backend:cudaMallocAsync".into(),
)); ));
// Torch Distributed Env vars // Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into())); envs.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into())); envs.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast // Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds // Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push(( envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(), "HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(), enable_hf_transfer.into(),
)); ));
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// If weights_cache_override is some, pass it to the shard // If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = weights_cache_override { if let Some(weights_cache_override) = weights_cache_override {
env.push(( envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
)); ));
...@@ -409,24 +409,24 @@ fn shard_manager( ...@@ -409,24 +409,24 @@ fn shard_manager(
// If disable_custom_kernels is true, pass it to the shard as an env var // If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels { if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
} }
// Watermark Gamma // Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma { if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
} }
// Watermark Delta // Watermark Delta
if let Some(watermark_delta) = watermark_delta { if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
} }
// Start process // Start process
tracing::info!("Starting shard {rank}"); tracing::info!("Starting shard {rank}");
let mut p = match Command::new("text-generation-server") let mut p = match Command::new("text-generation-server")
.args(shard_argv) .args(shard_args)
.envs(env) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
...@@ -632,7 +632,7 @@ enum LauncherError { ...@@ -632,7 +632,7 @@ enum LauncherError {
} }
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
let mut download_argv = vec![ let mut download_args = vec![
"download-weights".to_string(), "download-weights".to_string(),
args.model_id.to_string(), args.model_id.to_string(),
"--extension".to_string(), "--extension".to_string(),
...@@ -644,35 +644,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L ...@@ -644,35 +644,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Model optional revision // Model optional revision
if let Some(revision) = &args.revision { if let Some(revision) = &args.revision {
download_argv.push("--revision".to_string()); download_args.push("--revision".to_string());
download_argv.push(revision.to_string()) download_args.push(revision.to_string())
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// If huggingface_hub_cache is set, pass it to the download process // If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// Enable hf transfer for insane download speeds // Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push(( envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(), "HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(), enable_hf_transfer.into(),
)); ));
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// If args.weights_cache_override is some, pass it to the download process // If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override { if let Some(weights_cache_override) = &args.weights_cache_override {
env.push(( envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
)); ));
...@@ -681,8 +681,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L ...@@ -681,8 +681,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Command::new("text-generation-server") let mut download_process = match Command::new("text-generation-server")
.args(download_argv) .args(download_args)
.envs(env) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
...@@ -738,10 +738,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L ...@@ -738,10 +738,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
return Err(LauncherError::DownloadError); return Err(LauncherError::DownloadError);
} }
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap(); terminate("download", download_process, Duration::from_secs(10)).unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process.wait().unwrap();
tracing::info!("Download process terminated");
return Ok(()); return Ok(());
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
...@@ -844,7 +841,7 @@ fn spawn_webserver( ...@@ -844,7 +841,7 @@ fn spawn_webserver(
// All shard started // All shard started
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut argv = vec![ let mut router_args = vec![
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(), args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(), "--max-best-of".to_string(),
...@@ -877,24 +874,24 @@ fn spawn_webserver( ...@@ -877,24 +874,24 @@ fn spawn_webserver(
// Model optional revision // Model optional revision
if let Some(ref revision) = args.revision { if let Some(ref revision) = args.revision {
argv.push("--revision".to_string()); router_args.push("--revision".to_string());
argv.push(revision.to_string()) router_args.push(revision.to_string())
} }
if args.json_output { if args.json_output {
argv.push("--json-output".to_string()); router_args.push("--json-output".to_string());
} }
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = args.otlp_endpoint { if let Some(otlp_endpoint) = args.otlp_endpoint {
argv.push("--otlp-endpoint".to_string()); router_args.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint); router_args.push(otlp_endpoint);
} }
// CORS origins // CORS origins
for origin in args.cors_allow_origin.into_iter() { for origin in args.cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string()); router_args.push("--cors-allow-origin".to_string());
argv.push(origin); router_args.push(origin);
} }
// Ngrok // Ngrok
...@@ -904,34 +901,34 @@ fn spawn_webserver( ...@@ -904,34 +901,34 @@ fn spawn_webserver(
LauncherError::WebserverCannotStart LauncherError::WebserverCannotStart
})?; })?;
argv.push("--ngrok".to_string()); router_args.push("--ngrok".to_string());
argv.push("--ngrok-authtoken".to_string()); router_args.push("--ngrok-authtoken".to_string());
argv.push(authtoken); router_args.push(authtoken);
if let Some(domain) = args.ngrok_domain { if let Some(domain) = args.ngrok_domain {
argv.push("--ngrok-domain".to_string()); router_args.push("--ngrok-domain".to_string());
argv.push(domain); router_args.push(domain);
} }
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) { if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
argv.push("--ngrok-username".to_string()); router_args.push("--ngrok-username".to_string());
argv.push(username); router_args.push(username);
argv.push("--ngrok-password".to_string()); router_args.push("--ngrok-password".to_string());
argv.push(password); router_args.push(password);
} }
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
let mut webserver = match Command::new("text-generation-router") let mut webserver = match Command::new("text-generation-router")
.args(argv) .args(router_args)
.envs(env) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
...@@ -969,6 +966,31 @@ fn spawn_webserver( ...@@ -969,6 +966,31 @@ fn spawn_webserver(
Ok(webserver) Ok(webserver)
} }
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
tracing::info!("Terminating {process_name}");
let terminate_time = Instant::now();
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for {process_name} to gracefully shutdown");
while terminate_time.elapsed() < timeout {
if let Some(status) = process.try_wait()? {
tracing::info!("{process_name} terminated");
return Ok(status);
}
sleep(Duration::from_millis(100));
}
tracing::info!("Killing {process_name}");
process.kill()?;
let exit_status = process.wait()?;
tracing::info!("{process_name} killed");
Ok(exit_status)
}
fn main() -> Result<(), LauncherError> { fn main() -> Result<(), LauncherError> {
// Pattern match configuration // Pattern match configuration
let args = Args::parse(); let args = Args::parse();
...@@ -1038,6 +1060,11 @@ fn main() -> Result<(), LauncherError> { ...@@ -1038,6 +1060,11 @@ fn main() -> Result<(), LauncherError> {
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, running.clone())?; download_convert_model(&args, running.clone())?;
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
return Ok(());
}
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(AtomicBool::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel // Shared shutdown channel
...@@ -1096,10 +1123,7 @@ fn main() -> Result<(), LauncherError> { ...@@ -1096,10 +1123,7 @@ fn main() -> Result<(), LauncherError> {
} }
// Graceful termination // Graceful termination
signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap(); terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait().unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);
exit_code exit_code
......
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "0.9.1" version = "0.9.2"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment