Unverified Commit 8deeaca4 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for prefix caching to the v3 router (#2392)

This change adds support for prefix caching to the v3 router. This
is broken up from the backend support to ease reviewing.

For now prefix caching is only enabled with `USE_PREFIX_CACHING=1`
in this case, the router will switch to `RadixAllocator`. This
allocator uses a radix trie to keep track of prefills that were
seen prior. If a new prefill is a prefix of a previously-seen
prefil, the router will send a request with `prefix_len>0`, which
can be used by the backend to decide to reuse KV blocks from the
cache, rather than recomputing them.

Even though backend support is not added in this PR, the backend
will still work with prefix caching enabled. The prefix lengths
are just ignored and not used.
parent b6bb1d51
......@@ -4045,6 +4045,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers",
......
......@@ -156,6 +156,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
......
......@@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None,
};
let batch = Batch {
......
......@@ -33,6 +33,7 @@ rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48"
tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
......
......@@ -35,15 +35,24 @@ impl BackendV3 {
window_size: Option<u32>,
speculate: u32,
) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
matches!(prefix_caching.as_str(), "true" | "1")
} else {
false
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else if attention == Attention::FlashInfer {
1
} else {
16
};
......@@ -51,6 +60,7 @@ impl BackendV3 {
let queue = Queue::new(
requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
max_batch_total_tokens,
......
use std::cmp::min;
use std::{cmp::min, sync::Arc};
use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub allocation_id: u64,
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
block_allocator: BlockAllocator,
/// Prefix that was cached and for which the KV does not have to
/// be recomputed.
pub prefix_len: u32,
pub(crate) block_allocator: Option<BlockAllocator>,
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
if let Some(block_allocator) = self.block_allocator.as_mut() {
block_allocator.free(self.blocks.clone(), self.allocation_id)
}
}
}
......@@ -24,6 +34,7 @@ impl BlockAllocator {
pub(crate) fn new(
max_batch_total_tokens: u32,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
) -> Self {
// Create channel
......@@ -33,6 +44,7 @@ impl BlockAllocator {
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
prefix_caching,
window_size,
receiver,
));
......@@ -42,28 +54,32 @@ impl BlockAllocator {
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
prefill_tokens,
response_sender,
})
.unwrap();
response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
response_receiver.await.unwrap().map(|mut allocation| {
allocation.block_allocator = Some(self.clone());
allocation
})
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap();
}
}
......@@ -71,21 +87,83 @@ impl BlockAllocator {
async fn block_allocator_task(
blocks: u32,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate {
tokens,
prefill_tokens,
response_sender,
} => {
response_sender
.send(allocator.allocate(tokens, prefill_tokens))
.unwrap();
}
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
allocation_id: u64,
},
Allocate {
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<BlockAllocation>>,
},
}
pub(crate) trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
......@@ -94,43 +172,39 @@ async fn block_allocator_task(
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
},
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}
......@@ -157,6 +157,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
......
......@@ -245,6 +245,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None,
};
let batch = Batch {
......
......@@ -2,6 +2,7 @@ mod backend;
mod block_allocator;
mod client;
mod queue;
mod radix;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
......
......@@ -46,6 +46,7 @@ impl Queue {
pub(crate) fn new(
requires_padding: bool,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
......@@ -57,6 +58,7 @@ impl Queue {
tokio::spawn(queue_task(
requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
max_batch_total_tokens,
......@@ -109,6 +111,7 @@ impl Queue {
async fn queue_task(
requires_padding: bool,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
......@@ -117,6 +120,7 @@ async fn queue_task(
let mut state = State::new(
requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
max_batch_total_tokens,
......@@ -176,12 +180,19 @@ impl State {
fn new(
requires_padding: bool,
block_size: u32,
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
) -> Self {
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new(
max_batch_total_tokens,
block_size,
prefix_caching,
window_size,
)
});
Self {
entries: VecDeque::with_capacity(128),
......@@ -305,7 +316,10 @@ impl State {
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => {
// Entry is over budget
// Add it back to the front
......@@ -331,11 +345,12 @@ impl State {
// Update entry
entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
let (blocks, slots, prefix_len) = match &block_allocation {
None => (Vec::new(), Vec::new(), 0),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
block_allocation.prefix_len,
),
};
......@@ -372,6 +387,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
prefix_len,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time
......@@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use tracing::info_span;
......@@ -492,6 +510,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0,
truncate: 0,
decoder_input_details: false,
......@@ -527,7 +546,7 @@ mod tests {
#[tokio::test]
async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
......@@ -543,7 +562,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
......@@ -551,7 +570,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -583,7 +602,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -603,7 +622,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2);
let mut state = State::new(false, 1, false, None, 0, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -636,14 +655,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
......@@ -651,7 +670,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -684,7 +703,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -700,7 +719,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -725,7 +744,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16);
let queue = Queue::new(false, 1, false, None, 2, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -744,7 +763,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _) = default_entry();
queue.append(entry);
......
This diff is collapsed.
......@@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
prefix_len: 0,
adapter_id: None,
})
.collect();
......
......@@ -4,21 +4,22 @@ package generate.v3;
service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
rpc Info(InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
rpc ServiceDiscovery(ServiceDiscoveryRequest)
returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
rpc Warmup(WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
rpc Prefill(PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
rpc Decode(DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
rpc Health(HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
......@@ -68,9 +69,7 @@ message InputChunk {
}
}
message Input {
repeated InputChunk chunks = 1;
}
message Input { repeated InputChunk chunks = 1; }
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
......@@ -136,6 +135,8 @@ message Request {
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
}
message Batch {
......@@ -214,7 +215,6 @@ message FilterBatchResponse {
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
......
......@@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use std::iter;
use std::sync::Arc;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc;
......@@ -115,13 +116,14 @@ impl Validation {
}
}
#[allow(clippy::type_complexity)]
#[instrument(skip(self, inputs))]
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
......@@ -156,8 +158,10 @@ impl Validation {
));
}
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens))
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
}
// Return inputs without validation
else {
......@@ -180,7 +184,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize);
}
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
}
}
......@@ -314,7 +323,7 @@ impl Validation {
.unwrap_or(Ok(None))?;
// Validate inputs
let (inputs, input_length, max_new_tokens) = self
let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
......@@ -391,6 +400,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
......@@ -707,6 +717,7 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)]
pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
......
......@@ -5,16 +5,29 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
ATTENTION = os.getenv("ATTENTION", "paged")
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment