Unverified Commit c58a0c18 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

v0.9.2 (#616)

parent 5b9de4a1
This diff is collapsed.
......@@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "0.9.1"
version = "0.9.2"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
......
......@@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.9.1"
"version": "0.9.2"
},
"paths": {
"/": {
......
......@@ -7,7 +7,7 @@ use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, Stdio};
use std::process::{Child, Command, ExitStatus, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
use std::sync::{mpsc, Arc};
......@@ -319,7 +319,7 @@ fn shard_manager(
}
// Process args
let mut shard_argv = vec![
let mut shard_args = vec![
"serve".to_string(),
model_id,
"--uds-path".to_string(),
......@@ -331,77 +331,77 @@ fn shard_manager(
// Activate trust remote code
if trust_remote_code {
shard_argv.push("--trust-remote-code".to_string());
shard_args.push("--trust-remote-code".to_string());
}
// Activate tensor parallelism
if world_size > 1 {
shard_argv.push("--sharded".to_string());
shard_args.push("--sharded".to_string());
}
if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string());
shard_argv.push(quantize.to_string())
shard_args.push("--quantize".to_string());
shard_args.push(quantize.to_string())
}
if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string());
shard_argv.push(dtype.to_string())
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_argv.push("--revision".to_string());
shard_argv.push(revision)
shard_args.push("--revision".to_string());
shard_args.push(revision)
}
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint);
shard_args.push("--otlp-endpoint".to_string());
shard_args.push(otlp_endpoint);
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
envs.push(("MASTER_ADDR".into(), master_addr.into()));
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push((
envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(),
));
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = weights_cache_override {
env.push((
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
......@@ -409,24 +409,24 @@ fn shard_manager(
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process
tracing::info!("Starting shard {rank}");
let mut p = match Command::new("text-generation-server")
.args(shard_argv)
.envs(env)
.args(shard_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
......@@ -632,7 +632,7 @@ enum LauncherError {
}
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
let mut download_argv = vec![
let mut download_args = vec![
"download-weights".to_string(),
args.model_id.to_string(),
"--extension".to_string(),
......@@ -644,35 +644,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Model optional revision
if let Some(revision) = &args.revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
download_args.push("--revision".to_string());
download_args.push(revision.to_string())
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push((
envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(),
));
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override {
env.push((
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
......@@ -681,8 +681,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Start process
tracing::info!("Starting download process.");
let mut download_process = match Command::new("text-generation-server")
.args(download_argv)
.envs(env)
.args(download_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
......@@ -738,10 +738,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
return Err(LauncherError::DownloadError);
}
if !running.load(Ordering::SeqCst) {
signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process.wait().unwrap();
tracing::info!("Download process terminated");
terminate("download", download_process, Duration::from_secs(10)).unwrap();
return Ok(());
}
sleep(Duration::from_millis(100));
......@@ -844,7 +841,7 @@ fn spawn_webserver(
// All shard started
// Start webserver
tracing::info!("Starting Webserver");
let mut argv = vec![
let mut router_args = vec![
"--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
......@@ -877,24 +874,24 @@ fn spawn_webserver(
// Model optional revision
if let Some(ref revision) = args.revision {
argv.push("--revision".to_string());
argv.push(revision.to_string())
router_args.push("--revision".to_string());
router_args.push(revision.to_string())
}
if args.json_output {
argv.push("--json-output".to_string());
router_args.push("--json-output".to_string());
}
// OpenTelemetry
if let Some(otlp_endpoint) = args.otlp_endpoint {
argv.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint);
router_args.push("--otlp-endpoint".to_string());
router_args.push(otlp_endpoint);
}
// CORS origins
for origin in args.cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
router_args.push("--cors-allow-origin".to_string());
router_args.push(origin);
}
// Ngrok
......@@ -904,34 +901,34 @@ fn spawn_webserver(
LauncherError::WebserverCannotStart
})?;
argv.push("--ngrok".to_string());
argv.push("--ngrok-authtoken".to_string());
argv.push(authtoken);
router_args.push("--ngrok".to_string());
router_args.push("--ngrok-authtoken".to_string());
router_args.push(authtoken);
if let Some(domain) = args.ngrok_domain {
argv.push("--ngrok-domain".to_string());
argv.push(domain);
router_args.push("--ngrok-domain".to_string());
router_args.push(domain);
}
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
argv.push("--ngrok-username".to_string());
argv.push(username);
argv.push("--ngrok-password".to_string());
argv.push(password);
router_args.push("--ngrok-username".to_string());
router_args.push(username);
router_args.push("--ngrok-password".to_string());
router_args.push(password);
}
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
let mut webserver = match Command::new("text-generation-router")
.args(argv)
.envs(env)
.args(router_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
......@@ -969,6 +966,31 @@ fn spawn_webserver(
Ok(webserver)
}
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
tracing::info!("Terminating {process_name}");
let terminate_time = Instant::now();
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for {process_name} to gracefully shutdown");
while terminate_time.elapsed() < timeout {
if let Some(status) = process.try_wait()? {
tracing::info!("{process_name} terminated");
return Ok(status);
}
sleep(Duration::from_millis(100));
}
tracing::info!("Killing {process_name}");
process.kill()?;
let exit_status = process.wait()?;
tracing::info!("{process_name} killed");
Ok(exit_status)
}
fn main() -> Result<(), LauncherError> {
// Pattern match configuration
let args = Args::parse();
......@@ -1038,6 +1060,11 @@ fn main() -> Result<(), LauncherError> {
// Download and convert model weights
download_convert_model(&args, running.clone())?;
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
return Ok(());
}
// Shared shutdown bool
let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel
......@@ -1096,10 +1123,7 @@ fn main() -> Result<(), LauncherError> {
}
// Graceful termination
signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait().unwrap();
tracing::info!("Webserver terminated");
terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
......
[tool.poetry]
name = "text-generation-server"
version = "0.9.1"
version = "0.9.2"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment