Commit efd602c8 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

last

parent f1b779fc
......@@ -26,6 +26,7 @@ async def flash_llava_next(flash_llava_next_handle):
return flash_llava_next_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
......@@ -41,6 +42,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
......@@ -64,6 +66,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_load(
......
......@@ -13,6 +13,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
return fused_kernel_mamba_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
......@@ -24,6 +25,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
......@@ -50,6 +52,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot
......
......@@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle):
return mpt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_mpt(mpt_sharded, response_snapshot):
response = await mpt_sharded.generate(
......@@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
responses = await generate_load(
......
......@@ -13,6 +13,7 @@ async def mt0_base(mt0_base_handle):
return mt0_base_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_mt0_base(mt0_base, response_snapshot):
response = await mt0_base.generate(
......@@ -27,6 +28,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mt0_base_all_params(mt0_base, response_snapshot):
response = await mt0_base.generate(
......@@ -49,6 +51,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
responses = await generate_load(
......
......@@ -15,6 +15,7 @@ async def neox(neox_handle):
return neox_handle.client
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox, response_snapshot):
......@@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot):
......
......@@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle):
return neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
......@@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
......
......@@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle):
return t5_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_t5_sharded(t5_sharded, response_snapshot):
response = await t5_sharded.generate(
......@@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
responses = await generate_load(
......
......@@ -14,6 +14,7 @@ nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0"
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
thiserror = "1.0.59"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
......
......@@ -16,16 +16,35 @@ use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use tracing_subscriber::EnvFilter;
use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
#[derive(Deserialize)]
struct Config {
struct RawConfig {
max_position_embeddings: Option<usize>,
n_positions: Option<usize>,
max_seq_len: Option<usize>,
}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
}
impl From<RawConfig> for Config {
fn from(other: RawConfig) -> Self {
let max_position_embeddings = other
.max_position_embeddings
.or(other.max_seq_len)
.or(other.n_positions);
Config {
max_position_embeddings,
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model:
......@@ -36,11 +55,17 @@ enum Quantization {
/// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq,
/// Variable bit quantization. Requires a specific EXL2 quantized model:
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
/// not support tensor parallelism (num_shard > 1).
Exl2,
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not.
/// AWQ has faster kernels.
Gptq,
/// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16.
#[deprecated(
......@@ -76,9 +101,15 @@ impl std::fmt::Display for Quantization {
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Exl2 => {
write!(f, "exl2")
}
Quantization::Gptq => {
write!(f, "gptq")
}
Quantization::Marlin => {
write!(f, "marlin")
}
Quantization::Awq => {
write!(f, "awq")
}
......@@ -210,7 +241,7 @@ struct Args {
max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely
/// `top_n_tokens` is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or
/// ranking.
......@@ -382,6 +413,9 @@ struct Args {
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
......@@ -418,6 +452,11 @@ struct Args {
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
/// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)]
lora_adapters: Option<String>,
}
#[derive(Debug)]
......@@ -450,7 +489,11 @@ fn shard_manager(
rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
lora_adapters: Option<String>,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>,
......@@ -473,7 +516,7 @@ fn shard_manager(
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
"INFO".to_string(),
log_level.to_string().to_uppercase(),
"--json-output".to_string(),
];
......@@ -515,12 +558,20 @@ fn shard_manager(
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
// OpenTelemetry Endpoint
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
shard_args.push(otlp_endpoint);
}
// OpenTelemetry Service Name
shard_args.push("--otlp-service-name".to_string());
shard_args.push(otlp_service_name);
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
......@@ -555,7 +606,7 @@ fn shard_manager(
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// Detect rope scaling
......@@ -575,6 +626,11 @@ fn shard_manager(
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// Lora Adapters
if let Some(lora_adapters) = lora_adapters {
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
......@@ -714,7 +770,10 @@ fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver
fn num_cuda_devices() -> Option<usize> {
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
},
};
let n_devices = devices.split(',').count();
Some(n_devices)
......@@ -751,13 +810,13 @@ struct PythonLogMessage {
impl PythonLogMessage {
fn trace(&self) {
match self.record.level.name {
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
}
}
}
......@@ -787,9 +846,9 @@ fn find_num_shards(
let num_shard = match (sharded, num_shard) {
(Some(true), None) => {
// try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
if n_devices <= 1 {
return Err(LauncherError::NotEnoughCUDADevices(format!(
"`sharded` is true but only found {n_devices} CUDA devices"
......@@ -819,33 +878,40 @@ fn find_num_shards(
Ok(num_shard)
}
#[derive(Debug)]
#[derive(Debug, Error)]
enum LauncherError {
#[error("Invalid argument: {0}")]
ArgumentValidation(String),
#[error("not enough cuda devices: {0}")]
NotEnoughCUDADevices(String),
#[error("Download error")]
DownloadError,
#[error("Shard cannot start")]
ShardCannotStart,
#[error("Shard disconnected")]
ShardDisconnected,
#[error("Shard failed")]
ShardFailed,
#[error("Webserver failed")]
WebserverFailed,
#[error("Webserver cannot start")]
WebserverCannotStart,
}
impl core::fmt::Display for LauncherError {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for LauncherError {}
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
fn download_convert_model(
model_id: &str,
revision: Option<&str>,
trust_remote_code: bool,
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![
"download-weights".to_string(),
args.model_id.to_string(),
model_id.to_string(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
......@@ -854,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
];
// Model optional revision
if let Some(revision) = &args.revision {
if let Some(revision) = &revision {
download_args.push("--revision".to_string());
download_args.push(revision.to_string())
}
// Trust remote code for automatic peft fusion
if args.trust_remote_code {
if trust_remote_code {
download_args.push("--trust-remote-code".to_string());
}
......@@ -875,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
......@@ -888,12 +954,12 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override {
if let Some(weights_cache_override) = &weights_cache_override {
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
......@@ -901,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
};
// Start process
tracing::info!("Starting download process.");
tracing::info!("Starting check and download process for {model_id}");
let mut download_process = match Command::new("text-generation-server")
.args(download_args)
.env_clear()
......@@ -943,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
loop {
if let Some(status) = download_process.try_wait().unwrap() {
if status.success() {
tracing::info!("Successfully downloaded weights.");
tracing::info!("Successfully downloaded weights for {model_id}");
break;
}
......@@ -977,6 +1043,8 @@ fn spawn_shards(
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>,
......@@ -996,6 +1064,7 @@ fn spawn_shards(
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
......@@ -1009,6 +1078,7 @@ fn spawn_shards(
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone();
thread::spawn(move || {
shard_manager(
model_id,
......@@ -1033,7 +1103,11 @@ fn spawn_shards(
rope_factor,
max_total_tokens,
max_batch_size,
max_input_tokens,
lora_adapters,
otlp_endpoint,
otlp_service_name,
max_log_level,
status_sender,
shutdown,
shutdown_sender,
......@@ -1166,6 +1240,11 @@ fn spawn_webserver(
router_args.push(otlp_endpoint);
}
// OpenTelemetry
let otlp_service_name = args.otlp_service_name;
router_args.push("--otlp-service-name".to_string());
router_args.push(otlp_service_name);
// CORS origins
for origin in args.cors_allow_origin.into_iter() {
router_args.push("--cors-allow-origin".to_string());
......@@ -1186,7 +1265,7 @@ fn spawn_webserver(
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// Parse Compute type
......@@ -1264,8 +1343,22 @@ fn main() -> Result<(), LauncherError> {
let args: Args = Args::parse();
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
if args.json_output {
tracing_subscriber::fmt()
......@@ -1308,33 +1401,30 @@ fn main() -> Result<(), LauncherError> {
};
let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
} else {
max_position_embeddings
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok(max_default)
} else {
Ok(max_position_embeddings)
}
_ => {
return Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)));
}
};
Ok(max_position_embeddings)
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
}
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
......@@ -1430,6 +1520,11 @@ fn main() -> Result<(), LauncherError> {
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes");
}
......@@ -1471,7 +1566,28 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler");
// Download and convert model weights
download_convert_model(&args, running.clone())?;
download_convert_model(
&args.model_id,
args.revision.as_deref(),
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
}
}
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
......@@ -1492,6 +1608,8 @@ fn main() -> Result<(), LauncherError> {
&args,
cuda_graphs,
max_total_tokens,
max_input_tokens,
max_log_level,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
......
ShareGPT_V3_unfiltered_cleaned_split.json:
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
prepare_share: ShareGPT_V3_unfiltered_cleaned_split.json
python filter.py
prepare_orca:
python orca.py
import json
def main():
with open("./ShareGPT_V3_unfiltered_cleaned_split.json", "r") as f:
data = json.load(f)
# Select only the first 2k conversations that start with a human.
max = 2000
conversations = []
for conversation in data:
conv = conversation.get("conversations")
if conv and conv[0]["from"] == "human":
# Trim the rest of the output
conversation["conversations"] = conversation["conversations"][:1]
conversations.append(conversation)
if len(conversation) >= max:
break
with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4)
if __name__ == "__main__":
main()
import json
import datasets
import tqdm
def main():
dataset = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
# Select only the first 2k conversations that start with a human.
max = min(2000, len(dataset))
conversations = []
for item in tqdm.tqdm(dataset, total=max):
conversation = {
"conversations": [
{"from": "human", "value": item["question"]},
],
"id": item["id"],
}
conversations.append(conversation)
if len(conversations) >= max:
break
with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4)
if __name__ == "__main__":
main()
syntax = "proto3";
package generate.v3;
service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
message HealthResponse {}
/// Empty request
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
}
/// Empty request
message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse {
/// Other shards urls
repeated string urls = 1;
}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}
message Image {
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message Input {
repeated InputChunk chunks = 1;
}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// tokens
repeated string texts = 3;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
}
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}
......@@ -16,15 +16,16 @@ path = "src/main.rs"
[dependencies]
async-stream = "0.3.5"
axum = { version = "0.6.20", features = ["json"] }
axum-tracing-opentelemetry = "0.14.1"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28"
hf-hub = { workspace = true }
itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
......@@ -36,20 +37,21 @@ thiserror = "1.0.48"
tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] }
tracing = "0.1.37"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.40"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = "0.22.0"
base64 = { workspace = true }
[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
......@@ -58,3 +60,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
default = ["ngrok"]
ngrok = ["dep:ngrok"]
google = []
kserve = []
......@@ -6,6 +6,8 @@ authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "^0.1"
base64 = { workspace = true }
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.12"
......
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(());
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/v2/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/pb")
.out_dir("src/v2/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.map_err(|e| match e.kind(){
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
e => {e}
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
fs::create_dir_all("src/v3/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v3/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
......
//! Text Generation gRPC client library
mod client;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
pub mod v2;
pub mod v3;
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
use thiserror::Error;
use tonic::transport;
use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
......@@ -43,4 +56,36 @@ impl From<transport::Error> for ClientError {
}
}
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<InputChunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c.chunk {
Some(Chunk::Text(text)) => output.push_str(text),
Some(Chunk::Image(Image { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
output
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;
/// Single shard Client
use crate::v2::pb;
use crate::{ClientError, Result};
use crate::WARMUP_IMAGE_BASE64;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v2::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v2 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
/// Multi shard Client
use crate::{v2, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v2::InfoResponse;
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use v2::client::{DecodeTimings, PrefillTimings};
use v2::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
self.clone().prefill(batch).await?;
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment