Unverified Commit 5e511e92 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: update active blocks in chunks only when necessary (#1848)

parent 7bc70bc9
......@@ -206,9 +206,9 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount))
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
self.scheduler.push(request_id, token).await
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
self.scheduler.push(request_id, tokens).await
}
/// Free all blocks associated with a request
......@@ -273,6 +273,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?;
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
......@@ -283,17 +284,39 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! {
let mut accumulated_tokens = Vec::new();
let mut total_output_length = 0usize;
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
if let Some(ref output) = item.data {
for token_id in &output.token_ids {
chooser.push(&request_id, *token_id).await;
let Some(ref output) = item.data else {
yield item;
continue;
};
if output.token_ids.is_empty() {
yield item;
continue;
}
// Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Check if we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
if current_block_index > last_block_index {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
}
yield item;
}
chooser.free(&request_id).await;
});
......
......@@ -259,10 +259,10 @@ impl KvScheduler {
sequences.add_request(request_id, token_sequence, worker_id)
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
let mut sequences = self.sequences.lock().await;
sequences.push(request_id, token)
sequences.push(request_id, tokens)
}
/// Free all blocks associated with a request
......@@ -401,9 +401,11 @@ impl WorkerSelector for DefaultWorkerSelector {
}
// Normalize by dividing by max value
if max_logit > 0.0 {
for logit in worker_logits.values_mut() {
*logit /= max_logit;
}
}
// Use softmax sampling to select worker
let temperature = self.kv_router_config.router_temperature;
......
......@@ -185,32 +185,50 @@ impl ActiveSequences {
self.active_blocks
}
/// Push a token to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, token: u32) -> usize {
/// Push tokens to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize {
// Collect operations to perform after releasing the borrow
let mut blocks_to_remove = Vec::new();
let mut blocks_to_add = Vec::new();
{
let token_seq = self
.active_seqs
.get_mut(request_id)
.expect("Request ID not found for token push");
for &token in tokens {
token_seq.append(token).expect("Token push failed.");
// No need to update anything
// Guard: skip if we didn't cross a block boundary
if token_seq.total_tokens() % self.block_size != 1 {
return self.active_blocks;
continue;
}
let last_seq_hash = token_seq
.last_complete_block()
.map(|block| block.sequence_hash());
// Promote a partial block into a full block if not already
// Queue operations for later
if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() {
self.remove_block(request_id, &partial_block);
blocks_to_remove.push(partial_block);
}
if let Some(full_block) = last_seq_hash {
self.add_block(request_id.clone(), &UniqueBlock::FullBlock(full_block));
blocks_to_add.push(UniqueBlock::FullBlock(full_block));
}
blocks_to_add.push(UniqueBlock::default());
}
} // token_seq borrow is dropped here
self.add_block(request_id.clone(), &UniqueBlock::default());
// Now perform all the queued operations
for block in blocks_to_remove {
self.remove_block(request_id, &block);
}
for block in blocks_to_add {
self.add_block(request_id.clone(), &block);
}
self.active_blocks
}
......@@ -227,7 +245,7 @@ enum UpdateSequences {
},
Push {
request_id: RequestId,
token: u32,
tokens: Vec<u32>, // Changed from token: u32
},
NewBlocks {
token_sequence: Arc<TokenBlockSequence>,
......@@ -290,8 +308,8 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::Push { request_id, token } => {
active_sequences.push(&request_id, token);
UpdateSequences::Push { request_id, tokens } => {
active_sequences.push(&request_id, &tokens); // Changed to pass tokens slice
}
UpdateSequences::NewBlocks {
token_sequence,
......@@ -393,7 +411,7 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(request_id);
}
pub fn push(&mut self, request_id: &RequestId, token: u32) {
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) {
let worker_id = self
.request_to_worker
.get(request_id)
......@@ -402,7 +420,7 @@ impl ActiveSequencesMultiWorker {
self.senders[&worker_id]
.send(UpdateSequences::Push {
request_id: request_id.clone(),
token,
tokens: tokens.to_vec(), // Convert to Vec
})
.expect("Failed to send push command to worker");
}
......@@ -498,8 +516,7 @@ mod tests {
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]));
manager.push(&"0".to_string(), 3);
manager.push(&"0".to_string(), 4);
manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once
assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1);
......@@ -551,10 +568,9 @@ mod tests {
// Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0);
// Send request [0, 1, 2] to worker 1, then push 3 and push 4
// Send request [0, 1, 2] to worker 1, then push 3 and 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1);
manager.push(&"req1".to_string(), 3);
manager.push(&"req1".to_string(), 4);
manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once
// Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment