Unverified Commit 8aece3bd authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: move allocation logic to rust (#1835)

Close #2007
parent 9ffe1f1e
......@@ -26,7 +26,12 @@ incremental = true
inherits = "release"
debug = 1
incremental = true
panic = "abort"
[profile.release-opt]
inherits = "release"
debug = 0
incremental = false
lto = "fat"
opt-level = 3
codegen-units = 1
panic = "abort"
......@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
......@@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo build --release
RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
......@@ -226,11 +226,11 @@ RUN cd server && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
......
......@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
......@@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo build --release
RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
......@@ -193,11 +193,11 @@ RUN cd server && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
FROM base as sagemaker
......
......@@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
......@@ -32,7 +32,7 @@ COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo build --release
RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel
......@@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
ENV CCL_ZE_IPC_EXCHANGE=sockets
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image
FROM base
......
......@@ -155,6 +155,8 @@ async fn prefill(
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
})
.collect();
......@@ -163,6 +165,7 @@ async fn prefill(
requests,
size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
};
// Run prefill
......
......@@ -130,6 +130,10 @@ message Request {
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
}
message Batch {
......@@ -141,6 +145,8 @@ message Batch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
}
message CachedBatch {
......
......@@ -153,6 +153,9 @@ impl Client {
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
......@@ -187,7 +190,8 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
max_tokens: max_input_length,
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
......
......@@ -241,12 +241,16 @@ impl Health for ShardedClient {
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
Ok(())
......
use std::cmp::min;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
block_allocator: BlockAllocator,
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocator {
/// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
}
impl BlockAllocator {
pub(crate) fn new(
max_batch_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
) -> Self {
// Create channel
let (sender, receiver) = mpsc::unbounded_channel();
// Launch background queue task
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
window_size,
receiver,
));
Self {
block_allocator: sender,
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
response_sender,
})
.unwrap();
response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
})
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.unwrap();
}
}
async fn block_allocator_task(
blocks: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
response_sender,
} => {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
},
}
mod block_allocator;
mod queue;
mod scheduler;
......
use crate::infer::{InferError, InferStreamResponse};
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::cmp::{max, min};
use std::collections::VecDeque;
use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::{ChunksToString, Input};
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
use tracing::{info_span, instrument, Instrument, Span};
/// Queue entry
#[derive(Debug)]
......@@ -28,6 +31,8 @@ pub(crate) struct Entry {
pub queue_time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
}
/// Request Queue
......@@ -43,6 +48,7 @@ impl Queue {
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
......@@ -53,12 +59,14 @@ impl Queue {
block_size,
window_size,
speculate,
max_batch_total_tokens,
queue_receiver,
));
Self { queue_sender }
}
/// Append an entry to the queue
#[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state
......@@ -103,9 +111,16 @@ async fn queue_task(
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size, window_size, speculate);
let mut state = State::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
while let Some(cmd) = receiver.recv().await {
match cmd {
......@@ -120,12 +135,14 @@ async fn queue_task(
token_budget,
response_sender,
span,
} => span.in_scope(|| {
let next_batch =
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
} => {
let next_batch = state
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.instrument(span)
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}),
}
}
}
}
......@@ -142,9 +159,6 @@ struct State {
/// Id of the next batch
next_batch_id: u64,
/// Whether the model is using padding
requires_padding: bool,
/// Paged Attention block size
block_size: u32,
......@@ -153,6 +167,9 @@ struct State {
/// Speculation amount
speculate: u32,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
impl State {
......@@ -161,15 +178,19 @@ impl State {
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
) -> Self {
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
requires_padding,
block_size,
window_size,
speculate,
block_allocator,
}
}
......@@ -185,7 +206,7 @@ impl State {
}
// Get the next batch
fn next_batch(
async fn next_batch(
&mut self,
min_size: Option<usize>,
max_size: Option<usize>,
......@@ -220,9 +241,10 @@ impl State {
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
// Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() {
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
......@@ -231,43 +253,67 @@ impl State {
continue;
}
if self.requires_padding {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else {
// pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
/ self.block_size)
* self.block_size;
}
if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
// pad to block size
decode_tokens +=
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let block_allocation = match &self.block_allocator {
None => {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break 'entry_loop;
}
None
}
Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
}
};
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry
......@@ -278,13 +324,23 @@ impl State {
// Update entry
entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
),
};
entry.block_allocation = block_allocation;
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.chunks_to_string(),
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
......@@ -293,6 +349,8 @@ impl State {
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
......@@ -335,6 +393,7 @@ impl State {
requests: batch_requests,
size,
max_tokens: (prefill_tokens + decode_tokens),
max_blocks,
};
// Increment batch id
self.next_batch_id += 1;
......@@ -438,13 +497,14 @@ mod tests {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
};
(entry, receiver_tx)
}
#[test]
fn test_append() {
let mut state = State::new(false, 1, None, 0);
#[tokio::test]
async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
......@@ -458,23 +518,23 @@ mod tests {
assert_eq!(id, 0);
}
#[test]
fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0);
#[tokio::test]
async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).is_none());
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
}
#[test]
fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0);
#[tokio::test]
async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
......@@ -490,7 +550,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
......@@ -498,15 +558,15 @@ mod tests {
assert_eq!(id, 2);
}
#[test]
fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0);
#[tokio::test]
async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
......@@ -518,15 +578,15 @@ mod tests {
assert_eq!(state.next_batch_id, 1);
}
#[test]
fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0);
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
......@@ -539,7 +599,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
......@@ -553,14 +613,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0);
let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0);
let queue = Queue::new(false, 1, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
......@@ -568,7 +628,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0);
let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -601,7 +661,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0);
let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -617,7 +677,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0);
let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -642,7 +702,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2);
let queue = Queue::new(false, 1, None, 2, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -661,7 +721,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0);
let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _) = default_entry();
queue.append(entry);
......
......@@ -39,7 +39,13 @@ impl SchedulerV3 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let queue = Queue::new(
requires_padding,
16,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
......@@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
......
import math
import torch
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
BLOCK_SIZE: int = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks
self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int64
).view(num_blocks, self.block_size)
def allocate(
self,
needed_blocks_slots: List[Tuple[int, int]],
blocks: int,
max_blocks: int,
device: torch.device,
):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
if blocks > len(free_block_indices):
raise RuntimeError(
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
)
# Slice by the number of required blocks
block_indices = free_block_indices[:blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
all_slots = self.slots[allocated_blocks].flatten()
# Repeat slots in the case of context sliding window
if needed_slots > len(all_slots) and self.repeat_slots:
repeats = math.ceil(needed_slots / len(all_slots))
all_slots = all_slots.repeat(repeats)
allocated_slots = all_slots[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
cumulative_blocks += needed_blocks
block_tables = block_tables
block_tables_tensor = block_tables_tensor.to(device)
slots = torch.concat(slots).to(device)
# Allocate the required number of blocks by setting the mask to 0
self.free_block_mask[block_indices] = 0
return block_tables, block_tables_tensor, slots
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None and block_indices:
# Reset mask
self.free_block_mask[block_indices] = 1
def set_cache_manager(
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
) -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is not None:
del CACHE_MANAGER
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
)
return CACHE_MANAGER
def get_cache_manager() -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is None:
raise RuntimeError("cache manager was not initialized")
return CACHE_MANAGER
......@@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
......
......@@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
......
......@@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
......
......@@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.gpt_neox(
......
......@@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
......
......@@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
......
......@@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment