Commit 5a1cf2f0 authored by huangwb's avatar huangwb
Browse files

Merge tag 'v2.0.2' into dev-rocm

parents 24f58bb6 6073ece4
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.13000488,
"special": false,
"text": " A"
},
{
"id": 13088,
"logprob": -0.6713867,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.2980957,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.060638428,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.27319336,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.140625,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.040405273,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002708435,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.095336914,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.0068359375,
"special": false,
"text": "."
}
],
"top_tokens": null
},
"generated_text": " A chicken is sitting on a pile of money."
}
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher):
with launcher(
"HuggingFaceM4/idefics2-8b",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics2_next(flash_idefics2_next_handle):
await flash_idefics2_next_handle.health(300)
return flash_idefics2_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
)
assert (
response.generated_text == " A chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
response = await flash_idefics2_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == " A chicken is sitting on a pile of money."
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot
...@@ -7,14 +7,17 @@ pub(crate) struct Env { ...@@ -7,14 +7,17 @@ pub(crate) struct Env {
git_sha: &'static str, git_sha: &'static str,
docker_label: &'static str, docker_label: &'static str,
nvidia_env: String, nvidia_env: String,
xpu_env: String,
} }
impl Env { impl Env {
pub fn new() -> Self { pub fn new() -> Self {
let nvidia_env = nvidia_smi(); let nvidia_env = nvidia_smi();
let xpu_env = xpu_smi();
Self { Self {
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
cargo_version: env!("VERGEN_RUSTC_SEMVER"), cargo_version: env!("VERGEN_RUSTC_SEMVER"),
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
...@@ -31,7 +34,8 @@ impl fmt::Display for Env { ...@@ -31,7 +34,8 @@ impl fmt::Display for Env {
writeln!(f, "Cargo version: {}", self.cargo_version)?; writeln!(f, "Cargo version: {}", self.cargo_version)?;
writeln!(f, "Commit sha: {}", self.git_sha)?; writeln!(f, "Commit sha: {}", self.git_sha)?;
writeln!(f, "Docker label: {}", self.docker_label)?; writeln!(f, "Docker label: {}", self.docker_label)?;
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?; writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
Ok(()) Ok(())
} }
...@@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> { ...@@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
let output = nvidia_smi.replace('\n', "\n "); let output = nvidia_smi.replace('\n', "\n ");
Some(output.trim().to_string()) Some(output.trim().to_string())
} }
fn xpu_smi() -> Option<String> {
let output = Command::new("xpu-smi").arg("discovery").output().ok()?;
let xpu_smi = String::from_utf8(output.stdout).ok()?;
let output = xpu_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}
...@@ -251,7 +251,7 @@ struct Args { ...@@ -251,7 +251,7 @@ struct Args {
/// ///
/// This setting is only applied if there is room in the batch /// This setting is only applied if there is room in the batch
/// as defined by `max_batch_total_tokens`. /// as defined by `max_batch_total_tokens`.
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "0.3", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
/// Limits the number of tokens for the prefill operation. /// Limits the number of tokens for the prefill operation.
...@@ -448,6 +448,8 @@ fn shard_manager( ...@@ -448,6 +448,8 @@ fn shard_manager(
cuda_memory_fraction: f32, cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>, rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>, rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
...@@ -512,6 +514,7 @@ fn shard_manager( ...@@ -512,6 +514,7 @@ fn shard_manager(
(Some(scaling), Some(factor)) => Some((scaling, factor)), (Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
}; };
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string()); shard_args.push("--otlp-endpoint".to_string());
...@@ -564,6 +567,14 @@ fn shard_manager( ...@@ -564,6 +567,14 @@ fn shard_manager(
envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
} }
envs.push((
"MAX_TOTAL_TOKENS".into(),
max_total_tokens.to_string().into(),
));
if let Some(max_batch_size) = max_batch_size {
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
...@@ -672,9 +683,7 @@ fn shard_manager( ...@@ -672,9 +683,7 @@ fn shard_manager(
// We received a shutdown signal // We received a shutdown signal
if shutdown.load(Ordering::SeqCst) { if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap(); terminate("shard", p, Duration::from_secs(90)).unwrap();
let _ = p.wait();
tracing::info!("Shard terminated");
return; return;
} }
...@@ -967,6 +976,7 @@ fn spawn_shards( ...@@ -967,6 +976,7 @@ fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>, shutdown_sender: mpsc::Sender<()>,
...@@ -998,6 +1008,7 @@ fn spawn_shards( ...@@ -998,6 +1008,7 @@ fn spawn_shards(
let cuda_memory_fraction = args.cuda_memory_fraction; let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
...@@ -1020,6 +1031,8 @@ fn spawn_shards( ...@@ -1020,6 +1031,8 @@ fn spawn_shards(
cuda_memory_fraction, cuda_memory_fraction,
rope_scaling, rope_scaling,
rope_factor, rope_factor,
max_total_tokens,
max_batch_size,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
...@@ -1230,7 +1243,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R ...@@ -1230,7 +1243,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for {process_name} to gracefully shutdown"); tracing::info!("Waiting for {process_name} to gracefully shutdown");
while terminate_time.elapsed() < timeout { while terminate_time.elapsed() < timeout {
if let Some(status) = process.try_wait()? { if let Some(status) = process.try_wait()? {
tracing::info!("{process_name} terminated"); tracing::info!("{process_name} terminated");
...@@ -1238,7 +1250,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R ...@@ -1238,7 +1250,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
tracing::info!("Killing {process_name}"); tracing::info!("Killing {process_name}");
process.kill()?; process.kill()?;
...@@ -1273,7 +1284,7 @@ fn main() -> Result<(), LauncherError> { ...@@ -1273,7 +1284,7 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{}", env_runtime); tracing::info!("{}", env_runtime);
} }
tracing::info!("{:?}", args); tracing::info!("{:#?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> { let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
...@@ -1306,7 +1317,12 @@ fn main() -> Result<(), LauncherError> { ...@@ -1306,7 +1317,12 @@ fn main() -> Result<(), LauncherError> {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if max_position_embeddings > max_default { if max_position_embeddings > max_default {
let max = max_position_embeddings; let max = max_position_embeddings;
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default max_default
} else { } else {
max_position_embeddings max_position_embeddings
...@@ -1378,7 +1394,7 @@ fn main() -> Result<(), LauncherError> { ...@@ -1378,7 +1394,7 @@ fn main() -> Result<(), LauncherError> {
} }
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)] #[allow(deprecated)]
( (
None, None,
...@@ -1475,6 +1491,7 @@ fn main() -> Result<(), LauncherError> { ...@@ -1475,6 +1491,7 @@ fn main() -> Result<(), LauncherError> {
num_shard, num_shard,
&args, &args,
cuda_graphs, cuda_graphs,
max_total_tokens,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
shutdown_sender, shutdown_sender,
......
import { check, randomSeed } from 'k6'; import { check } from 'k6';
import { scenario } from 'k6/execution';
import http from 'k6/http'; import http from 'k6/http';
import { Trend, Counter } from 'k6/metrics'; import { Trend, Counter } from 'k6/metrics';
import { randomItem } from 'https://jslib.k6.io/k6-utils/1.2.0/index.js';
const seed = 0; const host = __ENV.HOST;
const model_id = __ENV.MODEL_ID;
const host = __ENV.HOST || '127.0.0.1:8000';
const timePerToken = new Trend('time_per_token', true); const timePerToken = new Trend('time_per_token', true);
const tokens = new Counter('tokens'); const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens'); const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens'); const input_tokens = new Counter('input_tokens');
const max_new_tokens = 50;
randomSeed(seed);
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json")) // const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
const shareGPT = JSON.parse(open("small.json")) const shareGPT = JSON.parse(open("small.json"))
export function get_options(reference_latency_ms){ export function get_options() {
return { return {
thresholds: { thresholds: {
http_req_failed: ['rate==0'], http_req_failed: ['rate==0'],
time_per_token: [{ // time_per_token: [{
threshold: `p(50)<${5 * reference_latency_ms}`, // threshold: `p(50)<${5 * reference_latency_ms}`,
abortOnFail: true, // abortOnFail: true,
delayAbortEval: '10s' // delayAbortEval: '10s'
}], // }],
}, },
scenarios: { scenarios: {
// single_user: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 1,
// rate: 20,
// timeUnit: '1s',
// },
load_test: { load_test: {
executor: 'constant-arrival-rate', executor: 'constant-arrival-rate',
duration: '60s', duration: '60s',
preAllocatedVUs: 10, preAllocatedVUs: 100,
rate: 10, rate: 1,
timeUnit: '1s', timeUnit: '1s',
}, },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300,
// stages: [
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
// throughput: {
// executor: 'shared-iterations',
// vus: 100,
// iterations: 200,
// maxDuration: '40s',
// },
}, },
}; };
} }
function generate_payload(gpt, max_new_tokens) {
const input = gpt["conversations"][0]["value"];
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
}
export const options = get_options();
export function run(host, generate_payload, max_new_tokens) { export default function run() {
const headers = {'Content-Type': 'application/json'}; const headers = { 'Content-Type': 'application/json' };
const query = randomItem(shareGPT); const query = shareGPT[scenario.iterationInTest % shareGPT.length];
const payload = JSON.stringify(generate_payload(query)); const payload = JSON.stringify(generate_payload(query, max_new_tokens));
const res = http.post(`http://${host}/generate`, payload, { const res = http.post(`http://${host}/v1/chat/completions`, payload, {
headers, headers,
}); });
if(res.status >= 400 && res.status < 500){ if (res.status >= 400 && res.status < 500) {
return; return;
} }
check(res, { check(res, {
'Post status is 200': (r) => res.status === 200, 'Post status is 200': (res) => res.status === 200,
}); });
const duration = res.timings.duration; const duration = res.timings.duration;
if (res.status === 200) { if (res.status === 200) {
const body = res.json(); const body = res.json();
const n_tokens = body.details.tokens.length; const completion_tokens = body.usage.completion_tokens;
const latency_ms_per_token = duration / n_tokens; const latency_ms_per_token = duration / completion_tokens;
timePerToken.add(latency_ms_per_token); timePerToken.add(latency_ms_per_token);
const latency_in_s = latency_ms_per_token / 1000; const prompt_tokens = body.usage.prompt_tokens;
const individual_throughput = 1 / latency_in_s; input_tokens.add(prompt_tokens);
const _input_tokens = body.details.prefill.length; new_tokens.add(completion_tokens);
tokens.add(n_tokens + _input_tokens); tokens.add(completion_tokens + prompt_tokens);
input_tokens.add(_input_tokens);
new_tokens.add(n_tokens);
} }
} }
import { get_options, run } from "./common.js";
const reference_latency_ms = 70;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
}
export const options = get_options(reference_latency_ms);
export default function(){
run(host, generate_payload, max_new_tokens);
}
import { get_options, run } from "./common.js";
const reference_latency_ms = 22;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"prompt": input, "temperature": 0.5, "ignore_eos": true}
}
export const options = get_options(reference_latency_ms);
export default function(){
run(host, generate_payload, max_new_tokens);
}
...@@ -114,8 +114,12 @@ impl Client { ...@@ -114,8 +114,12 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str("![](");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![]()");
}
requests.push(Request { requests.push(Request {
id: 0, id: 0,
......
...@@ -57,6 +57,31 @@ fn select_best_resolution( ...@@ -57,6 +57,31 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width)) best_fit.unwrap_or((original_height, original_width))
} }
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext { impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size; let image_size = self.vision_config.image_size;
...@@ -65,11 +90,9 @@ impl LlavaNext { ...@@ -65,11 +90,9 @@ impl LlavaNext {
let npatches = image_size / patch_size; let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) = let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width; let (unpadded_features, newline_features) =
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width; get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
// The base patch covers the entire image // The base patch covers the entire image
let base_features = npatches.pow(2); let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features unpadded_features + newline_features + base_features
...@@ -84,6 +107,17 @@ pub struct ClipVisionModel { ...@@ -84,6 +107,17 @@ pub struct ClipVisionModel {
patch_size: usize, patch_size: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -92,6 +126,7 @@ pub enum Config { ...@@ -92,6 +126,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics,
Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
Santacoder, Santacoder,
...@@ -146,13 +181,17 @@ mod test { ...@@ -146,13 +181,17 @@ mod test {
], ],
}; };
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640); let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928); assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640); let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340); assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024); let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732); assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899); let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320); assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
} }
} }
...@@ -73,9 +73,9 @@ pub struct HubTokenizerConfig { ...@@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file(filename: &std::path::Path) -> Self { pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).ok()
} }
} }
...@@ -116,6 +116,7 @@ mod token_serde { ...@@ -116,6 +116,7 @@ mod token_serde {
)) ))
} }
} }
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")), _ => Err(de::Error::custom("invalid token format")),
} }
} }
...@@ -168,9 +169,12 @@ pub struct Info { ...@@ -168,9 +169,12 @@ pub struct Info {
#[derive(Clone, Debug, Deserialize, ToSchema, Default)] #[derive(Clone, Debug, Deserialize, ToSchema, Default)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
/// Generate best_of sequences and return the one if the highest token logprobs.
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>, pub best_of: Option<usize>,
/// The value used to module the logits distribution.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
...@@ -179,6 +183,9 @@ pub(crate) struct GenerateParameters { ...@@ -179,6 +183,9 @@ pub(crate) struct GenerateParameters {
example = 0.5 example = 0.5
)] )]
pub temperature: Option<f32>, pub temperature: Option<f32>,
/// The parameter for repetition penalty. 1.0 means no penalty.
/// See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
...@@ -187,6 +194,10 @@ pub(crate) struct GenerateParameters { ...@@ -187,6 +194,10 @@ pub(crate) struct GenerateParameters {
example = 1.03 example = 1.03
)] )]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
/// The parameter for frequency penalty. 1.0 means no penalty
/// Penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = -2.0, exclusive_minimum = -2.0,
...@@ -195,9 +206,13 @@ pub(crate) struct GenerateParameters { ...@@ -195,9 +206,13 @@ pub(crate) struct GenerateParameters {
example = 0.1 example = 0.1
)] )]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>, pub top_k: Option<i32>,
/// Top-p value for nucleus sampling.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
...@@ -207,6 +222,9 @@ pub(crate) struct GenerateParameters { ...@@ -207,6 +222,9 @@ pub(crate) struct GenerateParameters {
example = 0.95 example = 0.95
)] )]
pub top_p: Option<f32>, pub top_p: Option<f32>,
/// Typical Decoding mass
/// See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
...@@ -216,30 +234,48 @@ pub(crate) struct GenerateParameters { ...@@ -216,30 +234,48 @@ pub(crate) struct GenerateParameters {
example = 0.95 example = 0.95
)] )]
pub typical_p: Option<f32>, pub typical_p: Option<f32>,
/// Activate logits sampling.
#[serde(default)] #[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
/// Maximum number of tokens to generate.
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
#[schema(nullable = true, default = "100", example = "20")] #[schema(nullable = true, default = "100", example = "20")]
pub max_new_tokens: Option<u32>, pub max_new_tokens: Option<u32>,
/// Whether to prepend the prompt to the generated text
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = false)] #[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>, pub return_full_text: Option<bool>,
/// Stop generating tokens if a member of `stop` is generated.
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))] #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
/// Truncate inputs tokens to the given size.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>, pub truncate: Option<usize>,
/// Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).
#[serde(default)] #[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub watermark: bool, pub watermark: bool,
/// Whether to return generation details.
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
/// Whether to return decoder input token logprobs and ids.
#[serde(default)] #[serde(default)]
#[schema(default = "false")] #[schema(default = "false")]
pub decoder_input_details: bool, pub decoder_input_details: bool,
/// Random sampling seed.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0, exclusive_minimum = 0,
...@@ -248,9 +284,13 @@ pub(crate) struct GenerateParameters { ...@@ -248,9 +284,13 @@ pub(crate) struct GenerateParameters {
example = "null" example = "null"
)] )]
pub seed: Option<u64>, pub seed: Option<u64>,
/// The number of highest probability vocabulary tokens to keep for top-n-filtering.
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
/// Grammar constraints for the generation.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>, pub grammar: Option<GrammarType>,
...@@ -549,7 +589,9 @@ pub(crate) struct ChatCompletionChoice { ...@@ -549,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, // TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub content: Option<String>, pub content: Option<String>,
...@@ -583,6 +625,31 @@ impl ChatCompletionChunk { ...@@ -583,6 +625,31 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
...@@ -591,19 +658,7 @@ impl ChatCompletionChunk { ...@@ -591,19 +658,7 @@ impl ChatCompletionChunk {
system_fingerprint, system_fingerprint,
choices: vec![ChatCompletionChoice { choices: vec![ChatCompletionChoice {
index: 0, index: 0,
delta: ChatCompletionDelta { delta,
role: "assistant".to_string(),
content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
},
logprobs, logprobs,
finish_reason, finish_reason,
}], }],
...@@ -829,12 +884,75 @@ pub(crate) struct ToolCall { ...@@ -829,12 +884,75 @@ pub(crate) struct ToolCall {
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Text {
#[serde(default)]
pub text: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ImageUrl {
#[serde(default)]
pub url: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Content {
pub r#type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>,
}
mod message_content_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Array(arr) => {
let results: Result<Vec<String>, _> = arr
.into_iter()
.map(|v| {
let content: Content =
serde_json::from_value(v).map_err(de::Error::custom)?;
match content.r#type.as_str() {
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("![]({})", url.url))
} else {
Ok(String::new())
}
}
_ => Err(de::Error::custom("invalid content type")),
}
})
.collect();
results.map(|strings| Some(strings.join("")))
}
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
......
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
...@@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig; ...@@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
...@@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> { ...@@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
...@@ -181,109 +180,113 @@ async fn main() -> Result<(), RouterError> { ...@@ -181,109 +180,113 @@ async fn main() -> Result<(), RouterError> {
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed // Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api { let api = if use_api {
tracing::info!("Using the Hugging Face API"); if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
match api_builder().build() { let cache = Cache::default();
Ok(api) => Some(api), tracing::warn!("Offline mode active using cache defaults");
Err(_) => { Type::Cache(cache)
tracing::warn!("Unable to build the Hugging Face API"); } else {
None tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
} }
} }
} else { } else {
None Type::None
}; };
// Load tokenizer and model info // Load tokenizer and model info
let (tokenizer, model_info, config) = if local_model { let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); Type::None => (
let model_info = HubModelInfo { Some(local_path.join("tokenizer.json")),
model_id: tokenizer_name.to_string(), Some(local_path.join("config.json")),
sha: None, Some(local_path.join("tokenizer_config.json")),
pipeline_tag: None, None,
}; ),
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json")) Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
None,
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok() .ok()
.as_ref() .as_ref()
.and_then(|c| serde_json::from_str(c).ok()); .and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
(tokenizer, model_info, config) if let Err(err) = &config {
} else if let Some(api) = api.clone() { tracing::warn!("Could not parse config {err:?}");
let api_repo = api.repo(Repo::with_revision( }
tokenizer_name.to_string(), config.ok()
RepoType::Model, })
revision.clone().unwrap_or_else(|| "main".to_string()), });
)); let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
let tokenizer = match api_repo.get("tokenizer.json").await { sha: None,
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), pipeline_tag: None,
Err(_) => get_base_tokenizer(&api, &api_repo).await, });
};
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
});
(tokenizer, model_info, config) // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else { } else {
// No API and no local model tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
return Err(RouterError::ArgumentValidation(
"No local model found and no revision specified".to_string(),
));
}; };
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
tracing::info!("Using config {config:?}"); tracing::info!("Using config {config:?}");
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else {
match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
}
};
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled"); tracing::warn!("Rust input length validation and truncation is disabled");
...@@ -480,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> { ...@@ -480,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
} }
/// get base tokenizer /// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> { pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?; let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer. // Open the file in read-only mode with buffer.
...@@ -497,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz ...@@ -497,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
"main".to_string(), "main".to_string(),
)); ));
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?; api_base_repo.get("tokenizer.json").await.ok()
Tokenizer::from_file(tokenizer_filename).ok()
} else { } else {
None None
} }
......
...@@ -1000,6 +1000,7 @@ async fn chat_completions( ...@@ -1000,6 +1000,7 @@ async fn chat_completions(
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
temperature,
.. ..
} = req; } = req;
...@@ -1008,6 +1009,11 @@ async fn chat_completions( ...@@ -1008,6 +1009,11 @@ async fn chat_completions(
let logprobs = logprobs.unwrap_or(false); let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default(); let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default(); let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
...@@ -1054,13 +1060,13 @@ async fn chat_completions( ...@@ -1054,13 +1060,13 @@ async fn chat_completions(
inputs: inputs.to_string(), inputs: inputs.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty: req.frequency_penalty, frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
do_sample: true, do_sample,
max_new_tokens, max_new_tokens,
return_full_text: None, return_full_text: None,
stop, stop,
...@@ -1097,7 +1103,13 @@ async fn chat_completions( ...@@ -1097,7 +1103,13 @@ async fn chat_completions(
let (content, tool_calls) = if tool_grammar.is_some() { let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text])) (None, Some(vec![stream_token.token.text]))
} else { } else {
(Some(stream_token.token.text), None) let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
}; };
event event
......
...@@ -540,7 +540,57 @@ fn prepare_input( ...@@ -540,7 +540,57 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(), Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
_ => inputs.clone(), _ => inputs.clone(),
}; };
......
vllm-cuda: vllm-cuda:
# Clone vllm # Clone vllm
pip install -U ninja packaging --no-cache-dir pip install -U ninja packaging --no-cache-dir
git clone https://github.com/OlivierDehaene/vllm.git vllm git clone https://github.com/Narsil/vllm.git vllm
build-vllm-cuda: vllm-cuda build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2 cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda
......
This diff is collapsed.
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "2.0.1" version = "2.0.2"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
...@@ -31,10 +31,12 @@ einops = "^0.6.1" ...@@ -31,10 +31,12 @@ einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }
peft = { version = "^0.10", optional = true } peft = { version = "^0.10", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.3.0", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines= { version = "^0.0.36", optional = true } outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]
......
...@@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" ...@@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13" filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
...@@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13" regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" ...@@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
......
...@@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" ...@@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13" filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
...@@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13" regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" ...@@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
......
...@@ -68,6 +68,7 @@ try: ...@@ -68,6 +68,7 @@ try:
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
# from text_generation_server.models.flash_mixtral import FlashMixtral # from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
...@@ -327,7 +328,7 @@ def get_model( ...@@ -327,7 +328,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "llama" or model_type == "baichuan": elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(
model_id, model_id,
...@@ -579,6 +580,18 @@ def get_model( ...@@ -579,6 +580,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "idefics2":
if FLASH_ATTENTION:
return Idefics2(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next": if model_type == "llava_next":
if FLASH_ATTENTION: if FLASH_ATTENTION:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment