Commit 81a882ad authored by jixx's avatar jixx
Browse files

add tgi2.4.0

parent 9822d7f6
use crate::client::Health;
/// Multi shard Client
use crate::client::{ClientError, Result};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap()
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
adapter_id: None,
chunk_len: None,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch, None).await?;
Ok(())
}
}
mod backend;
pub mod block_allocator;
mod client;
mod queue;
pub mod radix;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
#[schema(example = "false")]
pub prefix_caching: bool,
#[schema(example = "flashinfer")]
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V3Error::Connection)?;
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V3Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
prefix_caching: shard_info.use_prefix_caching,
attention_impl: shard_info.attention_impl.clone(),
block_size: shard_info.block_size,
};
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}
use clap::{Parser, Subcommand};
use text_generation_router::{server, usage_stats};
use text_generation_router_v3::{connect_backend, V3Error};
use thiserror::Error;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
command,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
disable_grammar_support,
max_client_batch_size,
usage_stats,
} = args;
if let Some(Commands::PrintSchema) = command {
use utoipa::OpenApi;
let api_doc = text_generation_router::server::ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
};
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` must be > 0".to_string(),
));
}
}
let (backend, backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
)
.await?;
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
disable_grammar_support,
max_client_batch_size,
usage_stats,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] V3Error),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
use crate::block_allocator::{BlockAllocation, BlockAllocator};
use crate::client;
use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::cmp::max;
use std::collections::VecDeque;
use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
};
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
......@@ -46,9 +46,11 @@ impl Queue {
pub(crate) fn new(
requires_padding: bool,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
......@@ -57,9 +59,11 @@ impl Queue {
tokio::spawn(queue_task(
requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
queue_receiver,
));
......@@ -85,6 +89,10 @@ impl Queue {
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if prefill_token_budget == 0 || token_budget == 0 {
return None;
};
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state
......@@ -106,27 +114,32 @@ impl Queue {
}
// Background task responsible of the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task(
requires_padding: bool,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(
requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
);
while let Some(cmd) = receiver.recv().await {
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
metrics::gauge!("tgi_queue_size").increment(1.0);
}
QueueCommand::NextBatch {
min_size,
......@@ -141,7 +154,7 @@ async fn queue_task(
.instrument(span)
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
}
}
}
......@@ -162,12 +175,14 @@ struct State {
/// Paged Attention block size
block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
/// Whether the model allow the prefill chunking
/// If it does, the last request in the batch will be split to exactly match the prefill
/// token budget
support_chunking: bool,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
......@@ -176,20 +191,28 @@ impl State {
fn new(
requires_padding: bool,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self {
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new(
max_batch_total_tokens,
block_size,
prefix_caching,
window_size,
)
});
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
block_size,
window_size,
speculate,
support_chunking,
block_allocator,
}
}
......@@ -226,29 +249,33 @@ impl State {
}
}
if let Some(max_size) = max_size {
if max_size == 0 {
tracing::debug!("No capacity");
return None;
}
}
// Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
next_batch_span.follows_from(Span::current());
let mut batch = Vec::with_capacity(self.entries.len());
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
// Pop entries starting from the front of the queue
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry");
continue;
}
......@@ -258,7 +285,7 @@ impl State {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
......@@ -273,32 +300,21 @@ impl State {
None
}
Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
// If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
None
} else {
entry.request.input_ids.clone()
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens).await {
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
......@@ -306,16 +322,88 @@ impl State {
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
if block_allocation.prefix_len == entry.request.input_length {
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
block_allocation.prefix_len -= 1;
}
block_allocation
}
};
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
if prefill_tokens + postfix_len > prefill_token_budget {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
if chunk_len > 0 {
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
} else {
// We cannot prefill even one token for this entry
// Add it back to the queue
self.entries.push_front((id, entry));
}
tracing::debug!(
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
prefill_tokens + postfix_len
);
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
// Add it back to the front
tracing::debug!(
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
Some(block_allocation)
}
};
batch.push((id, entry, block_allocation, None));
if Some(batch.len()) == max_size {
break;
}
}
tracing::debug!("Accepting entry");
// Empty batch
if batch.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// XXX We haven't allocated yet, so we're allowed to ditch the results.
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch.len() < min_size {
// Add back entries to the queue in the correct order
for (id, entry, _, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry));
}
return None;
}
}
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation, chunk_len) in batch {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
......@@ -324,11 +412,12 @@ impl State {
// Update entry
entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
let (blocks, slots, prefix_len) = match &block_allocation {
None => (Vec::new(), Vec::new(), 0),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
block_allocation.prefix_len,
),
};
......@@ -337,11 +426,26 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
input_chunks: Some(client::Input {
chunks: entry
.request
.inputs
.clone()
.into_iter()
.map(|c| client::InputChunk {
chunk: Some(match c {
Chunk::Text(text) => client::Chunk::Text(text),
Chunk::Image(image) => client::Chunk::Image(client::Image {
data: image.data,
mimetype: image.mimetype,
}),
}),
})
.collect(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
add_special_tokens: entry.request.add_special_tokens,
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
)),
......@@ -351,38 +455,14 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
cache_len: prefix_len,
adapter_id: entry.request.adapter_id.clone(),
chunk_len,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
// Check if max_size
if Some(batch_requests.len()) == max_size {
break;
}
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch_requests.len() < min_size {
// Add back entries to the queue in the correct order
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
}
return None;
}
}
// Final batch size
......@@ -399,7 +479,7 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
......@@ -459,6 +539,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use tracing::info_span;
......@@ -471,7 +553,9 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_length: 0,
input_ids: Some(Arc::new(vec![])),
input_length: 1,
add_special_tokens: true,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {
......@@ -506,7 +590,7 @@ mod tests {
#[tokio::test]
async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
......@@ -522,7 +606,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
......@@ -530,7 +614,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -562,7 +646,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -582,7 +666,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -615,14 +699,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
......@@ -630,7 +714,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -663,7 +747,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -679,7 +763,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -704,7 +788,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16);
let queue = Queue::new(true, 1, false, None, 2, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -723,7 +807,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _) = default_entry();
queue.append(entry);
......
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::hash::{Hash, Hasher};
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
fn hash(slice: &[u32]) -> u64 {
assert!(!slice.is_empty());
if slice.len() == 1 {
slice[0] as u64
} else {
let mut s = std::hash::DefaultHasher::new();
slice.hash(&mut s);
s.finish()
}
}
pub struct RadixAllocator {
allocation_id: u64,
allocations: HashMap<u64, RadixAllocation>,
cache_blocks: RadixTrie,
/// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>,
#[allow(dead_code)]
// This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>,
block_size: u32,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(block_size as usize),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
window_size,
block_size,
}
}
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
if self.free_blocks.len() < n_blocks_needed {
// This is a bit annoying, we first extend the free list and then
// split it off again below. This is because we need to put it on
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
tracing::debug!(
"Free blocks {} need {n_blocks_needed}",
self.free_blocks.len()
);
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
);
}
if self.free_blocks.len() >= n_blocks_needed {
Some(
self.free_blocks
.split_off(self.free_blocks.len() - n_blocks_needed),
)
} else {
None
}
}
}
// Allocator trait
impl Allocator for RadixAllocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
node_id
} else {
self.cache_blocks.root_id()
};
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
self.cache_blocks
.incref(prefix_node)
.expect("Failed to increment refcount");
let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
tracing::debug!("Block size {}", self.block_size);
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
return None;
}
}
// 1:1 mapping of blocks and slots.
let slots = if self.block_size == 1 {
blocks.clone()
} else {
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
'slots: for block_id in &blocks {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() as u32 == tokens {
break 'slots;
}
}
}
slots
};
let allocation = RadixAllocation {
prefix_node,
cached_prefix_len: prefix_len,
prefill_tokens: prefill_tokens.clone(),
};
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);
Some(BlockAllocation {
allocation_id: self.allocation_id,
block_allocator: None,
blocks,
slots,
prefix_len: prefix_len as u32,
})
}
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
self.cache_blocks
.decref(allocation.prefix_node)
.expect("Failed to decrement refcount");
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let aligned =
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
if aligned > 0 {
let prefix_len = self
.cache_blocks
.insert(
&prefill_tokens[..aligned],
&blocks[..aligned / self.block_size as usize],
)
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
if prefix_len > allocation.cached_prefix_len {
self.free_blocks.extend(
&blocks[allocation.cached_prefix_len / self.block_size as usize
..prefix_len / self.block_size as usize],
);
}
}
}
// Free non-prefill blocks.
self.free_blocks
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
} else {
self.free_blocks.extend(blocks);
}
}
}
struct RadixAllocation {
prefix_node: NodeId,
cached_prefix_len: usize,
prefill_tokens: Option<Arc<Vec<u32>>>,
}
// Radix trie that is heavily inspired by radix attention from sglang.
//
// The trie is optimized for prefix caching:
//
// - A normal radix trie stores discrete values. In this radix trie,
// inserting *abc* with value *xyz* will also enable lookup for
// *a* (*x*) and *ab* (*xy*).
// - As a result, every value is required to have the same length as
// the key.
// - We store additional information in each node, such as last access
// time and a reference count.
#[derive(Debug)]
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
}
pub type NodeId = DefaultKey;
#[derive(Debug)]
pub struct RadixTrie {
/// Identifier of the root nod.
root: DefaultKey,
/// Leave node identifiers ordered by increasing recency.
leaves: BTreeSet<(u64, NodeId)>,
/// All trie nodes.
nodes: SlotMap<NodeId, TrieNode>,
/// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require.
time: u64,
/// All blocks need to be aligned with this
block_size: usize,
}
impl RadixTrie {
/// Construct a new radix trie.
pub fn new(block_size: usize) -> Self {
let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new();
let root = nodes.insert(root);
RadixTrie {
leaves: BTreeSet::new(),
nodes,
root,
time: 0,
block_size,
}
}
/// Find the prefix of the given tokens.
///
/// The blocks corresponding to the part of the prefix that could be found
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count.
///
/// Using this method will update the access time of the traversed nodes.
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
self.time += 1;
self.find_(self.root, key, blocks)
}
/// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if key.len() >= self.block_size {
let node_key = hash(&key[..self.block_size]);
if let Some(&child_id) = node.children.get(&node_key) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
}
}
}
node_id
}
/// Decrease the reference count of a node.
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
// We don't care about refcounting for root, since it will never
// be evicted.
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
return Err(TrieError::RefCountUnderflow);
}
node.ref_count -= 1;
if node.ref_count == 0 {
assert!(
node.children.is_empty(),
"Nodes with children must have refcount > 0"
);
self.leaves.insert((node.last_accessed, node_id));
}
Ok(())
}
/// Increase the reference count of a node.
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
self.leaves.remove(&(node.last_accessed, node_id));
}
node.ref_count += 1;
Ok(())
}
/// Evict `n_blocks` from the trie.
///
/// Returns the evicted blocks. When the length is less than `n_blocks`,
/// not enough blocks could be evicted.
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
// NOTE: we don't return Result here. If any of the unwrapping fails,
// it's a programming error in the trie implementation, not a user
// error caused by e.g. an invalid argument.
// TODO: add some bookkeeping in the future to check whether we can
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
tracing::debug!("Evicting in search of {n_blocks}");
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!(
node.ref_count, 0,
"Leaf must have refcount of 0, got {}",
node.ref_count
);
if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id);
evicted.extend(node.blocks);
if evicted.len() >= n_blocks {
break;
}
} else {
// The node has more blocks than needed, so we'll just remove
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
let truncate_blocks = node.blocks.len() - blocks_needed;
let truncate_tokens = truncate_blocks * self.block_size;
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id));
break;
}
}
evicted
}
/// Insert a prefill along with its blocks.
///
/// This method returns the length of the prefix that was already
/// in the trie. E.g. if the length is 10, this means that for
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1;
let common = self.insert_(self.root, tokens, blocks)?;
Ok(common)
}
/// Insertion worker.
fn insert_(
&mut self,
node_id: NodeId,
tokens: &[u32],
blocks: &[u32],
) -> Result<usize, TrieError> {
// TODO: in the future we may want to check that the blocks match for
// the part of the prefix that is already in the trie to detect
// mismatches.
assert_eq!(tokens.len(), blocks.len() * self.block_size);
let node_key = hash(&tokens[..self.block_size]);
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
self.update_access_time(child_id);
let child = self
.nodes
.get_mut(child_id)
// Unwrap here, since failure is a bug.
.expect("Child node does not exist");
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
// We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
return Ok(shared_prefix_len);
}
// The node's prefix is a prefix of the insertion prefix.
if shared_prefix_len == child.key.len() {
return Ok(shared_prefix_len
+ self.insert_(
child_id,
&tokens[shared_prefix_len..],
&blocks[shared_prefix_len / self.block_size..],
)?);
}
// The node's prefix and the insertion prefix only match partially,
// split the node to just contain the matching part. Then insert the
// remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len / self.block_size..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else {
self.add_node(node_id, tokens, blocks);
Ok(0)
}
}
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
// We have to make the current node a child to ensure that its
// properties and node id stay the same.
// This funcion unwraps, an invalid node_id is a programming error.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let prefix_blocks = prefix_len / self.block_size;
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = hash(&node.key[..self.block_size]);
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
self.add_node_to_parent(parent_id, node_key, node_id);
// Reborrow to make the borrow checker happy.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
node.parent = Some(parent_id);
parent_id
}
/// Create a node and add it to the parent.
fn add_node(
&mut self,
parent_id: NodeId,
key: impl Into<Vec<u32>>,
blocks: impl Into<Vec<u32>>,
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = hash(&key[..self.block_size]);
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
self.add_node_to_parent(parent_id, first, child_id);
self.leaves.insert((self.time, child_id));
child_id
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(hash, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
}
}
/// Remove a node from the trie.
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.remove(node_id).expect("Unknown node");
assert!(
node.children.is_empty(),
"Tried to remove a node with {} children",
node.children.len()
);
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
let node_key = hash(&node.key[..self.block_size]);
parent.children.remove(&node_key);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
node
}
fn update_access_time(&mut self, node_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.get_mut(node_id).expect("Unknown node");
// Update the ordered leaves set if the node is a leave.
if self.leaves.remove(&(node.last_accessed, node_id)) {
self.leaves.insert((self.time, node_id));
}
node.last_accessed = self.time;
}
#[allow(dead_code)]
#[doc(hidden)]
/// Print debugging output for the trie.
///
/// In contrast to `Debug` nicely formatted.
pub fn print_debug(&self) {
self.print_debug_(self.root, 0);
}
fn print_debug_(&self, node_id: NodeId, indent: usize) {
let node = &self.nodes[node_id];
eprintln!(
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
" ".repeat(indent),
node_id,
node.key,
node.blocks,
node.ref_count,
node.last_accessed,
node.parent,
node.children
);
for child_id in self.nodes[node_id].children.values() {
self.print_debug_(*child_id, indent + 2);
}
}
pub(crate) fn root_id(&self) -> DefaultKey {
self.root
}
}
/// Trie node.
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u64, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,
ref_count: usize,
}
impl TrieNode {
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
TrieNode {
children: HashMap::new(),
key,
blocks,
last_accessed,
parent,
ref_count: 0,
}
}
}
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
// NOTE: this is the case because the child node was chosen based on
// matching the first character of the key/prefix.
assert!(full > 0, "Prefixes must at least share 1 token");
(full / block_size) * block_size
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[test]
fn allocator_block_size() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_block_size_non_aligned() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 2);
}
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.blocks, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0);
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
assert_eq!(allocation2.blocks, vec![1, 2]);
assert_eq!(allocation2.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation3.prefix_len, 0);
}
#[test]
fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation3.prefix_len, 4);
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
assert_eq!(cache.free_blocks.len(), 5);
}
#[test]
fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation2 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
assert_eq!(allocation2.prefix_len, 2);
let allocation3 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation3.prefix_len, 2);
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
assert_eq!(cache.free_blocks.len(), 11);
let allocation4 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
assert_eq!(allocation4.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
let allocation5 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
assert_eq!(allocation5.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
}
#[test]
fn trie_insertions_have_correct_prefix_len() {
let mut trie = RadixTrie::new(1);
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
// Already exists.
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap(),
4
);
}
#[test]
fn trie_insertions_block_size() {
let mut trie = RadixTrie::new(2);
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
// Already exists.
// But needs to be block_size aligned
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
.unwrap(),
2
);
}
#[test]
fn trie_get_returns_correct_blocks() {
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
let mut blocks = Vec::new();
trie.find(&[0], &mut blocks);
assert_eq!(blocks, vec![0]);
blocks.clear();
trie.find(&[0, 1, 2], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2]);
blocks.clear();
trie.find(&[1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
}
#[test]
fn trie_evict_removes_correct_blocks() {
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
let mut blocks = Vec::new();
// Remove less than the leave blocks.
assert_eq!(trie.evict(1), vec![7]);
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
// Refresh other leaf.
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
trie.find(&[1, 2, 3], &mut blocks);
// Remove the leave blocks exactly.
assert_eq!(trie.evict(2), vec![5, 6]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
trie.find(&[1, 2, 3], &mut blocks);
// Remove more than the leave blocks.
assert_eq!(trie.evict(3), vec![4, 3, 2]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1]);
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
}
......@@ -16,16 +16,15 @@ path = "src/main.rs"
[dependencies]
average = "0.14"
clap = { version = "4.4.5", features = ["derive", "env"] }
crossterm = "0.27"
float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0"
tabled = "0.14.0"
text-generation-client = { path = "../router/client" }
text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
ratatui = "0.28.1"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = { workspace = true }
......@@ -7,7 +7,7 @@
</div>
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
and powered by [Ratatui](https://github.com/ratatui/ratatui).
## Install
......
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
use crate::generation::{Decode, Message, Prefill};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
use tui::backend::Backend;
use tui::layout::{Alignment, Constraint, Direction, Layout};
use tui::style::{Color, Modifier, Style};
use tui::text::{Line, Span};
use tui::widgets::{
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
use ratatui::style::{Color, Modifier, Style};
use ratatui::text::{Line, Span};
use ratatui::widgets::{
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
};
use tui::{symbols, Frame};
use ratatui::{symbols, Frame};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
/// TUI powered App
pub(crate) struct App {
......@@ -153,7 +152,7 @@ impl App {
}
/// Render frame
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
pub fn render(&mut self, f: &mut Frame) {
let batch_progress =
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
let run_progress =
......@@ -172,7 +171,7 @@ impl App {
]
.as_ref(),
)
.split(f.size());
.split(f.area());
// Top row horizontal layout
let top = Layout::default()
......@@ -239,7 +238,7 @@ impl App {
f.render_widget(helper, row5[0]);
// Batch tabs
let titles = self
let titles: Vec<Line> = self
.data
.batch_size
.iter()
......
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
use crossterm::event;
use ratatui::crossterm::event;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, mpsc};
......
......@@ -148,6 +148,7 @@ async fn prefill(
}),
inputs: sequence.clone(),
truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
......@@ -157,6 +158,8 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
cache_len: 0,
chunk_len: None,
adapter_id: None,
})
.collect();
......@@ -171,7 +174,7 @@ async fn prefill(
// Run prefill
let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
// Get latency
let latency = start_time.elapsed();
......
......@@ -6,13 +6,13 @@ mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use ratatui::backend::CrosstermBackend;
use ratatui::crossterm::ExecutableCommand;
use ratatui::Terminal;
use std::io;
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
use tui::Terminal;
/// Run benchmarking app
#[allow(clippy::too_many_arguments)]
......@@ -50,9 +50,9 @@ pub async fn run(
};
// Initialize terminal properties
crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?;
ratatui::crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
// Initialize terminal
let mut terminal = {
......@@ -128,9 +128,9 @@ pub async fn run(
let _ = shutdown_guard_receiver.recv().await;
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
ratatui::crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
let parameters_table = table::parameters_table(
tokenizer_name,
......
......@@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
// Run app
......
# Legacy warning ⚠️
The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`.
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
......
......@@ -27,3 +27,6 @@ asyncio_mode = "auto"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"
import pytest
from text_generation import __version__
from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
@pytest.fixture
def llama_7b():
return "meta-llama/Llama-2-7b-chat-hf"
@pytest.fixture
def fake_model():
return "fake/model"
@pytest.fixture
def unsupported_model():
return "gpt2"
@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"
@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def llama_7b_url(base_url, llama_7b):
# return f"{base_url}/{llama_7b}"
return "http://localhost:3001"
@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"
@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"
@pytest.fixture(scope="session")
def hf_headers():
# return build_hf_headers(
# library_name="text-generation-tests", library_version=__version__
# )
header = {'content-type': 'application/json'}
return header
import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, InputToken
def test_generate(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special
def test_generate_best_of(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)
assert response.details.seed is not None
assert response.details.best_of_sequences is not None
assert len(response.details.best_of_sequences) == 1
assert response.details.best_of_sequences[0].seed is not None
def test_generate_validation_error(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000)
def test_generate_stream(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
def test_generate_stream_validation_error(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio
async def test_generate_async(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True
)
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert response.details.prefill[1] == InputToken(
id=1243, text="test", logprob=-10.9375
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special
@pytest.mark.asyncio
async def test_generate_async_best_of(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)
assert response.details.seed is not None
assert response.details.best_of_sequences is not None
assert len(response.details.best_of_sequences) == 1
assert response.details.best_of_sequences[0].seed is not None
@pytest.mark.asyncio
async def test_generate_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio
async def test_generate_stream_async(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
@pytest.mark.asyncio
async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass
from text_generation.errors import (
parse_error,
GenerationError,
IncompleteGenerationError,
OverloadedError,
ValidationError,
BadRequestError,
ShardNotReadyError,
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
)
def test_generation_error():
payload = {"error_type": "generation", "error": "test"}
assert isinstance(parse_error(400, payload), GenerationError)
def test_incomplete_generation_error():
payload = {"error_type": "incomplete_generation", "error": "test"}
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
def test_overloaded_error():
payload = {"error_type": "overloaded", "error": "test"}
assert isinstance(parse_error(400, payload), OverloadedError)
def test_validation_error():
payload = {"error_type": "validation", "error": "test"}
assert isinstance(parse_error(400, payload), ValidationError)
def test_bad_request_error():
payload = {"error": "test"}
assert isinstance(parse_error(400, payload), BadRequestError)
def test_shard_not_ready_error():
payload = {"error": "test"}
assert isinstance(parse_error(403, payload), ShardNotReadyError)
assert isinstance(parse_error(424, payload), ShardNotReadyError)
def test_shard_timeout_error():
payload = {"error": "test"}
assert isinstance(parse_error(504, payload), ShardTimeoutError)
def test_not_found_error():
payload = {"error": "test"}
assert isinstance(parse_error(404, payload), NotFoundError)
def test_rate_limit_exceeded_error():
payload = {"error": "test"}
assert isinstance(parse_error(429, payload), RateLimitExceededError)
def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)
import pytest
from text_generation.types import Parameters, Request
from text_generation.errors import ValidationError
def test_parameters_validation():
# Test best_of
Parameters(best_of=1)
with pytest.raises(ValidationError):
Parameters(best_of=0)
with pytest.raises(ValidationError):
Parameters(best_of=-1)
Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError):
Parameters(best_of=2)
with pytest.raises(ValidationError):
Parameters(best_of=2, seed=1)
# Test repetition_penalty
Parameters(repetition_penalty=1)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=0)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=-1)
# Test seed
Parameters(seed=1)
with pytest.raises(ValidationError):
Parameters(seed=-1)
# Test temperature
Parameters(temperature=1)
with pytest.raises(ValidationError):
Parameters(temperature=0)
with pytest.raises(ValidationError):
Parameters(temperature=-1)
# Test top_k
Parameters(top_k=1)
with pytest.raises(ValidationError):
Parameters(top_k=0)
with pytest.raises(ValidationError):
Parameters(top_k=-1)
# Test top_p
Parameters(top_p=0.5)
with pytest.raises(ValidationError):
Parameters(top_p=0)
with pytest.raises(ValidationError):
Parameters(top_p=-1)
with pytest.raises(ValidationError):
Parameters(top_p=1)
# Test truncate
Parameters(truncate=1)
with pytest.raises(ValidationError):
Parameters(truncate=0)
with pytest.raises(ValidationError):
Parameters(truncate=-1)
# Test typical_p
Parameters(typical_p=0.5)
with pytest.raises(ValidationError):
Parameters(typical_p=0)
with pytest.raises(ValidationError):
Parameters(typical_p=-1)
with pytest.raises(ValidationError):
Parameters(typical_p=1)
def test_request_validation():
Request(inputs="test")
with pytest.raises(ValidationError):
Request(inputs="")
Request(inputs="test", stream=True)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
with pytest.raises(ValidationError):
Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)
......@@ -19,5 +19,15 @@ DEPRECATION_WARNING = (
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
from text_generation.client import Client, AsyncClient # noqa E402
from text_generation.inference_api import ( # noqa E402
InferenceAPIClient,
InferenceAPIAsyncClient,
)
__all__ = [
"Client",
"AsyncClient",
"InferenceAPIClient",
"InferenceAPIAsyncClient",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment