Unverified Commit 2b19d671 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Rebase TRT-llm (#2331)

* wip

wip

refacto

refacto

Initial setup for CXX binding to TRTLLM

Working FFI call for TGI and TRTLLM backend

Remove unused parameters annd force tokenizer name to be set

Overall build TRTLLM and deps through CMake build system

Enable end to end CMake build

First version loading engines and making it ready for inference

Remembering to check how we can detect support for chunked context

Move to latest TensorRT-LLM version

Specify which default log level to use depending on CMake build type

make leader executor mode working

unconditionally call InitializeBackend on the FFI layer

bind to CUDA::nvml to retrieve compute capabilities at runtime

updated logic and comment to detect cuda compute capabilities

implement the Stream method to send new tokens through a callback

use spdlog release 1.14.1 moving forward

update trtllm to latest version a96cccafcf6365c128f004f779160951f8c0801c

correctly tell cmake to build dependent tensorrt...
parent 53aec273
/// Batching and inference logic use crate::infer::InferError;
use crate::infer::v3::queue::{Entry, Queue}; use crate::{
use crate::infer::{ ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
}; };
use crate::validation::ValidGenerateRequest; use minijinja::{Environment, ErrorKind, Template};
use crate::{FinishReason, PrefillToken, Token}; use minijinja_contrib::pycompat;
use nohash_hasher::IntMap;
use std::sync::{ /// Raise a exception (custom function) used in the chat templates
atomic::{AtomicBool, Ordering}, pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Arc, Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}; }
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct SchedulerV3 { #[derive(Clone)]
/// Request queue pub(crate) struct ChatTemplate {
queue: Queue, template: Template<'static, 'static>,
/// Notify batcher on queue appends bos_token: Option<String>,
batching_task_notifier: Arc<Notify>, eos_token: Option<String>,
use_default_tool_template: bool,
} }
impl SchedulerV3 { impl ChatTemplate {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, template: String,
waiting_served_ratio: f32, bos_token: Option<TokenizerConfigToken>,
max_batch_prefill_tokens: u32, eos_token: Option<TokenizerConfigToken>,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let mut env = Box::new(Environment::new());
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") // enable things like .strip() or .capitalize()
} else { env.set_unknown_method_callback(pycompat::unknown_method_callback);
false let template_str = template.into_boxed_str();
}; env.add_function("raise_exception", raise_exception);
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // check if contains the tools variable within the template
tokio::spawn(batching_task( let use_default_tool_template =
client, !template_str.as_ref().replace(' ', "").contains("{{tools}}");
waiting_served_ratio, // leaking env and template_str as read-only, static resources for performance.
max_batch_prefill_tokens, let template = Box::leak(env)
max_batch_total_tokens, .template_from_str(Box::leak(template_str))
max_waiting_tokens, .unwrap();
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
generation_health,
));
Self { Self {
queue, template,
batching_task_notifier, bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
} }
} }
}
impl Scheduler for SchedulerV3 { pub(crate) fn apply(
#[instrument(skip_all)]
fn schedule(
&self, &self,
request: ValidGenerateRequest, mut messages: Vec<Message>,
permit: OwnedSemaphorePermit, grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<GenerateStreamResponse, InferError> { ) -> Result<String, InferError> {
// MPSC channel to communicate with the background batching task if self.use_default_tool_template {
let (response_tx, response_rx) = mpsc::unbounded_channel(); if let Some(last_message) = messages.last_mut() {
let input_length = request.input_length; if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
// Append the request to the queue text: format!("\n---\n{}\n{}", tool_prompt, tools),
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
generation_health: Arc<AtomicBool>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
}); });
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
} }
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
} }
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await { let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
Ok((generations, next_batch, timings)) => {
// Update health self.template
generation_health.store(true, Ordering::SeqCst); .render(ChatTemplateInputs {
messages,
let start_filtering_time = Instant::now(); bos_token: self.bos_token.as_deref(),
// Send generated tokens and filter stopped entries eos_token: self.eos_token.as_deref(),
filter_send_generations(generations, entries); add_generation_prompt: true,
tools: None,
// Filter next batch and remove requests that were stopped tools_prompt: None,
let next_batch = filter_batch(client, next_batch, entries).await; })
.map_err(InferError::TemplateError)
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
} }
} }
// tests // tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage}; use crate::{ChatTemplateInputs, TextMessage};
use minijinja::Environment; use minijinja::Environment;
......
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::Health;
#[derive(Clone)]
pub(crate) struct HealthCheck {
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
}
impl HealthCheck {
pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
mod health; // pub(crate) mod v2;
pub(crate) mod v2; mod chat_template;
pub(crate) mod v3; pub mod tool_grammar;
pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType;
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, Message, PrefillToken, Token,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
}; };
use async_trait::async_trait;
use chat_template::ChatTemplate;
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::ErrorKind;
use minijinja_contrib::pycompat; use std::sync::atomic::{AtomicBool, Ordering};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
...@@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream; ...@@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
pub(crate) trait Scheduler { #[async_trait]
pub trait Backend {
fn schedule( fn schedule(
&self, &self,
request: ValidGenerateRequest, request: ValidGenerateRequest,
permit: OwnedSemaphorePermit, ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
) -> Result<GenerateStreamResponse, InferError>;
async fn health(&self, current_health: bool) -> bool;
} }
/// Inference struct /// Inference struct
...@@ -39,18 +36,20 @@ pub(crate) trait Scheduler { ...@@ -39,18 +36,20 @@ pub(crate) trait Scheduler {
pub struct Infer { pub struct Infer {
/// Validation /// Validation
validation: Validation, validation: Validation,
/// Request scheduler /// Request backend
scheduler: Arc<dyn Scheduler + Send + Sync>, backend: Arc<dyn Backend + Send + Sync>,
/// Chat template /// Chat template
chat_template: Option<ChatTemplate>, chat_template: Option<ChatTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Backend health
backend_health: Arc<AtomicBool>,
} }
impl Infer { impl Infer {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
scheduler: Arc<dyn Scheduler + Send + Sync>, backend: impl Backend + Send + Sync + 'static,
validation: Validation, validation: Validation,
max_concurrent_requests: usize, max_concurrent_requests: usize,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
...@@ -71,18 +70,22 @@ impl Infer { ...@@ -71,18 +70,22 @@ impl Infer {
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
// Backend health
let backend_health = Arc::new(AtomicBool::new(false));
Self { Self {
validation, validation,
scheduler, backend: Arc::new(backend),
chat_template, chat_template,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
backend_health,
} }
} }
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream<'a>(
&self, &'a self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> { ) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
...@@ -103,7 +106,10 @@ impl Infer { ...@@ -103,7 +106,10 @@ impl Infer {
err err
})?; })?;
self.scheduler.schedule(valid_request, permit) let input_length = valid_request.input_length;
let generation_stream = self.backend.schedule(valid_request)?;
Ok((permit, input_length, generation_stream))
} }
/// Tokenizer the input /// Tokenizer the input
...@@ -155,7 +161,7 @@ impl Infer { ...@@ -155,7 +161,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; let (_permit, _input_length, stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
...@@ -165,6 +171,8 @@ impl Infer { ...@@ -165,6 +171,8 @@ impl Infer {
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
let mut stream = Box::pin(stream);
// Iterate on stream // Iterate on stream
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response? { match response? {
...@@ -256,207 +264,15 @@ impl Infer { ...@@ -256,207 +264,15 @@ impl Infer {
let best_response = infer_responses.remove(max_index); let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses)) Ok((best_response, infer_responses))
} }
}
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
#[derive(Clone)]
struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}
fn apply( #[instrument(skip(self))]
&self, pub(crate) async fn health(&self) -> bool {
mut messages: Vec<Message>, let health = self
grammar_with_prompt: Option<(GrammarType, String)>, .backend
) -> Result<String, InferError> { .health(self.backend_health.load(Ordering::SeqCst))
if self.use_default_tool_template { .await;
if let Some(last_message) = messages.last_mut() { self.backend_health.store(health, Ordering::SeqCst);
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { health
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
pub struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
} }
} }
...@@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = ( ...@@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = (
); );
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct GeneratedText { pub struct GeneratedText {
pub(crate) text: String, pub text: String,
pub(crate) generated_tokens: u32, pub generated_tokens: u32,
pub(crate) finish_reason: FinishReason, pub finish_reason: FinishReason,
pub(crate) seed: Option<u64>, pub seed: Option<u64>,
} }
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(Vec<PrefillToken>), Prefill(Vec<PrefillToken>),
// Intermediate messages // Intermediate messages
......
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
pub(crate) struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
}
}
mod queue; mod queue;
mod scheduler; mod scheduler;
pub(crate) use scheduler::SchedulerV2; pub(crate) use scheduler::BackendV2;
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
}; };
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token}; use crate::{FinishReason, PrefillToken, Token};
...@@ -18,14 +18,14 @@ use tokio::time::Instant; ...@@ -18,14 +18,14 @@ use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct SchedulerV2 { pub(crate) struct BackendV2 {
/// Request queue /// Request queue
queue: Queue, queue: Queue,
/// Notify batcher on queue appends /// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>, batching_task_notifier: Arc<Notify>,
} }
impl SchedulerV2 { impl BackendV2 {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
...@@ -69,7 +69,7 @@ impl SchedulerV2 { ...@@ -69,7 +69,7 @@ impl SchedulerV2 {
} }
} }
impl Scheduler for SchedulerV2 { impl Backend for BackendV2 {
#[instrument(skip_all)] #[instrument(skip_all)]
fn schedule( fn schedule(
&self, &self,
......
mod block_allocator;
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV3;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
pub mod config; pub mod config;
mod infer; pub mod infer;
pub mod server; pub mod server;
mod validation; pub mod validation;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
mod kserve; mod kserve;
pub mod logging;
pub mod usage_stats; pub mod usage_stats;
...@@ -148,12 +149,13 @@ pub struct Info { ...@@ -148,12 +149,13 @@ pub struct Info {
pub model_id: String, pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>, pub model_sha: Option<String>,
#[schema(example = "torch.float16")] // #[schema(example = "torch.float16")]
pub model_dtype: String, // pub model_dtype: String,
#[schema(example = "cuda")] // #[schema(example = "cuda")]
pub model_device_type: String, // pub model_device_type: String,
#[schema(nullable = true, example = "text-generation")] #[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>, pub model_pipeline_tag: Option<String>,
/// Router Parameters /// Router Parameters
#[schema(example = "128")] #[schema(example = "128")]
pub max_concurrent_requests: usize, pub max_concurrent_requests: usize,
...@@ -165,18 +167,11 @@ pub struct Info { ...@@ -165,18 +167,11 @@ pub struct Info {
pub max_input_tokens: usize, pub max_input_tokens: usize,
#[schema(example = "2048")] #[schema(example = "2048")]
pub max_total_tokens: usize, pub max_total_tokens: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "2")] #[schema(example = "2")]
pub validation_workers: usize, pub validation_workers: usize,
#[schema(example = "32")] #[schema(example = "32")]
pub max_client_batch_size: usize, pub max_client_batch_size: usize,
/// Router Info /// Router Info
#[schema(example = "text-generation-router")] #[schema(example = "text-generation-router")]
pub router: &'static str, pub router: &'static str,
...@@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest { ...@@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken { pub struct PrefillToken {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, pub id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, pub text: String,
#[schema(nullable = true, example = - 0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, pub logprob: f32,
} }
#[derive(Debug, Serialize, ToSchema, Clone)] #[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, pub id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, pub text: String,
#[schema(nullable = true, example = - 0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, pub logprob: f32,
#[schema(example = "false")] #[schema(example = "false")]
special: bool, pub special: bool,
} }
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
...@@ -1102,7 +1097,7 @@ pub struct SimpleToken { ...@@ -1102,7 +1097,7 @@ pub struct SimpleToken {
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")] #[schema(example = "Length")]
pub(crate) enum FinishReason { pub enum FinishReason {
#[schema(rename = "length")] #[schema(rename = "length")]
Length, Length,
#[serde(rename = "eos_token")] #[serde(rename = "eos_token")]
......
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
/// HTTP Server logic /// HTTP Server logic
use crate::config::Config; use crate::config::Config;
use crate::infer::v2::SchedulerV2; use crate::infer::tool_grammar::ToolGrammar;
use crate::infer::v3::SchedulerV3; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
use crate::kserve::{ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::usage_stats;
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
...@@ -27,7 +26,7 @@ use crate::{ ...@@ -27,7 +26,7 @@ use crate::{
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
...@@ -37,15 +36,18 @@ use futures::stream::StreamExt; ...@@ -37,15 +36,18 @@ use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION; use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::fs::File;
use std::sync::atomic::AtomicBool; use std::io::BufReader;
use std::sync::Arc; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_client::{v2, v3, ClientError, ShardInfo}; use std::path::{Path, PathBuf};
use thiserror::Error; use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
...@@ -124,12 +126,10 @@ responses( ...@@ -124,12 +126,10 @@ responses(
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
) )
)] )]
#[instrument(skip(health))] #[instrument(skip(infer))]
/// Health check method /// Health check method
async fn health( async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
mut health: Extension<HealthCheck>, match infer.health().await {
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await {
true => Ok(()), true => Ok(()),
false => Err(( false => Err((
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
...@@ -430,8 +430,9 @@ async fn generate_stream_internal( ...@@ -430,8 +430,9 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, _input_length, mut response_stream)) => { Ok((_permit, _input_length, response_stream)) => {
let mut index = 0; let mut index = 0;
let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
index += 1; index += 1;
...@@ -1396,262 +1397,456 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String { ...@@ -1396,262 +1397,456 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct ComputeType(String); pub(crate) struct ComputeType(String);
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
health,
get_model_info,
compat_generate,
generate,
generate_stream,
chat_completions,
completions,
tokenize,
metrics,
),
components(
schemas(
Info,
CompatGenerateRequest,
GenerateRequest,
GrammarType,
ChatRequest,
Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletionLogprob,
ChatCompletionLogprobs,
ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
Chunk,
Completion,
CompletionFinal,
Prompt,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
GrammarType,
Usage,
DeltaToolCall,
ToolType,
Tool,
ToolCall,
Function,
FunctionDefinition,
ToolChoice,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
pub struct ApiDoc;
pub fn schema() -> ApiDoc {
ApiDoc
}
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
master_shard_uds_path: String, backend: impl Backend + Send + Sync + 'static,
model_info: HubModelInfo,
compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
tokenizer: Option<Tokenizer>,
config: Option<Config>,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
api_key: Option<String>, api_key: Option<String>,
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
print_schema_command: bool, disable_usage_stats: bool,
disable_crash_reports: bool,
) -> Result<(), WebServerError> { ) -> Result<(), WebServerError> {
// OpenAPI documentation // CORS allowed origins
#[derive(OpenApi)] // map to go inside the option and then map to parse from String to HeaderValue
#[openapi( // Finally, convert to AllowOrigin
paths( let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
health, AllowOrigin::list(
get_model_info, cors_allow_origin
compat_generate, .iter()
generate, .map(|origin| origin.parse::<HeaderValue>().unwrap()),
generate_stream, )
chat_completions, });
completions,
tokenize,
metrics,
),
components(
schemas(
Info,
CompatGenerateRequest,
GenerateRequest,
GrammarType,
ChatRequest,
Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletionLogprob,
ChatCompletionLogprobs,
ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
Chunk,
Completion,
CompletionFinal,
Prompt,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
GrammarType,
Usage,
DeltaToolCall,
ToolType,
Tool,
ToolCall,
Function,
FunctionDefinition,
ToolChoice,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct ApiDoc;
// Create state // Parse Huggingface hub token
if print_schema_command { let authorization_token = std::env::var("HF_TOKEN")
let api_doc = ApiDoc::openapi(); .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); .ok();
println!("{}", api_doc);
std::process::exit(0);
}
// Open connection, get model info and warmup // Tokenizer instance
let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( // This will only be used to validate payloads
Arc<dyn Scheduler + Send + Sync>, let local_path = Path::new(&tokenizer_name);
HealthCheck,
ShardInfo, // Shared API builder initialization
u32, let api_builder = || {
) = { let mut builder = ApiBuilder::new()
// Helper function to check both v2 and v3 .with_progress(false)
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| { .with_token(authorization_token);
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
None => { builder = builder.with_cache_dir(cache_dir.into());
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( }
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
); builder
tracing::warn!("Model does not support automatic max batch total tokens"); };
Ok(max_batch_total_tokens)
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
} }
// Flash attention models return their max supported total tokens }
Some(max_supported_batch_total_tokens) => { }
// Warn if user added his own max-batch-total-tokens as we will ignore it } else {
if max_batch_total_tokens.is_some() { Type::None
tracing::warn!( };
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models." // Load tokenizer and model info
); let (
tracing::warn!( tokenizer_filename,
"Inferred max batch total tokens: {max_supported_batch_total_tokens}" config_filename,
); tokenizer_config_filename,
} preprocessor_config_filename,
if max_total_tokens as u32 > max_supported_batch_total_tokens { processor_config_filename,
return Err(WebServerError::NotEnoughMemory(max_total_tokens)); model_info,
} ) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
Ok(max_supported_batch_total_tokens) let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
} }
} }
}; }
tokenizer
});
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
let preprocessor_config: Option<HubPreprocessorConfig> =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
let generation_health = Arc::new(AtomicBool::new(false)); // Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
Ok(mut sharded_client) => { let user_agent = if !disable_usage_stats && is_container {
// server is running on v3 let reduced_args = usage_stats::Args::new(
// Clear the cache; useful if the webserver rebooted config.clone(),
sharded_client tokenizer_config.tokenizer_class.clone(),
.clear_cache(None) max_concurrent_requests,
.await max_best_of,
.map_err(WebServerError::Cache)?; max_stop_sequences,
// Get info from the shard max_top_n_tokens,
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; max_input_tokens,
max_total_tokens,
// Warmup model // waiting_served_ratio,
tracing::info!("Warming up model"); // max_batch_prefill_tokens,
let max_batch_total_tokens = check_max_batch_total_tokens( // max_batch_total_tokens,
sharded_client // max_waiting_tokens,
.warmup( // max_batch_size,
max_input_tokens as u32, revision.clone(),
max_batch_prefill_tokens, validation_workers,
max_total_tokens as u32, messages_api_enabled,
max_batch_size, disable_grammar_support,
) max_client_batch_size,
.await disable_usage_stats,
.map_err(WebServerError::Warmup)?, disable_crash_reports,
)?; );
Some(usage_stats::UserAgent::new(reduced_args))
let health_ext = } else {
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); None
let scheduler = Arc::new(SchedulerV3::new( };
sharded_client,
waiting_served_ratio, if let Some(ref ua) = user_agent {
max_batch_prefill_tokens, let start_event =
max_batch_total_tokens, usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
max_waiting_tokens, tokio::spawn(async move {
max_batch_size, start_event.send().await;
shard_info.requires_padding, });
shard_info.window_size, };
shard_info.speculate, let compat_return_full_text = match &model_info.pipeline_tag {
generation_health, None => {
)); tracing::warn!("no pipeline tag found for model {tokenizer_name}");
tracing::info!("Using scheduler V3"); true
}
(scheduler, health_ext, shard_info, max_batch_total_tokens) Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
let result = start(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
config,
(tokenizer, tokenizer_config),
(preprocessor_config, processor_config),
hostname,
port,
ngrok,
_ngrok_authtoken,
_ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
model_info,
compat_return_full_text,
allow_origin,
)
.await;
if let Some(ua) = user_agent {
match result {
Ok(_) => {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
Ok(())
} }
Err(_) => { Err(e) => {
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) if !disable_crash_reports {
.await let error_event = usage_stats::UsageStatsEvent::new(
.map_err(WebServerError::Connection)?; ua.clone(),
usage_stats::EventType::Error,
// server is running on v2 Some(e.to_string()),
// Clear the cache; useful if the webserver rebooted );
sharded_client error_event.send().await;
.clear_cache(None) } else {
.await let unknow_error_event = usage_stats::UsageStatsEvent::new(
.map_err(WebServerError::Cache)?; ua.clone(),
// Get info from the shard usage_stats::EventType::Error,
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; Some("unknow_error".to_string()),
);
// Warmup model unknow_error_event.send().await;
tracing::info!("Warming up model"); }
let max_batch_total_tokens = check_max_batch_total_tokens( Err(e)
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V2");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
} }
} }
} else {
result
}
}
#[allow(clippy::too_many_arguments)]
async fn start(
backend: impl Backend + Send + Sync + 'static,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
validation_workers: usize,
api_key: Option<String>,
config: Option<Config>,
(tokenizer, tokenizer_config): (Option<Tokenizer>, HubTokenizerConfig),
(preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),
hostname: String,
port: u16,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
model_info: HubModelInfo,
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
}; };
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
// Create state
let validation = Validation::new( let validation = Validation::new(
validation_workers, validation_workers,
tokenizer, tokenizer,
...@@ -1662,11 +1857,11 @@ pub async fn run( ...@@ -1662,11 +1857,11 @@ pub async fn run(
max_top_n_tokens, max_top_n_tokens,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
grammar_support, disable_grammar_support,
); );
let infer = Infer::new( let infer = Infer::new(
scheduler, backend,
validation, validation,
max_concurrent_requests, max_concurrent_requests,
tokenizer_config, tokenizer_config,
...@@ -1703,8 +1898,8 @@ pub async fn run( ...@@ -1703,8 +1898,8 @@ pub async fn run(
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect(); let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets // Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); // let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); // let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// Prometheus handler // Prometheus handler
let builder = PrometheusBuilder::new() let builder = PrometheusBuilder::new()
...@@ -1717,9 +1912,9 @@ pub async fn run( ...@@ -1717,9 +1912,9 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap() .unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap(); .unwrap();
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
// .unwrap();
let prom_handle = builder let prom_handle = builder
.install_recorder() .install_recorder()
.expect("failed to install metrics recorder"); .expect("failed to install metrics recorder");
...@@ -1735,18 +1930,18 @@ pub async fn run( ...@@ -1735,18 +1930,18 @@ pub async fn run(
let info = Info { let info = Info {
model_id: model_info.model_id, model_id: model_info.model_id,
model_sha: model_info.sha, model_sha: model_info.sha,
model_dtype: shard_info.dtype, // model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type, // model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag, model_pipeline_tag: model_info.pipeline_tag,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, // waiting_served_ratio,
max_batch_total_tokens, // max_batch_total_tokens,
max_waiting_tokens, // max_waiting_tokens,
max_batch_size, // max_batch_size,
validation_workers, validation_workers,
max_client_batch_size, max_client_batch_size,
router: env!("CARGO_PKG_NAME"), router: env!("CARGO_PKG_NAME"),
...@@ -1907,7 +2102,6 @@ pub async fn run( ...@@ -1907,7 +2102,6 @@ pub async fn run(
// add layers after routes // add layers after routes
app = app app = app
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(compute_type)) .layer(Extension(compute_type))
...@@ -1945,6 +2139,68 @@ pub async fn run( ...@@ -1945,6 +2139,68 @@ pub async fn run(
Ok(()) Ok(())
} }
/// get model info from the Huggingface Hub
pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Shutdown signal handler /// Shutdown signal handler
async fn shutdown_signal() { async fn shutdown_signal() {
let ctrl_c = async { let ctrl_c = async {
...@@ -2008,16 +2264,77 @@ impl From<InferError> for Event { ...@@ -2008,16 +2264,77 @@ impl From<InferError> for Event {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum WebServerError { pub enum WebServerError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
} }
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
...@@ -78,11 +78,11 @@ pub struct Args { ...@@ -78,11 +78,11 @@ pub struct Args {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, // waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, // max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>, // max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, // max_waiting_tokens: usize,
max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool, messages_api_enabled: bool,
...@@ -103,11 +103,11 @@ impl Args { ...@@ -103,11 +103,11 @@ impl Args {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, // waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, // max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>, // max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, // max_waiting_tokens: usize,
max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool, messages_api_enabled: bool,
...@@ -125,11 +125,11 @@ impl Args { ...@@ -125,11 +125,11 @@ impl Args {
max_top_n_tokens, max_top_n_tokens,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, // waiting_served_ratio,
max_batch_prefill_tokens, // max_batch_prefill_tokens,
max_batch_total_tokens, // max_batch_total_tokens,
max_waiting_tokens, // max_waiting_tokens,
max_batch_size, // max_batch_size,
revision, revision,
validation_workers, validation_workers,
messages_api_enabled, messages_api_enabled,
......
...@@ -5,13 +5,12 @@ use crate::{ ...@@ -5,13 +5,12 @@ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
}; };
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
...@@ -96,7 +95,7 @@ impl Validation { ...@@ -96,7 +95,7 @@ impl Validation {
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> { ) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
...@@ -122,7 +121,7 @@ impl Validation { ...@@ -122,7 +121,7 @@ impl Validation {
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
...@@ -181,11 +180,7 @@ impl Validation { ...@@ -181,11 +180,7 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok(( Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
} }
} }
...@@ -589,7 +584,7 @@ fn prepare_input( ...@@ -589,7 +584,7 @@ fn prepare_input(
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
...@@ -601,16 +596,16 @@ fn prepare_input( ...@@ -601,16 +596,16 @@ fn prepare_input(
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
...@@ -618,7 +613,7 @@ fn prepare_input( ...@@ -618,7 +613,7 @@ fn prepare_input(
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), _ => (inputs.clone(), vec![Chunk::Text(inputs)]),
}; };
// Get the number of tokens in the input // Get the number of tokens in the input
...@@ -631,18 +626,51 @@ fn prepare_input( ...@@ -631,18 +626,51 @@ fn prepare_input(
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span, Span,
); );
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Image {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk {
Text(String),
Image(Image),
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<Chunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c {
Chunk::Text(text) => output.push_str(text),
Chunk::Image(Image { data, mimetype }) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
});
output
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) enum ValidGrammar { pub enum ValidGrammar {
Json(String), Json(String),
Regex(String), Regex(String),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidParameters { pub struct ValidParameters {
/// / exponential scaling output probability distribution /// / exponential scaling output probability distribution
pub temperature: f32, pub temperature: f32,
/// / restricting to the k highest probability elements /// / restricting to the k highest probability elements
...@@ -666,7 +694,7 @@ pub(crate) struct ValidParameters { ...@@ -666,7 +694,7 @@ pub(crate) struct ValidParameters {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters { pub struct ValidStoppingParameters {
/// / Maximum number of generated tokens /// / Maximum number of generated tokens
pub max_new_tokens: u32, pub max_new_tokens: u32,
/// / Optional stopping sequences /// / Optional stopping sequences
...@@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters { ...@@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>, pub inputs: Vec<Chunk>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,
...@@ -750,6 +778,8 @@ pub enum ValidationError { ...@@ -750,6 +778,8 @@ pub enum ValidationError {
InvalidImageContent(String), InvalidImageContent(String),
#[error("Could not fetch image: {0}")] #[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str),
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -167,22 +167,24 @@ def check_openapi(check: bool): ...@@ -167,22 +167,24 @@ def check_openapi(check: bool):
else: else:
os.rename(tmp_filename, filename) os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.") print("OpenAPI documentation updated.")
errors = subprocess.run( p = subprocess.run(
[ [
"swagger-cli", "redocly",
# allow for trailing whitespace since it's not significant # allow for trailing whitespace since it's not significant
# and the precommit hook will remove it # and the precommit hook will remove it
"validate", "lint",
filename, filename,
], ],
capture_output=True, capture_output=True,
).stderr.decode("utf-8") )
errors = p.stderr.decode("utf-8")
# The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where
# utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969
if not errors.startswith("Swagger schema validation failed."): print(errors)
if p.returncode != 0:
print(errors) print(errors)
raise Exception( raise Exception(
f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" f"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\n {errors}"
) )
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment