Commit 81a882ad authored by jixx's avatar jixx
Browse files

add tgi2.4.0

parent 9822d7f6
......@@ -62,6 +62,7 @@ async def test_mamba_load(
)
assert len(responses) == 4
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
......
import pytest
import asyncio
@pytest.fixture(scope="module")
def mllama_handle(launcher):
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mllama(mllama_handle):
await mllama_handle.health(300)
return mllama_handle.client
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
assert response.usage == {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60,
}
assert (
response.choices[0].message.content
== "In a bustling city, a chicken named Cluck"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
futures = [
mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
for i in range(4)
]
responses = await asyncio.gather(*futures)
generated_texts = [response.choices[0].message.content for response in responses]
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot
import pytest
@pytest.fixture(scope="module")
def opt_sharded_handle(launcher):
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def opt_sharded(opt_sharded_handle):
await opt_sharded_handle.health(300)
return opt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_opt(opt_sharded):
pass
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
"meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle:
yield handle
......@@ -39,6 +38,7 @@ tools = [
},
},
"required": ["location", "format"],
"additionalProperties": False,
},
},
},
......@@ -65,13 +65,13 @@ tools = [
},
},
"required": ["location", "format", "num_days"],
"additionalProperties": False,
},
},
},
]
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
......@@ -79,7 +79,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
temperature=0.0,
messages=[
{
"role": "system",
......@@ -91,22 +91,21 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
......@@ -116,8 +115,8 @@ async def test_flash_llama_grammar_tools_auto(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
......@@ -129,15 +128,15 @@ async def test_flash_llama_grammar_tools_auto(
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
......@@ -145,7 +144,6 @@ async def test_flash_llama_grammar_tools_auto(
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
......@@ -155,8 +153,8 @@ async def test_flash_llama_grammar_tools_choice(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
......@@ -168,15 +166,15 @@ async def test_flash_llama_grammar_tools_choice(
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
......@@ -184,7 +182,6 @@ async def test_flash_llama_grammar_tools_choice(
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
......@@ -194,8 +191,8 @@ async def test_flash_llama_grammar_tools_stream(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
......@@ -210,14 +207,22 @@ async def test_flash_llama_grammar_tools_stream(
)
count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert count == 38
assert response == response_snapshot
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28
assert last_response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
......@@ -225,35 +230,100 @@ async def test_flash_llama_grammar_tools_insufficient_information(
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
"content": "Who are you?",
},
],
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert (
content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
)
assert last_response == response_snapshot
This source diff could not be displayed because it is too large. You can view the blob instead.
[tool.poetry]
name = "text-generation-integration-tests"
version = "2.0.1"
version = "2.4.0"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies]
pydantic = "> 2, < 3"
python = ">=3.9,<3.13"
syrupy = "4.0.1"
python = ">=3.10,<3.13"
syrupy = "^4.7.1"
text-generation = "^0.6.0"
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
docker = "^6.1.3"
docker = "^7"
numpy = "^1.20"
[tool.isort]
profile = "black"
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
colored==1.4.4 ; python_version >= "3.9" and python_version < "3.13"
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.0 ; python_version >= "3.10" and python_version < "3.13"
aiohttp==3.10.5 ; python_version >= "3.10" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
docker==7.1.0 ; python_version >= "3.10" and python_version < "3.13"
exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
filelock==3.16.0 ; python_version >= "3.10" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "3.13"
fsspec==2024.9.0 ; python_version >= "3.10" and python_version < "3.13"
huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "3.13"
multidict==6.1.0 ; python_version >= "3.10" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.10" and python_version < "3.13"
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "3.13"
pydantic-core==2.23.3 ; python_version >= "3.10" and python_version < "3.13"
pydantic==2.9.1 ; python_version >= "3.10" and python_version < "3.13"
pytest-asyncio==0.21.2 ; python_version >= "3.10" and python_version < "3.13"
pytest==7.4.4 ; python_version >= "3.10" and python_version < "3.13"
pywin32==306 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13"
syrupy==4.7.1 ; python_version >= "3.10" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.10" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
tqdm==4.66.5 ; python_version >= "3.10" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.11.1 ; python_version >= "3.10" and python_version < "3.13"
......@@ -12,11 +12,13 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0"
pyo3 = { workspace = true }
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
thiserror = "1.0.59"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
regex = "1.11.0"
[dev-dependencies]
float_eq = "1.0.1"
......
pub fn get_cuda_capability() -> Option<(usize, usize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
let torch = py.import_bound("torch.cuda")?;
let get_device_capability = torch.getattr("get_device_capability")?;
get_device_capability.call0()?.extract()
};
match pyo3::Python::with_gil(py_get_capability) {
Ok((major, minor)) if major < 0 || minor < 0 => {
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
None
}
Ok((major, minor)) => Some((major as usize, minor as usize)),
Err(err) => {
tracing::warn!("Cannot determine GPU compute capability: {}", err);
None
}
}
}
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use hf_hub::{
api::sync::{Api, ApiBuilder},
Repo, RepoType,
};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines};
use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
......@@ -15,22 +19,153 @@ use std::sync::{mpsc, Arc};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use std::{
fs, io,
io::{Read, Write},
};
use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
mod gpu;
fn get_config(
model_id: &str,
revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
let mut path = std::path::Path::new(model_id).to_path_buf();
let model_id = model_id.to_string();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
"paged"
} else {
"flashdecoding"
};
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing attention to '{fallback_attention}' because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some(fallback_attention.to_string());
}
if fallback_attention == "paged" && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
prefix_caching = Some("0".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
_ => {
if attention.is_none() {
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some(fallback_attention.to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
n_positions: Option<usize>,
model_type: Option<String>,
max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
}
#[derive(Deserialize)]
struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
}
impl From<RawConfig> for Config {
......@@ -39,13 +174,39 @@ impl From<RawConfig> for Config {
.max_position_embeddings
.or(other.max_seq_len)
.or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
(Some(hidden_size), _, Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
// Legacy
(_, Some(hidden_size), Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
_ => None,
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model:
/// <https://hf.co/models?search=awq>.
......@@ -68,17 +229,17 @@ enum Quantization {
Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16.
#[deprecated(
since = "1.1.0",
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
)]
// #[deprecated(
// since = "1.1.0",
// note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
// )]
Bitsandbytes,
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
/// but it is known that the model will be much slower to run than the native f16.
BitsandbytesNF4,
BitsandbytesNf4,
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
/// perplexity performance for you model
BitsandbytesFP4,
BitsandbytesFp4,
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
/// This dtype has native ops should be the fastest if available.
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
......@@ -95,10 +256,10 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::BitsandbytesNF4 => {
Quantization::BitsandbytesNf4 => {
write!(f, "bitsandbytes-nf4")
}
Quantization::BitsandbytesFP4 => {
Quantization::BitsandbytesFp4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Exl2 => {
......@@ -144,6 +305,28 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e4m3fn")]
Fp8e4m3fn,
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e4m3fn => {
write!(f, "fp8_e4m3fn")
}
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
......@@ -164,6 +347,33 @@ impl std::fmt::Display for RopeScaling {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
pub enum UsageStatsLevel {
/// Default option, usage statistics are collected anonymously
On,
/// Disables all collection of usage statistics
Off,
/// Doesn't send the error stack trace or error type, but allows sending a crash event
NoStack,
}
impl std::fmt::Display for UsageStatsLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
UsageStatsLevel::On => {
write!(f, "on")
}
UsageStatsLevel::Off => {
write!(f, "off")
}
UsageStatsLevel::NoStack => {
write!(f, "no-stack")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
......@@ -199,7 +409,11 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Whether you want the model to be quantized.
/// Quantization method to use for the model. It is not necessary to specify this option
/// for pre-quantized models, since the quantization method is read from the model
/// configuration.
///
/// Marlin kernels will be used automatically for GPTQ/AWQ models.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
......@@ -214,6 +428,12 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
......@@ -418,6 +638,10 @@ struct Args {
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
......@@ -457,6 +681,12 @@ struct Args {
/// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)]
lora_adapters: Option<String>,
/// Control if anonymous usage stats are collected.
/// Options are "on", "off" and "no-stack"
/// Defaul is on.
#[clap(default_value = "on", long, env)]
usage_stats: UsageStatsLevel,
}
#[derive(Debug)]
......@@ -472,6 +702,7 @@ fn shard_manager(
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
kv_cache_dtype: Option<KVCacheDtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
......@@ -545,6 +776,11 @@ fn shard_manager(
shard_args.push(dtype.to_string())
}
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_args.push("--revision".to_string());
......@@ -680,6 +916,7 @@ fn shard_manager(
.args(shard_args)
.env_clear()
.envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
......@@ -701,12 +938,13 @@ fn shard_manager(
};
// Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
log_lines(shard_stdout_reader);
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
......@@ -715,6 +953,20 @@ fn shard_manager(
err_sender.send(line).unwrap_or(());
}
});
// We read stdin in another thread as it seems that lines() can block in some cases
if LevelFilter::current() >= tracing::Level::DEBUG {
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
}
let mut ready = false;
let start_time = Instant::now();
......@@ -821,19 +1073,40 @@ impl PythonLogMessage {
}
}
impl TryFrom<&String> for PythonLogMessage {
impl TryFrom<&[u8]> for PythonLogMessage {
type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value)
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice::<Self>(value)
}
}
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
for line in lines.map_while(Result::ok) {
match PythonLogMessage::try_from(&line) {
Ok(log) => log.trace(),
Err(_) => tracing::debug!("{line}"),
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
let mut buffer = vec![0u8; 8 * 4096];
let mut stdout = std::io::stdout();
loop {
let n = bufread.read(&mut buffer);
if let Ok(n) = n {
if n > 0 {
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
while let Some(line) = lines.next() {
match PythonLogMessage::try_from(line) {
Ok(log) => log.trace(),
// For interactive debugging ?
Err(_) => {
if LevelFilter::current() >= tracing::Level::DEBUG {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
}
stdout.flush().unwrap();
}
}
}
}
} else {
break;
}
}
}
}
......@@ -993,7 +1266,7 @@ fn download_convert_model(
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || {
log_lines(download_stdout.lines());
log_lines(download_stdout);
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
......@@ -1044,6 +1317,7 @@ fn spawn_shards(
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
quantize: Option<Quantization>,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
......@@ -1065,9 +1339,9 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
......@@ -1086,6 +1360,7 @@ fn spawn_shards(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
rank,
......@@ -1201,6 +1476,10 @@ fn spawn_webserver(
args.model_id,
];
// Pass usage stats flags to router
router_args.push("--usage-stats".to_string());
router_args.push(args.usage_stats.to_string());
// Grammar support
if args.disable_grammar_support {
router_args.push("--disable-grammar-support".to_string());
......@@ -1230,6 +1509,10 @@ fn spawn_webserver(
router_args.push(revision.to_string())
}
if args.trust_remote_code {
router_args.push("--trust-remote-code".to_string());
}
if args.json_output {
router_args.push("--json-output".to_string());
}
......@@ -1251,6 +1534,11 @@ fn spawn_webserver(
router_args.push(origin);
}
// API Key
if let Some(api_key) = args.api_key {
router_args.push("--api-key".to_string());
router_args.push(api_key);
}
// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
......@@ -1379,34 +1667,12 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = Api::new()?;
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
......@@ -1416,17 +1682,20 @@ fn main() -> Result<(), LauncherError> {
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok(max_default)
max_default
} else {
Ok(max_position_embeddings)
max_position_embeddings
}
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
max_default
}
} else {
max_default
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
......@@ -1476,34 +1745,34 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
}
let quantize = args.quantize.or(quantize);
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)]
(
None,
Some(
Quantization::Bitsandbytes
| Quantization::BitsandbytesNF4
| Quantization::BitsandbytesFP4,
| Quantization::BitsandbytesNf4
| Quantization::BitsandbytesFp4,
),
) => {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
// let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
// tracing::info!("Using default cuda graphs {cuda_graphs:?}");
// cuda_graphs
tracing::info!("Currently disable cuda graphs by default,may enable in the future");
(None, Some(Quantization::Exl2)) => {
tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs
}
};
if args.validation_workers == 0 {
......@@ -1529,12 +1798,6 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
......@@ -1578,14 +1841,41 @@ fn main() -> Result<(), LauncherError> {
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
// skip download if a path is provided
if adapter.contains('=') {
continue;
}
let adapter = adapter.trim();
// check if adapter has more than 1 '@'
if adapter.matches('@').count() > 1 {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
// capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
if let Some(caps) = re.captures(adapter) {
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
let revision = caps.get(3).map(|m| m.as_str());
download_convert_model(
adapter_id,
revision,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
}
}
......@@ -1609,6 +1899,7 @@ fn main() -> Result<(), LauncherError> {
cuda_graphs,
max_total_tokens,
max_input_tokens,
quantize,
max_log_level,
shutdown.clone(),
&shutdown_receiver,
......@@ -1633,9 +1924,8 @@ fn main() -> Result<(), LauncherError> {
shutdown.clone(),
&shutdown_receiver,
)
.map_err(|err| {
.inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code
......
......@@ -33,13 +33,13 @@ export function get_options() {
// rate: 20,
// timeUnit: '1s',
// },
load_test: {
executor: 'constant-arrival-rate',
duration: '60s',
preAllocatedVUs: 100,
rate: 1,
timeUnit: '1s',
},
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300,
......@@ -47,12 +47,12 @@ export function get_options() {
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
// throughput: {
// executor: 'shared-iterations',
// vus: 100,
// iterations: 200,
// maxDuration: '40s',
// },
throughput: {
executor: 'shared-iterations',
vus: 100,
iterations: 200,
maxDuration: '40s',
},
},
};
}
......
......@@ -20,7 +20,7 @@ def main():
break
with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4)
json.dump(conversations, f, indent=4)
if __name__ == "__main__":
......
{
buildPythonPackage,
poetry-core,
huggingface-hub,
pydantic,
}:
buildPythonPackage {
name = "text-generation";
src = ../clients/python;
pyproject = true;
build-system = [ poetry-core ];
dependencies = [
huggingface-hub
pydantic
];
}
{ pkgs, nix-filter }:
let
filter = nix-filter.lib;
in
with pkgs;
defaultCrateOverrides
// {
aws-lc-rs = attrs: {
# aws-lc-rs does its own custom parsing of Cargo environment
# variables like DEP_.*_INCLUDE. However buildRustCrate does
# not use the version number, so the parsing fails.
postPatch = ''
substituteInPlace build.rs \
--replace-fail \
"assert!(!selected.is_empty()" \
"// assert!(!selected.is_empty()"
'';
};
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
grpc-metadata = attrs: {
src = filter {
root = ../backends/grpc-metadata;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
pyo3-build-config = attrs: {
buildInputs = [ python3 ];
};
text-generation-benchmark = attrs: {
src = filter {
root = ../benchmark;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-client = attrs: {
src = filter {
root = ../.;
include = with filter; [
isDirectory
(and (inDirectory "backends/client") (matchExt "rs"))
(and (inDirectory "proto") (matchExt "proto"))
];
};
postPatch = "cd backends/client";
buildInputs = [ protobuf ];
};
text-generation-launcher = attrs: {
src = filter {
root = ../launcher;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-router = attrs: {
src = filter {
root = ../router;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-router-v3 = attrs: {
# We need to do the src/source root dance so that the build
# has access to the protobuf file.
src = filter {
root = ../.;
include = with filter; [
isDirectory
(and (inDirectory "backends/v3") (matchExt "rs"))
(and (inDirectory "proto") (matchExt "proto"))
];
};
postPatch = "cd backends/v3";
buildInputs = [ protobuf ];
};
}
{
dockerTools,
cacert,
text-generation-inference,
stream ? false,
}:
let
build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;
in
build {
name = "tgi-docker";
tag = "latest";
config = {
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
Env = [
"HF_HOME=/data"
"PORT=80"
];
};
contents = [ cacert ];
}
{
lib,
mkShell,
black,
cmake,
isort,
ninja,
which,
cudaPackages,
openssl,
pkg-config,
protobuf,
python3,
pyright,
redocly,
ruff,
rust-bin,
server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}:
mkShell {
nativeBuildInputs =
[
black
isort
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
pyright
redocly
ruff
]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pytest
pytest-asyncio
syrupy
])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
}
final: prev: {
# You can use this overlay to temporarily override packages for
# development. For permanent overrides, it's better to do this in
# our package flake:
#
# https://github.com/huggingface/text-generation-inference-nix
#
# Note that overriding packages that are in the transitive closure
# of many other packages (e.g. transformers) will require a large
# rebuild.
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
(
python-self: python-super: with python-self; {
# Python package override example:
# transformers = python-super.transformers.overrideAttrs (
# _: _: {
# src = final.fetchFromGitHub {
# owner = "huggingface";
# repo = "transformers";
# rev = "2bd4d5897dc73e8b172832070a6f9e567a0df017";
# hash = "sha256-JOIpKH9ssDEfI2Tf15e0iPKtThJwQ9GxMvRAnm+M2Pg=";
# };
# }
# );
}
)
];
# Non-python package override example:
#
# ripgrep = prev.ripgrep.overrideAttrs (
# _: _: {
# src = final.fetchFromGitHub {
# owner = "BurntSushi";
# repo = "ripgrep";
# rev = "79cbe89deb1151e703f4d91b19af9cdcc128b765";
# hash = "sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=";
# };
# });
}
{
nix-filter,
buildPythonPackage,
poetry-core,
mypy-protobuf,
awq-inference-engine,
causal-conv1d,
eetq,
einops,
exllamav2,
flashinfer,
flash-attn,
flash-attn-layer-norm,
flash-attn-rotary,
flash-attn-v1,
grpc-interceptor,
grpcio-reflection,
grpcio-status,
grpcio-tools,
hf-transfer,
loguru,
mamba-ssm,
marlin-kernels,
moe-kernels,
opentelemetry-api,
opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc,
opentelemetry-semantic-conventions,
peft,
punica-kernels,
safetensors,
tokenizers,
torch,
sentencepiece,
transformers,
typer,
vllm,
}:
let
filter = nix-filter.lib;
in
buildPythonPackage {
name = "text-generation-server";
src = filter {
root = ../.;
include = with filter; [
isDirectory
(and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi")))
"server/pyproject.toml"
(and (inDirectory "proto/v3") (matchExt "proto"))
];
};
pyproject = true;
build-system = [ poetry-core ];
nativeBuildInputs = [ mypy-protobuf ];
pythonRelaxDeps = [
"einops"
"huggingface-hub"
"loguru"
"opentelemetry-instrumentation-grpc"
"sentencepiece"
"typer"
];
pythonRemoveDeps = [ "scipy" ];
dependencies = [
awq-inference-engine
eetq
causal-conv1d
einops
exllamav2
flashinfer
flash-attn
flash-attn-layer-norm
flash-attn-rotary
grpc-interceptor
grpcio-reflection
grpcio-status
grpcio-tools
hf-transfer
loguru
mamba-ssm
marlin-kernels
moe-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions
peft
punica-kernels
safetensors
sentencepiece
tokenizers
transformers
typer
vllm
];
prePatch = ''
python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \
--grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto
find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch server/text_generation_server/pb/__init__.py
cd server
'';
}
......@@ -3,22 +3,23 @@ syntax = "proto3";
package generate.v3;
service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
/// Model Info
rpc Info(InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery(ServiceDiscoveryRequest)
returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup(WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill(PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode(DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health(HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
......@@ -28,240 +29,255 @@ message HealthResponse {}
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
}
/// Empty request
message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse {
/// Other shards urls
repeated string urls = 1;
/// Other shards urls
repeated string urls = 1;
}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}
message Image {
/// Binary image data.
bytes data = 1;
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message Input {
repeated InputChunk chunks = 1;
}
message Input { repeated InputChunk chunks = 1; }
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Request ID
uint64 id = 1;
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation
bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// tokens
repeated string texts = 3;
/// special
repeated bool is_special = 4;
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// tokens
repeated string texts = 3;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
/// Filtered Batch (cached)
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
/// Cached batches
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment