Unverified Commit 8aece3bd authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: move allocation logic to rust (#1835)

Close #2007
parent 9ffe1f1e
...@@ -26,7 +26,12 @@ incremental = true ...@@ -26,7 +26,12 @@ incremental = true
inherits = "release" inherits = "release"
debug = 1 debug = 1
incremental = true incremental = true
panic = "abort"
[profile.release-opt]
inherits = "release"
debug = 0
incremental = false
lto = "fat" lto = "fat"
opt-level = 3 opt-level = 3
codegen-units = 1 codegen-units = 1
panic = "abort"
...@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ ...@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
...@@ -33,7 +33,7 @@ COPY proto proto ...@@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
...@@ -226,11 +226,11 @@ RUN cd server && \ ...@@ -226,11 +226,11 @@ RUN cd server && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
......
...@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ ...@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
...@@ -33,7 +33,7 @@ COPY proto proto ...@@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
...@@ -193,11 +193,11 @@ RUN cd server && \ ...@@ -193,11 +193,11 @@ RUN cd server && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base as sagemaker
......
...@@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ ...@@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
...@@ -32,7 +32,7 @@ COPY proto proto ...@@ -32,7 +32,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
...@@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp ...@@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image # Final image
FROM base FROM base
......
...@@ -155,6 +155,8 @@ async fn prefill( ...@@ -155,6 +155,8 @@ async fn prefill(
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
}) })
.collect(); .collect();
...@@ -163,6 +165,7 @@ async fn prefill( ...@@ -163,6 +165,7 @@ async fn prefill(
requests, requests,
size: batch_size, size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length), max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
}; };
// Run prefill // Run prefill
......
...@@ -130,6 +130,10 @@ message Request { ...@@ -130,6 +130,10 @@ message Request {
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
} }
message Batch { message Batch {
...@@ -141,6 +145,8 @@ message Batch { ...@@ -141,6 +145,8 @@ message Batch {
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
} }
message CachedBatch { message CachedBatch {
......
...@@ -153,6 +153,9 @@ impl Client { ...@@ -153,6 +153,9 @@ impl Client {
}), }),
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
...@@ -187,7 +190,8 @@ impl Client { ...@@ -187,7 +190,8 @@ impl Client {
id: 0, id: 0,
size: requests.len() as u32, size: requests.len() as u32,
requests, requests,
max_tokens: 0, max_tokens: max_input_length,
max_blocks: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest {
......
...@@ -241,12 +241,16 @@ impl Health for ShardedClient { ...@@ -241,12 +241,16 @@ impl Health for ShardedClient {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0, top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,
requests: vec![liveness_request], requests: vec![liveness_request],
size: 1, size: 1,
max_tokens: 2, max_tokens: 2,
max_blocks: 1,
}; };
self.clone().prefill(batch).await?; self.clone().prefill(batch).await?;
Ok(()) Ok(())
......
use std::cmp::min;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
block_allocator: BlockAllocator,
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocator {
/// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
}
impl BlockAllocator {
pub(crate) fn new(
max_batch_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
) -> Self {
// Create channel
let (sender, receiver) = mpsc::unbounded_channel();
// Launch background queue task
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
window_size,
receiver,
));
Self {
block_allocator: sender,
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
response_sender,
})
.unwrap();
response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
})
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.unwrap();
}
}
async fn block_allocator_task(
blocks: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
response_sender,
} => {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
},
}
mod block_allocator;
mod queue; mod queue;
mod scheduler; mod scheduler;
......
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{ use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::{max, min};
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::v3::{ use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use text_generation_client::{ChunksToString, Input}; use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
/// Queue entry /// Queue entry
#[derive(Debug)] #[derive(Debug)]
...@@ -28,6 +31,8 @@ pub(crate) struct Entry { ...@@ -28,6 +31,8 @@ pub(crate) struct Entry {
pub queue_time: Instant, pub queue_time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
} }
/// Request Queue /// Request Queue
...@@ -43,6 +48,7 @@ impl Queue { ...@@ -43,6 +48,7 @@ impl Queue {
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
...@@ -53,12 +59,14 @@ impl Queue { ...@@ -53,12 +59,14 @@ impl Queue {
block_size, block_size,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens,
queue_receiver, queue_receiver,
)); ));
Self { queue_sender } Self { queue_sender }
} }
/// Append an entry to the queue
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) { pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state // Send append command to the background task managing the state
...@@ -103,9 +111,16 @@ async fn queue_task( ...@@ -103,9 +111,16 @@ async fn queue_task(
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size, window_size, speculate); let mut state = State::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
...@@ -120,12 +135,14 @@ async fn queue_task( ...@@ -120,12 +135,14 @@ async fn queue_task(
token_budget, token_budget,
response_sender, response_sender,
span, span,
} => span.in_scope(|| { } => {
let next_batch = let next_batch = state
state.next_batch(min_size, max_size, prefill_token_budget, token_budget); .next_batch(min_size, max_size, prefill_token_budget, token_budget)
.instrument(span)
.await;
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }
} }
} }
} }
...@@ -142,9 +159,6 @@ struct State { ...@@ -142,9 +159,6 @@ struct State {
/// Id of the next batch /// Id of the next batch
next_batch_id: u64, next_batch_id: u64,
/// Whether the model is using padding
requires_padding: bool,
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
...@@ -153,6 +167,9 @@ struct State { ...@@ -153,6 +167,9 @@ struct State {
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
} }
impl State { impl State {
...@@ -161,15 +178,19 @@ impl State { ...@@ -161,15 +178,19 @@ impl State {
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding,
block_size, block_size,
window_size, window_size,
speculate, speculate,
block_allocator,
} }
} }
...@@ -185,7 +206,7 @@ impl State { ...@@ -185,7 +206,7 @@ impl State {
} }
// Get the next batch // Get the next batch
fn next_batch( async fn next_batch(
&mut self, &mut self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>, max_size: Option<usize>,
...@@ -220,9 +241,10 @@ impl State { ...@@ -220,9 +241,10 @@ impl State {
let mut max_input_length = 0; let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0; let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
// Pop entries starting from the front of the queue // Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() { 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
...@@ -231,43 +253,67 @@ impl State { ...@@ -231,43 +253,67 @@ impl State {
continue; continue;
} }
if self.requires_padding { let block_allocation = match &self.block_allocator {
// We pad to max input length in the Python shards None => {
// We need to take these padding tokens into the equation // We pad to max input length in the Python shards
max_input_length = max_input_length.max(entry.request.input_length); // We need to take these padding tokens into the equation
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length max_input_length = max_input_length.max(entry.request.input_length);
} else { prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
// pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1) decode_tokens += entry.request.stopping_parameters.max_new_tokens;
/ self.block_size) let total_tokens = prefill_tokens + decode_tokens + self.speculate;
* self.block_size;
} if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
// Entry is over budget
if self.requires_padding { // Add it back to the front
decode_tokens += entry.request.stopping_parameters.max_new_tokens; tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
} else { self.entries.push_front((id, entry));
let max_new_tokens = match self.window_size { break 'entry_loop;
None => entry.request.stopping_parameters.max_new_tokens, }
Some(window_size) => min( None
window_size.saturating_sub(entry.request.input_length), }
entry.request.stopping_parameters.max_new_tokens, Some(block_allocator) => {
), prefill_tokens += entry.request.input_length;
}; let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
// pad to block size Some(window_size) => min(
decode_tokens += window_size.saturating_sub(entry.request.input_length),
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; entry.request.stopping_parameters.max_new_tokens,
} ),
};
if prefill_tokens > prefill_token_budget decode_tokens += max_new_tokens;
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{ if prefill_tokens > prefill_token_budget
// Entry is over budget || (prefill_tokens + decode_tokens + self.speculate) > token_budget
// Add it back to the front {
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); // Entry is over budget
self.entries.push_front((id, entry)); // Add it back to the front
break; tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
} self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
}
};
tracing::debug!("Accepting entry"); tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
...@@ -278,13 +324,23 @@ impl State { ...@@ -278,13 +324,23 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
),
};
entry.block_allocation = block_allocation;
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.decoder_input_details, prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.chunks_to_string(),
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: entry.request.inputs.clone(), chunks: entry.request.inputs.clone(),
}), }),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(NextTokenChooserParameters::from( parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(), entry.request.parameters.clone(),
...@@ -293,6 +349,8 @@ impl State { ...@@ -293,6 +349,8 @@ impl State {
entry.request.stopping_parameters.clone(), entry.request.stopping_parameters.clone(),
)), )),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
...@@ -335,6 +393,7 @@ impl State { ...@@ -335,6 +393,7 @@ impl State {
requests: batch_requests, requests: batch_requests,
size, size,
max_tokens: (prefill_tokens + decode_tokens), max_tokens: (prefill_tokens + decode_tokens),
max_blocks,
}; };
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
...@@ -438,13 +497,14 @@ mod tests { ...@@ -438,13 +497,14 @@ mod tests {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
block_allocation: None,
}; };
(entry, receiver_tx) (entry, receiver_tx)
} }
#[test] #[tokio::test]
fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
...@@ -458,23 +518,23 @@ mod tests { ...@@ -458,23 +518,23 @@ mod tests {
assert_eq!(id, 0); assert_eq!(id, 0);
} }
#[test] #[tokio::test]
fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
} }
#[test] #[tokio::test]
fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
...@@ -490,7 +550,7 @@ mod tests { ...@@ -490,7 +550,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
assert!(state.next_batch(Some(2), None, 2, 2).is_none()); assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
...@@ -498,15 +558,15 @@ mod tests { ...@@ -498,15 +558,15 @@ mod tests {
assert_eq!(id, 2); assert_eq!(id, 2);
} }
#[test] #[tokio::test]
fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some()); assert!(entries.get(&0).unwrap().batch_time.is_some());
...@@ -518,15 +578,15 @@ mod tests { ...@@ -518,15 +578,15 @@ mod tests {
assert_eq!(state.next_batch_id, 1); assert_eq!(state.next_batch_id, 1);
} }
#[test] #[tokio::test]
fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
...@@ -539,7 +599,7 @@ mod tests { ...@@ -539,7 +599,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
...@@ -553,14 +613,14 @@ mod tests { ...@@ -553,14 +613,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
...@@ -568,7 +628,7 @@ mod tests { ...@@ -568,7 +628,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -601,7 +661,7 @@ mod tests { ...@@ -601,7 +661,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -617,7 +677,7 @@ mod tests { ...@@ -617,7 +677,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -642,7 +702,7 @@ mod tests { ...@@ -642,7 +702,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2); let queue = Queue::new(false, 1, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
...@@ -661,7 +721,7 @@ mod tests { ...@@ -661,7 +721,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);
......
...@@ -39,7 +39,13 @@ impl SchedulerV3 { ...@@ -39,7 +39,13 @@ impl SchedulerV3 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate); let queue = Queue::new(
requires_padding,
16,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
...@@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 { ...@@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
block_allocation: None,
}); });
// Notify the background task that we have a new entry in the queue that needs // Notify the background task that we have a new entry in the queue that needs
......
import math
import torch
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
BLOCK_SIZE: int = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks
self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int64
).view(num_blocks, self.block_size)
def allocate(
self,
needed_blocks_slots: List[Tuple[int, int]],
blocks: int,
max_blocks: int,
device: torch.device,
):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
if blocks > len(free_block_indices):
raise RuntimeError(
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
)
# Slice by the number of required blocks
block_indices = free_block_indices[:blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
all_slots = self.slots[allocated_blocks].flatten()
# Repeat slots in the case of context sliding window
if needed_slots > len(all_slots) and self.repeat_slots:
repeats = math.ceil(needed_slots / len(all_slots))
all_slots = all_slots.repeat(repeats)
allocated_slots = all_slots[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
cumulative_blocks += needed_blocks
block_tables = block_tables
block_tables_tensor = block_tables_tensor.to(device)
slots = torch.concat(slots).to(device)
# Allocate the required number of blocks by setting the mask to 0
self.free_block_mask[block_indices] = 0
return block_tables, block_tables_tensor, slots
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None and block_indices:
# Reset mask
self.free_block_mask[block_indices] = 1
def set_cache_manager(
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
) -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is not None:
del CACHE_MANAGER
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
)
return CACHE_MANAGER
def get_cache_manager() -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is None:
raise RuntimeError("cache manager was not initialized")
return CACHE_MANAGER
...@@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module): ...@@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
......
...@@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): ...@@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
......
...@@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids) input_embeds = self.embed_tokens(input_ids)
......
...@@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(
......
...@@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module): ...@@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
......
...@@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
......
...@@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment