"vscode:/vscode.git/clone" did not exist on "a60cdf59ed4baf95d472736c4ee23fe33062d2ca"
Unverified Commit 5e511e92 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: update active blocks in chunks only when necessary (#1848)

parent 7bc70bc9
...@@ -206,9 +206,9 @@ impl KvRouter { ...@@ -206,9 +206,9 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount)) Ok((best_worker_id, overlap_amount))
} }
/// Push a token to a specific request's sequence /// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) { pub async fn push(&self, request_id: &String, tokens: &[u32]) {
self.scheduler.push(request_id, token).await self.scheduler.push(request_id, tokens).await
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
...@@ -273,6 +273,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -273,6 +273,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?; .await?;
// Update the request with the estimated prefix hit blocks // Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
...@@ -283,17 +284,39 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -283,17 +284,39 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = response_stream.context(); let stream_context = response_stream.context();
let chooser = self.chooser.clone(); let chooser = self.chooser.clone();
let request_id = context_id.clone(); let request_id = context_id.clone();
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut accumulated_tokens = Vec::new();
let mut total_output_length = 0usize;
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
while let Some(item) = response_stream.next().await { while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response // Track tokens if they exist in the response
if let Some(ref output) = item.data { let Some(ref output) = item.data else {
for token_id in &output.token_ids { yield item;
chooser.push(&request_id, *token_id).await; continue;
};
if output.token_ids.is_empty() {
yield item;
continue;
} }
// Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Check if we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
if current_block_index > last_block_index {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
} }
yield item; yield item;
} }
chooser.free(&request_id).await; chooser.free(&request_id).await;
}); });
......
...@@ -259,10 +259,10 @@ impl KvScheduler { ...@@ -259,10 +259,10 @@ impl KvScheduler {
sequences.add_request(request_id, token_sequence, worker_id) sequences.add_request(request_id, token_sequence, worker_id)
} }
/// Push a token to a specific request's sequence /// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) { pub async fn push(&self, request_id: &String, tokens: &[u32]) {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
sequences.push(request_id, token) sequences.push(request_id, tokens)
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
...@@ -401,9 +401,11 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -401,9 +401,11 @@ impl WorkerSelector for DefaultWorkerSelector {
} }
// Normalize by dividing by max value // Normalize by dividing by max value
if max_logit > 0.0 {
for logit in worker_logits.values_mut() { for logit in worker_logits.values_mut() {
*logit /= max_logit; *logit /= max_logit;
} }
}
// Use softmax sampling to select worker // Use softmax sampling to select worker
let temperature = self.kv_router_config.router_temperature; let temperature = self.kv_router_config.router_temperature;
......
...@@ -185,32 +185,50 @@ impl ActiveSequences { ...@@ -185,32 +185,50 @@ impl ActiveSequences {
self.active_blocks self.active_blocks
} }
/// Push a token to a specific request's sequence /// Push tokens to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, token: u32) -> usize { pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize {
// Collect operations to perform after releasing the borrow
let mut blocks_to_remove = Vec::new();
let mut blocks_to_add = Vec::new();
{
let token_seq = self let token_seq = self
.active_seqs .active_seqs
.get_mut(request_id) .get_mut(request_id)
.expect("Request ID not found for token push"); .expect("Request ID not found for token push");
for &token in tokens {
token_seq.append(token).expect("Token push failed."); token_seq.append(token).expect("Token push failed.");
// No need to update anything // Guard: skip if we didn't cross a block boundary
if token_seq.total_tokens() % self.block_size != 1 { if token_seq.total_tokens() % self.block_size != 1 {
return self.active_blocks; continue;
} }
let last_seq_hash = token_seq let last_seq_hash = token_seq
.last_complete_block() .last_complete_block()
.map(|block| block.sequence_hash()); .map(|block| block.sequence_hash());
// Promote a partial block into a full block if not already // Queue operations for later
if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() { if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() {
self.remove_block(request_id, &partial_block); blocks_to_remove.push(partial_block);
} }
if let Some(full_block) = last_seq_hash { if let Some(full_block) = last_seq_hash {
self.add_block(request_id.clone(), &UniqueBlock::FullBlock(full_block)); blocks_to_add.push(UniqueBlock::FullBlock(full_block));
}
blocks_to_add.push(UniqueBlock::default());
} }
} // token_seq borrow is dropped here
self.add_block(request_id.clone(), &UniqueBlock::default()); // Now perform all the queued operations
for block in blocks_to_remove {
self.remove_block(request_id, &block);
}
for block in blocks_to_add {
self.add_block(request_id.clone(), &block);
}
self.active_blocks self.active_blocks
} }
...@@ -227,7 +245,7 @@ enum UpdateSequences { ...@@ -227,7 +245,7 @@ enum UpdateSequences {
}, },
Push { Push {
request_id: RequestId, request_id: RequestId,
token: u32, tokens: Vec<u32>, // Changed from token: u32
}, },
NewBlocks { NewBlocks {
token_sequence: Arc<TokenBlockSequence>, token_sequence: Arc<TokenBlockSequence>,
...@@ -290,8 +308,8 @@ impl ActiveSequencesMultiWorker { ...@@ -290,8 +308,8 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::Free { request_id } => { UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id); active_sequences.free(&request_id);
} }
UpdateSequences::Push { request_id, token } => { UpdateSequences::Push { request_id, tokens } => {
active_sequences.push(&request_id, token); active_sequences.push(&request_id, &tokens); // Changed to pass tokens slice
} }
UpdateSequences::NewBlocks { UpdateSequences::NewBlocks {
token_sequence, token_sequence,
...@@ -393,7 +411,7 @@ impl ActiveSequencesMultiWorker { ...@@ -393,7 +411,7 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(request_id); self.request_to_worker.remove(request_id);
} }
pub fn push(&mut self, request_id: &RequestId, token: u32) { pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) {
let worker_id = self let worker_id = self
.request_to_worker .request_to_worker
.get(request_id) .get(request_id)
...@@ -402,7 +420,7 @@ impl ActiveSequencesMultiWorker { ...@@ -402,7 +420,7 @@ impl ActiveSequencesMultiWorker {
self.senders[&worker_id] self.senders[&worker_id]
.send(UpdateSequences::Push { .send(UpdateSequences::Push {
request_id: request_id.clone(), request_id: request_id.clone(),
token, tokens: tokens.to_vec(), // Convert to Vec
}) })
.expect("Failed to send push command to worker"); .expect("Failed to send push command to worker");
} }
...@@ -498,8 +516,7 @@ mod tests { ...@@ -498,8 +516,7 @@ mod tests {
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4 // Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2])); manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]));
manager.push(&"0".to_string(), 3); manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once
manager.push(&"0".to_string(), 4);
assert_eq!(manager.active_blocks(), 2); assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1); assert_eq!(manager.partial_blocks.len(), 1);
...@@ -551,10 +568,9 @@ mod tests { ...@@ -551,10 +568,9 @@ mod tests {
// Send request [0, 1, 2, 3] to worker 0 // Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0); manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0);
// Send request [0, 1, 2] to worker 1, then push 3 and push 4 // Send request [0, 1, 2] to worker 1, then push 3 and 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1); manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1);
manager.push(&"req1".to_string(), 3); manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once
manager.push(&"req1".to_string(), 4);
// Send request [0, 1, 2] to worker 2 // Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2); manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment