Unverified Commit 8deeaca4 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for prefix caching to the v3 router (#2392)

This change adds support for prefix caching to the v3 router. This
is broken up from the backend support to ease reviewing.

For now prefix caching is only enabled with `USE_PREFIX_CACHING=1`
in this case, the router will switch to `RadixAllocator`. This
allocator uses a radix trie to keep track of prefills that were
seen prior. If a new prefill is a prefix of a previously-seen
prefil, the router will send a request with `prefix_len>0`, which
can be used by the backend to decide to reuse KV blocks from the
cache, rather than recomputing them.

Even though backend support is not added in this PR, the backend
will still work with prefix caching enabled. The prefix lengths
are just ignored and not used.
parent b6bb1d51
...@@ -4045,6 +4045,7 @@ dependencies = [ ...@@ -4045,6 +4045,7 @@ dependencies = [
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers",
......
...@@ -156,6 +156,7 @@ impl Client { ...@@ -156,6 +156,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
......
...@@ -244,6 +244,7 @@ impl Health for ShardedClient { ...@@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {
......
...@@ -33,6 +33,7 @@ rand = "0.8.5" ...@@ -33,6 +33,7 @@ rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true} tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
......
...@@ -35,15 +35,24 @@ impl BackendV3 { ...@@ -35,15 +35,24 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
matches!(prefix_caching.as_str(), "true" | "1")
} else {
false
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention attention
.parse() .parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else { } else {
Attention::Paged Attention::Paged
}; };
let block_size = if attention == Attention::FlashDecoding { let block_size = if attention == Attention::FlashDecoding {
256 256
} else if attention == Attention::FlashInfer {
1
} else { } else {
16 16
}; };
...@@ -51,6 +60,7 @@ impl BackendV3 { ...@@ -51,6 +60,7 @@ impl BackendV3 {
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
......
use std::cmp::min; use std::{cmp::min, sync::Arc};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
pub allocation_id: u64,
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
block_allocator: BlockAllocator,
/// Prefix that was cached and for which the KV does not have to
/// be recomputed.
pub prefix_len: u32,
pub(crate) block_allocator: Option<BlockAllocator>,
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) if let Some(block_allocator) = self.block_allocator.as_mut() {
block_allocator.free(self.blocks.clone(), self.allocation_id)
}
} }
} }
...@@ -24,6 +34,7 @@ impl BlockAllocator { ...@@ -24,6 +34,7 @@ impl BlockAllocator {
pub(crate) fn new( pub(crate) fn new(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
) -> Self { ) -> Self {
// Create channel // Create channel
...@@ -33,6 +44,7 @@ impl BlockAllocator { ...@@ -33,6 +44,7 @@ impl BlockAllocator {
tokio::spawn(block_allocator_task( tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size, max_batch_total_tokens / block_size,
block_size, block_size,
prefix_caching,
window_size, window_size,
receiver, receiver,
)); ));
...@@ -42,28 +54,32 @@ impl BlockAllocator { ...@@ -42,28 +54,32 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
response_receiver response_receiver.await.unwrap().map(|mut allocation| {
.await allocation.block_allocator = Some(self.clone());
.unwrap() allocation
.map(|(blocks, slots)| BlockAllocation { })
blocks,
slots,
block_allocator: self.clone(),
})
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap(); .unwrap();
} }
} }
...@@ -71,54 +87,29 @@ impl BlockAllocator { ...@@ -71,54 +87,29 @@ impl BlockAllocator {
async fn block_allocator_task( async fn block_allocator_task(
blocks: u32, blocks: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
// Block 0 is reserved for health checks let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
let mut free_blocks: Vec<u32> = (1..blocks).collect(); Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
} => { } => {
// Apply window size response_sender
let (required_blocks, repeats) = { .send(allocator.allocate(tokens, prefill_tokens))
let (tokens, repeats) = match window_size { .unwrap();
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
} }
} }
} }
...@@ -128,9 +119,92 @@ async fn block_allocator_task( ...@@ -128,9 +119,92 @@ async fn block_allocator_task(
enum BlockAllocatorCommand { enum BlockAllocatorCommand {
Free { Free {
blocks: Vec<u32>, blocks: Vec<u32>,
allocation_id: u64,
}, },
Allocate { Allocate {
tokens: u32, tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<BlockAllocation>>,
}, },
} }
pub(crate) trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}
...@@ -157,6 +157,7 @@ impl Client { ...@@ -157,6 +157,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
......
...@@ -245,6 +245,7 @@ impl Health for ShardedClient { ...@@ -245,6 +245,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {
......
...@@ -2,6 +2,7 @@ mod backend; ...@@ -2,6 +2,7 @@ mod backend;
mod block_allocator; mod block_allocator;
mod client; mod client;
mod queue; mod queue;
mod radix;
use crate::client::{ClientError, ShardedClient}; use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3; pub(crate) use backend::BackendV3;
......
...@@ -46,6 +46,7 @@ impl Queue { ...@@ -46,6 +46,7 @@ impl Queue {
pub(crate) fn new( pub(crate) fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
...@@ -57,6 +58,7 @@ impl Queue { ...@@ -57,6 +58,7 @@ impl Queue {
tokio::spawn(queue_task( tokio::spawn(queue_task(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
...@@ -109,6 +111,7 @@ impl Queue { ...@@ -109,6 +111,7 @@ impl Queue {
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
...@@ -117,6 +120,7 @@ async fn queue_task( ...@@ -117,6 +120,7 @@ async fn queue_task(
let mut state = State::new( let mut state = State::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
...@@ -176,12 +180,19 @@ impl State { ...@@ -176,12 +180,19 @@ impl State {
fn new( fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding) let block_allocator = (!requires_padding).then(|| {
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); BlockAllocator::new(
max_batch_total_tokens,
block_size,
prefix_caching,
window_size,
)
});
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
...@@ -305,7 +316,10 @@ impl State { ...@@ -305,7 +316,10 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator.allocate(tokens).await { match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
...@@ -331,11 +345,12 @@ impl State { ...@@ -331,11 +345,12 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation { let (blocks, slots, prefix_len) = match &block_allocation {
None => (Vec::new(), Vec::new()), None => (Vec::new(), Vec::new(), 0),
Some(block_allocation) => ( Some(block_allocation) => (
block_allocation.blocks.clone(), block_allocation.blocks.clone(),
block_allocation.slots.clone(), block_allocation.slots.clone(),
block_allocation.prefix_len,
), ),
}; };
...@@ -372,6 +387,7 @@ impl State { ...@@ -372,6 +387,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
...@@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters { ...@@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use super::*; use super::*;
use tracing::info_span; use tracing::info_span;
...@@ -492,6 +510,7 @@ mod tests { ...@@ -492,6 +510,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
...@@ -527,7 +546,7 @@ mod tests { ...@@ -527,7 +546,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
...@@ -543,7 +562,7 @@ mod tests { ...@@ -543,7 +562,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
...@@ -551,7 +570,7 @@ mod tests { ...@@ -551,7 +570,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
...@@ -583,7 +602,7 @@ mod tests { ...@@ -583,7 +602,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
...@@ -603,7 +622,7 @@ mod tests { ...@@ -603,7 +622,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2); let mut state = State::new(false, 1, false, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
...@@ -636,14 +655,14 @@ mod tests { ...@@ -636,14 +655,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
...@@ -651,7 +670,7 @@ mod tests { ...@@ -651,7 +670,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -684,7 +703,7 @@ mod tests { ...@@ -684,7 +703,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -700,7 +719,7 @@ mod tests { ...@@ -700,7 +719,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -725,7 +744,7 @@ mod tests { ...@@ -725,7 +744,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16); let queue = Queue::new(false, 1, false, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -744,7 +763,7 @@ mod tests { ...@@ -744,7 +763,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);
......
This diff is collapsed.
...@@ -157,6 +157,7 @@ async fn prefill( ...@@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();
......
...@@ -3,22 +3,23 @@ syntax = "proto3"; ...@@ -3,22 +3,23 @@ syntax = "proto3";
package generate.v3; package generate.v3;
service TextGenerationService { service TextGenerationService {
/// Model Info /// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {} rpc Info(InfoRequest) returns (InfoResponse) {}
/// Service discovery /// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery(ServiceDiscoveryRequest)
/// Empties batch cache returns (ServiceDiscoveryResponse) {}
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); /// Empties batch cache
/// Remove requests from a cached batch rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); /// Remove requests from a cached batch
/// Warmup the model and compute max cache size rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Warmup the model and compute max cache size
/// Prefill batch and decode first token rpc Warmup(WarmupRequest) returns (WarmupResponse);
rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Prefill batch and decode first token
/// Decode token for a list of prefilled batches rpc Prefill(PrefillRequest) returns (PrefillResponse);
rpc Decode (DecodeRequest) returns (DecodeResponse); /// Decode token for a list of prefilled batches
/// Health check rpc Decode(DecodeRequest) returns (DecodeResponse);
rpc Health (HealthRequest) returns (HealthResponse); /// Health check
rpc Health(HealthRequest) returns (HealthResponse);
} }
message HealthRequest {} message HealthRequest {}
...@@ -28,240 +29,239 @@ message HealthResponse {} ...@@ -28,240 +29,239 @@ message HealthResponse {}
message InfoRequest {} message InfoRequest {}
message InfoResponse { message InfoResponse {
bool requires_padding = 1; bool requires_padding = 1;
string dtype = 2; string dtype = 2;
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
} }
/// Empty request /// Empty request
message ServiceDiscoveryRequest {} message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse { message ServiceDiscoveryResponse {
/// Other shards urls /// Other shards urls
repeated string urls = 1; repeated string urls = 1;
} }
message ClearCacheRequest { message ClearCacheRequest {
/// Optional batch id /// Optional batch id
optional uint64 id = 1; optional uint64 id = 1;
} }
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
message Image { message Image {
/// Binary image data. /// Binary image data.
bytes data = 1; bytes data = 1;
/// Image MIME type. /// Image MIME type.
string mimetype = 2; string mimetype = 2;
} }
message InputChunk { message InputChunk {
oneof chunk { oneof chunk {
/// Plain text data /// Plain text data
string text = 1; string text = 1;
/// Image data /// Image data
Image image = 2; Image image = 2;
} }
} }
message Input { message Input { repeated InputChunk chunks = 1; }
repeated InputChunk chunks = 1;
}
enum GrammarType { enum GrammarType {
GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2; GRAMMAR_TYPE_REGEX = 2;
} }
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
/// restricting to the k highest probability elements /// restricting to the k highest probability elements
uint32 top_k = 2; uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3; float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4; float typical_p = 4;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 5; bool do_sample = 5;
/// random seed for sampling /// random seed for sampling
uint64 seed = 6; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 7; float repetition_penalty = 7;
/// frequency penalty /// frequency penalty
float frequency_penalty = 9; float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty) /// grammar (applied if not empty)
string grammar = 10; string grammar = 10;
/// grammar type /// grammar type
GrammarType grammar_type = 11; GrammarType grammar_type = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
/// Maximum number of generated tokens /// Maximum number of generated tokens
uint32 max_new_tokens = 1; uint32 max_new_tokens = 1;
/// Optional stopping sequences /// Optional stopping sequences
repeated string stop_sequences = 2; repeated string stop_sequences = 2;
/// Ignore end of sequence token /// Ignore end of sequence token
/// used for benchmarking /// used for benchmarking
bool ignore_eos_token = 3; bool ignore_eos_token = 3;
} }
message Request { message Request {
/// Request ID /// Request ID
uint64 id = 1; uint64 id = 1;
/// The generation context as chunks /// The generation context as chunks
Input input_chunks = 8; Input input_chunks = 8;
/// The generation context, stringified input_chunks /// The generation context, stringified input_chunks
string inputs = 2; string inputs = 2;
/// Context truncation /// Context truncation
uint32 truncate = 3; uint32 truncate = 3;
/// Next Token Chooser Parameters /// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4; NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters /// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs /// Return prefill logprobs
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// Paged attention blocks /// Paged attention blocks
repeated uint32 blocks = 9; repeated uint32 blocks = 9;
/// Paged attention slots /// Paged attention slots
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
} }
message Batch { message Batch {
/// Batch ID /// Batch ID
uint64 id = 1; uint64 id = 1;
/// Individual requests /// Individual requests
repeated Request requests = 2; repeated Request requests = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks /// Maximum number of Paged Attention blocks
uint32 max_blocks = 5; uint32 max_blocks = 5;
} }
message CachedBatch { message CachedBatch {
/// Batch ID /// Batch ID
uint64 id = 1; uint64 id = 1;
/// Individual requests ids /// Individual requests ids
repeated uint64 request_ids = 2; repeated uint64 request_ids = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
} }
enum FinishReason { enum FinishReason {
FINISH_REASON_LENGTH = 0; FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2; FINISH_REASON_STOP_SEQUENCE = 2;
} }
message GeneratedText { message GeneratedText {
/// Output /// Output
string text = 1; string text = 1;
/// Number of generated tokens /// Number of generated tokens
uint32 generated_tokens = 2; uint32 generated_tokens = 2;
/// Finish reason /// Finish reason
FinishReason finish_reason = 3; FinishReason finish_reason = 3;
/// Seed /// Seed
optional uint64 seed = 4; optional uint64 seed = 4;
} }
message Tokens { message Tokens {
/// Token IDs /// Token IDs
repeated uint32 ids = 1; repeated uint32 ids = 1;
/// Logprobs /// Logprobs
repeated float logprobs = 2; repeated float logprobs = 2;
/// tokens /// tokens
repeated string texts = 3; repeated string texts = 3;
/// special /// special
repeated bool is_special = 4; repeated bool is_special = 4;
} }
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
/// Prefill tokens (optional) /// Prefill tokens (optional)
Tokens prefill_tokens = 2; Tokens prefill_tokens = 2;
Tokens tokens = 3; Tokens tokens = 3;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 4; optional GeneratedText generated_text = 4;
/// Top tokens /// Top tokens
repeated Tokens top_tokens = 5; repeated Tokens top_tokens = 5;
} }
message FilterBatchRequest { message FilterBatchRequest {
/// Batch ID /// Batch ID
uint64 batch_id = 1; uint64 batch_id = 1;
/// Requests to keep /// Requests to keep
repeated uint64 request_ids = 2; repeated uint64 request_ids = 2;
} }
message FilterBatchResponse { message FilterBatchResponse {
/// Filtered Batch (cached) /// Filtered Batch (cached)
CachedBatch batch = 1; CachedBatch batch = 1;
} }
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
} }
message PrefillResponse { message PrefillResponse {
/// Generation /// Generation
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds /// Forward elapsed time in nanoseconds
uint64 forward_ns = 3; uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds /// Decode elapsed time in nanoseconds
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
} }
message DecodeRequest { message DecodeRequest {
/// Cached batches /// Cached batches
repeated CachedBatch batches = 1; repeated CachedBatch batches = 1;
} }
message DecodeResponse { message DecodeResponse {
/// Decodes /// Decodes
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds /// Forward elapsed time in nanoseconds
uint64 forward_ns = 3; uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds /// Decode elapsed time in nanoseconds
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds /// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6; optional uint64 concat_ns = 6;
} }
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;
uint32 max_input_length = 2; uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3; uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4; uint32 max_total_tokens = 4;
} }
message WarmupResponse { message WarmupResponse {
/// Maximum number of tokens supported by the model /// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1; optional uint32 max_supported_total_tokens = 1;
} }
...@@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; ...@@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
...@@ -115,13 +116,14 @@ impl Validation { ...@@ -115,13 +116,14 @@ impl Validation {
} }
} }
#[allow(clippy::type_complexity)]
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( async fn validate_input(
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
...@@ -156,8 +158,10 @@ impl Validation { ...@@ -156,8 +158,10 @@ impl Validation {
)); ));
} }
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
...@@ -180,7 +184,12 @@ impl Validation { ...@@ -180,7 +184,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
} }
} }
...@@ -314,7 +323,7 @@ impl Validation { ...@@ -314,7 +323,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
...@@ -391,6 +400,7 @@ impl Validation { ...@@ -391,6 +400,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
...@@ -707,6 +717,7 @@ pub struct ValidStoppingParameters { ...@@ -707,6 +717,7 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>, pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,
......
...@@ -5,16 +5,29 @@ from typing import Dict, Optional ...@@ -5,16 +5,29 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
ATTENTION = os.getenv("ATTENTION", "paged") PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment