Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
...@@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ...@@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0 ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1 ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
......
...@@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo ...@@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV CUDA_GRAPHS=0 ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]
...@@ -158,7 +158,8 @@ impl Client { ...@@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
...@@ -217,8 +218,13 @@ impl Client { ...@@ -217,8 +218,13 @@ impl Client {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok(( Ok((
response.generations, response.generations,
......
...@@ -134,11 +134,12 @@ impl ShardedClient { ...@@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
...@@ -245,7 +246,8 @@ impl Health for ShardedClient { ...@@ -245,7 +246,8 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, cache_len: 0,
chunk_len: None,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {
...@@ -255,7 +257,7 @@ impl Health for ShardedClient { ...@@ -255,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2, max_tokens: 2,
max_blocks: 1, max_blocks: 1,
}; };
self.clone().prefill(batch).await?; self.clone().prefill(batch, None).await?;
Ok(()) Ok(())
} }
} }
...@@ -6,7 +6,7 @@ use nohash_hasher::IntMap; ...@@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
...@@ -36,18 +36,14 @@ impl BackendV2 { ...@@ -36,18 +36,14 @@ impl BackendV2 {
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
attention let block_size = match attention.as_str() {
.parse() "flashinfer" => 1,
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) "flashdecoding" => 256,
} else { "paged" => 16,
Attention::Paged _ => unreachable!(),
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
}; };
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
......
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic /// Batching and inference logic
use crate::client::{
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
};
use crate::queue::{Entry, Queue}; use crate::queue::{Entry, Queue};
use async_trait::async_trait; use async_trait::async_trait;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
...@@ -31,27 +33,22 @@ impl BackendV3 { ...@@ -31,27 +33,22 @@ impl BackendV3 {
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
requires_padding: bool, shard_info: InfoResponse,
window_size: Option<u32>,
speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if shard_info.support_chunking {
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); }
let attention: String = std::env::var("ATTENTION").expect("attention env var");
let attention: Attention = attention let block_size = shard_info.block_size;
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let queue = Queue::new( let queue = Queue::new(
requires_padding, shard_info.requires_padding,
block_size, block_size,
prefix_caching, shard_info.use_prefix_caching,
window_size, shard_info.window_size,
speculate, shard_info.speculate,
max_batch_total_tokens, max_batch_total_tokens,
shard_info.support_chunking,
); );
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
...@@ -63,6 +60,7 @@ impl BackendV3 { ...@@ -63,6 +60,7 @@ impl BackendV3 {
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
shard_info.support_chunking,
queue.clone(), queue.clone(),
batching_task_notifier.clone(), batching_task_notifier.clone(),
)); ));
...@@ -127,6 +125,7 @@ pub(crate) async fn batching_task( ...@@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue, queue: Queue,
notifier: Arc<Notify>, notifier: Arc<Notify>,
) { ) {
...@@ -147,7 +146,7 @@ pub(crate) async fn batching_task( ...@@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
) )
.await .await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries) let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
...@@ -158,28 +157,44 @@ pub(crate) async fn batching_task( ...@@ -158,28 +157,44 @@ pub(crate) async fn batching_task(
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens { let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small let (min_size, max_size, prefill_token_budget) = if support_chunking {
None // Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
let prefill_token_budget =
max_batch_prefill_tokens.saturating_sub(current_tokens);
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
// bound, making min_size useless.
(None, None, prefill_token_budget)
} else { } else {
// Minimum batch size let min_size = if waiting_tokens >= max_waiting_tokens {
// TODO: temporarily disable to avoid incorrect deallocation + // If we didn't onboard any new requests since >= max_waiting_tokens, we try
// reallocation when using prefix caching. // to add a new batch even though its size might be small
Some((batch_size as f32 * waiting_served_ratio).floor() as usize) None
}; } else {
// Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let max_size =
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
(min_size, max_size, max_batch_prefill_tokens)
};
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await .await
{ {
// Tracking metrics // Tracking metrics
...@@ -187,31 +202,45 @@ pub(crate) async fn batching_task( ...@@ -187,31 +202,45 @@ pub(crate) async fn batching_task(
metrics::counter!("tgi_batch_concat", "reason" => "backpressure") metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1); .increment(1);
} else { } else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") let counter = if support_chunking {
.increment(1); metrics::counter!("tgi_batch_concat", "reason" => "chunking")
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
};
counter.increment(1);
} }
let cached_batch = if support_chunking {
entries.iter_mut().for_each(|(_, entry)| { // Concat current batch to the new one
// Create a new span to add the info that this entry is waiting batches.pop()
// because a new batch is being computed } else {
let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); // Request are waiting only if we don't support chunking
// Add relationships entries.iter_mut().for_each(|(_, entry)| {
span.follows_from(&entry_waiting_span); // Create a new span to add the info that this entry is waiting
entry_waiting_span.follows_from(&span); // because a new batch is being computed
// Update entry let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
entry.temp_span = Some(entry_waiting_span); // Add relationships
}); span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
None
};
entries.extend(new_entries);
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) let new_cached_batch =
.instrument(span) prefill(&mut client, new_batch, cached_batch, &mut entries)
.await; .instrument(span)
.await;
// Reset waiting counter // Reset waiting counter
waiting_tokens = 1; waiting_tokens = 1;
// Extend current batch with the new batch // Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch { if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch); batches.push(new_cached_batch);
} else if support_chunking {
// New cached batch is empty, no work left
break;
} }
} }
...@@ -244,13 +273,14 @@ pub(crate) async fn batching_task( ...@@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
async fn prefill( async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now(); let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
...@@ -259,6 +289,10 @@ async fn prefill( ...@@ -259,6 +289,10 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64()); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
......
...@@ -158,7 +158,8 @@ impl Client { ...@@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
...@@ -217,13 +218,23 @@ impl Client { ...@@ -217,13 +218,23 @@ impl Client {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok(( Ok((
response.generations, response.generations,
response.batch, response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), PrefillTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
)) ))
} }
...@@ -252,14 +263,16 @@ impl Client { ...@@ -252,14 +263,16 @@ impl Client {
} }
pub struct PrefillTimings { pub struct PrefillTimings {
pub concat: Option<Duration>,
pub forward: Duration, pub forward: Duration,
pub decode: Duration, pub decode: Duration,
pub total: Duration, pub total: Duration,
} }
impl PrefillTimings { impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self { Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns), forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns), decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns), total: Duration::from_nanos(total_ns),
......
...@@ -29,15 +29,6 @@ pub trait Health { ...@@ -29,15 +29,6 @@ pub trait Health {
async fn model_health(&self) -> Result<()>; async fn model_health(&self) -> Result<()>;
} }
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientError { pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")] #[error("Could not connect to Text Generation server: {0}")]
......
use crate::client::{ClientError, Result}; use crate::client::Health;
/// Multi shard Client /// Multi shard Client
use crate::client::{Health, ShardInfo}; use crate::client::{ClientError, Result};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{ use crate::client::{
...@@ -49,13 +49,13 @@ impl ShardedClient { ...@@ -49,13 +49,13 @@ impl ShardedClient {
/// Get the model info /// Get the model info
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> { pub async fn info(&mut self) -> Result<InfoResponse> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| client.info()) .map(|client| client.info())
.collect(); .collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from) join_all(futures).await.pop().unwrap()
} }
/// GRPC health check /// GRPC health check
...@@ -135,11 +135,12 @@ impl ShardedClient { ...@@ -135,11 +135,12 @@ impl ShardedClient {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
...@@ -194,18 +195,6 @@ impl ShardedClient { ...@@ -194,18 +195,6 @@ impl ShardedClient {
} }
} }
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait] #[async_trait]
impl Health for ShardedClient { impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> { async fn device_health(&self) -> Result<()> {
...@@ -246,8 +235,9 @@ impl Health for ShardedClient { ...@@ -246,8 +235,9 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, cache_len: 0,
adapter_id: None, adapter_id: None,
chunk_len: None,
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,
...@@ -256,7 +246,7 @@ impl Health for ShardedClient { ...@@ -256,7 +246,7 @@ impl Health for ShardedClient {
max_tokens: 2, max_tokens: 2,
max_blocks: 1, max_blocks: 1,
}; };
self.clone().prefill(batch).await?; self.clone().prefill(batch, None).await?;
Ok(()) Ok(())
} }
} }
...@@ -29,6 +29,14 @@ pub struct BackendInfo { ...@@ -29,6 +29,14 @@ pub struct BackendInfo {
pub max_waiting_tokens: usize, pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>, pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
#[schema(example = "false")]
pub prefix_caching: bool,
#[schema(example = "flashinfer")]
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
...@@ -110,6 +118,10 @@ pub async fn connect_backend( ...@@ -110,6 +118,10 @@ pub async fn connect_backend(
model_device_type: shard_info.device_type.clone(), model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(), model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize, speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
prefix_caching: shard_info.use_prefix_caching,
attention_impl: shard_info.attention_impl.clone(),
block_size: shard_info.block_size,
}; };
let backend = BackendV3::new( let backend = BackendV3::new(
...@@ -119,9 +131,7 @@ pub async fn connect_backend( ...@@ -119,9 +131,7 @@ pub async fn connect_backend(
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
shard_info.requires_padding, shard_info,
shard_info.window_size,
shard_info.speculate,
); );
tracing::info!("Using backend V3"); tracing::info!("Using backend V3");
......
...@@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> { ...@@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size { if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 { if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
...@@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> { ...@@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
} }
} }
let (backend, _backend_info) = connect_backend( let (backend, backend_info) = connect_backend(
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
master_shard_uds_path, master_shard_uds_path,
...@@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> { ...@@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
) )
.await?; .await?;
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
// Run server // Run server
server::run( server::run(
backend, backend,
......
...@@ -4,7 +4,7 @@ use crate::client::{ ...@@ -4,7 +4,7 @@ use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min}; use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
...@@ -50,6 +50,7 @@ impl Queue { ...@@ -50,6 +50,7 @@ impl Queue {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
...@@ -62,6 +63,7 @@ impl Queue { ...@@ -62,6 +63,7 @@ impl Queue {
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
queue_receiver, queue_receiver,
)); ));
...@@ -87,6 +89,10 @@ impl Queue { ...@@ -87,6 +89,10 @@ impl Queue {
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
if prefill_token_budget == 0 || token_budget == 0 {
return None;
};
// Create response channel // Create response channel
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state // Send next batch command to the background task managing the state
...@@ -108,6 +114,7 @@ impl Queue { ...@@ -108,6 +114,7 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
...@@ -115,6 +122,7 @@ async fn queue_task( ...@@ -115,6 +122,7 @@ async fn queue_task(
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new( let mut state = State::new(
...@@ -124,6 +132,7 @@ async fn queue_task( ...@@ -124,6 +132,7 @@ async fn queue_task(
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
); );
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
...@@ -166,12 +175,14 @@ struct State { ...@@ -166,12 +175,14 @@ struct State {
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
/// Whether the model allow the prefill chunking
/// If it does, the last request in the batch will be split to exactly match the prefill
/// token budget
support_chunking: bool,
/// Paged Attention Block Allocation /// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>, block_allocator: Option<BlockAllocator>,
} }
...@@ -184,6 +195,7 @@ impl State { ...@@ -184,6 +195,7 @@ impl State {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding).then(|| { let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new( BlockAllocator::new(
...@@ -199,8 +211,8 @@ impl State { ...@@ -199,8 +211,8 @@ impl State {
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
block_size, block_size,
window_size,
speculate, speculate,
support_chunking,
block_allocator, block_allocator,
} }
} }
...@@ -287,32 +299,7 @@ impl State { ...@@ -287,32 +299,7 @@ impl State {
} }
None None
} }
Some(_block_allocator) => { Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
// If users wants the prefill logprobs, we cannot reuse the cache. // If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree. // So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details { let input_ids = if entry.request.decoder_input_details {
...@@ -321,10 +308,73 @@ impl State { ...@@ -321,10 +308,73 @@ impl State {
entry.request.input_ids.clone() entry.request.input_ids.clone()
}; };
Some((tokens, input_ids)) let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
block_allocation.prefix_len -= 1;
}
block_allocation
}
};
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
if prefill_tokens + postfix_len > prefill_token_budget {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
if chunk_len > 0 {
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
} else {
// We cannot prefill even one token for this entry
// Add it back to the queue
self.entries.push_front((id, entry));
}
tracing::debug!(
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
prefill_tokens + postfix_len
);
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
// Add it back to the front
tracing::debug!(
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
Some(block_allocation)
} }
}; };
batch.push((id, entry, block_allocation)); batch.push((id, entry, block_allocation, None));
if Some(batch.len()) == max_size { if Some(batch.len()) == max_size {
break; break;
} }
...@@ -342,7 +392,7 @@ impl State { ...@@ -342,7 +392,7 @@ impl State {
// Batch is too small // Batch is too small
if batch.len() < min_size { if batch.len() < min_size {
// Add back entries to the queue in the correct order // Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() { for (id, entry, _, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
} }
return None; return None;
...@@ -353,29 +403,7 @@ impl State { ...@@ -353,29 +403,7 @@ impl State {
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch { for (id, mut entry, block_allocation, chunk_len) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} else {
None
};
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships // Add relationships
...@@ -427,8 +455,9 @@ impl State { ...@@ -427,8 +455,9 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
prefix_len, cache_len: prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
chunk_len,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
...@@ -436,12 +465,6 @@ impl State { ...@@ -436,12 +465,6 @@ impl State {
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
} }
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size // Final batch size
let size = batch_requests.len() as u32; let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size); next_batch_span.record("batch_size", size);
...@@ -531,7 +554,7 @@ mod tests { ...@@ -531,7 +554,7 @@ mod tests {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])), input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 1,
add_special_tokens: true, add_special_tokens: true,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
...@@ -567,7 +590,7 @@ mod tests { ...@@ -567,7 +590,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
...@@ -583,7 +606,7 @@ mod tests { ...@@ -583,7 +606,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
...@@ -591,7 +614,7 @@ mod tests { ...@@ -591,7 +614,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
...@@ -623,7 +646,7 @@ mod tests { ...@@ -623,7 +646,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
...@@ -643,7 +666,7 @@ mod tests { ...@@ -643,7 +666,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, false, None, 0, 2); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
...@@ -676,14 +699,14 @@ mod tests { ...@@ -676,14 +699,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
...@@ -691,7 +714,7 @@ mod tests { ...@@ -691,7 +714,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -724,7 +747,7 @@ mod tests { ...@@ -724,7 +747,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -740,7 +763,7 @@ mod tests { ...@@ -740,7 +763,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -765,7 +788,7 @@ mod tests { ...@@ -765,7 +788,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, false, None, 2, 16); let queue = Queue::new(true, 1, false, None, 2, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -784,7 +807,7 @@ mod tests { ...@@ -784,7 +807,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);
......
...@@ -158,7 +158,8 @@ async fn prefill( ...@@ -158,7 +158,8 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, cache_len: 0,
chunk_len: None,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();
...@@ -173,7 +174,7 @@ async fn prefill( ...@@ -173,7 +174,7 @@ async fn prefill(
// Run prefill // Run prefill
let start_time = Instant::now(); let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?; let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
// Get latency // Get latency
let latency = start_time.elapsed(); let latency = start_time.elapsed();
......
...@@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.clear_cache(None) .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
// Run app // Run app
......
...@@ -9,13 +9,16 @@ import subprocess ...@@ -9,13 +9,16 @@ import subprocess
import sys import sys
import tempfile import tempfile
import time import time
from typing import Dict, List, Optional
import docker import docker
import pytest import pytest
import base64
from pathlib import Path
from typing import Dict, List, Optional
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import ( from text_generation.types import (
BestOfSequence, BestOfSequence,
...@@ -403,6 +406,7 @@ def launcher(event_loop): ...@@ -403,6 +406,7 @@ def launcher(event_loop):
print(" ".join(args), file=sys.stderr) print(" ".join(args), file=sys.stderr)
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
env["PREFILL_CHUNKING"] = "1"
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
...@@ -501,6 +505,7 @@ def launcher(event_loop): ...@@ -501,6 +505,7 @@ def launcher(event_loop):
env = { env = {
"LOG_LEVEL": "info,text_generation_router=debug", "LOG_LEVEL": "info,text_generation_router=debug",
"PREFILL_CHUNKING": "1",
} }
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
...@@ -642,3 +647,22 @@ def generate_multi(): ...@@ -642,3 +647,22 @@ def generate_multi():
return responses return responses
return generate_load_inner return generate_load_inner
# TODO fix the server parsser to count inline image tokens correctly
@pytest.fixture
def chicken():
path = Path(__file__).parent / "images" / "chicken_on_money.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture
def cow_beach():
path = Path(__file__).parent / "images" / "cow_beach.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
import pytest import pytest
import base64
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle): ...@@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client return flash_pali_gemma_handle.client
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
cow = get_cow_beach() inputs = f"![]({cow_beach})Where is the cow standing?\n"
inputs = f"![]({cow})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach" assert response.generated_text == "beach"
...@@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): ...@@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma_two_images(
chicken = get_chicken() flash_pali_gemma, response_snapshot, chicken, cow_beach
cow_beach = get_cow_beach() ):
response = await flash_pali_gemma.generate( response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n", f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20, max_new_tokens=20,
......
...@@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle): ...@@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel): class Weather(BaseModel):
unit: str unit: str
temperature: List[int] temperature: List[int]
......
import pytest import pytest
import base64
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -16,22 +15,8 @@ async def idefics(idefics_handle): ...@@ -16,22 +15,8 @@ async def idefics(idefics_handle):
return idefics_handle.client return idefics_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot): async def test_idefics(idefics, response_snapshot, chicken):
chicken = get_chicken()
response = await idefics.generate( response = await idefics.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10, max_new_tokens=10,
...@@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot): ...@@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot): async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await idefics.generate( response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:", f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20, max_new_tokens=20,
...@@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot): ...@@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot): async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
idefics, idefics,
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
......
import pytest import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle): ...@@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot): async def test_flash_idefics2_next_simple(
chicken = get_chicken() flash_idefics2_next, response_snapshot, chicken
):
response = await flash_idefics2_next.generate( response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:", f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10, max_new_tokens=10,
...@@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot ...@@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): async def test_flash_idefics2_two_images(
chicken = get_chicken() flash_idefics2_next, response_snapshot, chicken, cow_beach
cow_beach = get_cow_beach() ):
response = await flash_idefics2_next.generate( response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:", f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20, max_new_tokens=20,
...@@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap ...@@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_next_load( async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot flash_idefics2_next, generate_load, response_snapshot, chicken
): ):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
flash_idefics2_next, flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:", f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
......
import pytest import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle): ...@@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
chicken = get_chicken()
response = await flash_llava_next.generate( response = await flash_llava_next.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10, max_new_tokens=10,
...@@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): ...@@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_load( async def test_flash_llava_next_load(
flash_llava_next, generate_load, response_snapshot flash_llava_next, generate_load, response_snapshot, chicken
): ):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
flash_llava_next, flash_llava_next,
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment