Unverified Commit 66231cf0 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: reduce / revert routing overheads, do not consider output tokens (#2182)

parent dbd33df6
...@@ -31,8 +31,8 @@ use crate::{ ...@@ -31,8 +31,8 @@ use crate::{
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
indexer::{ indexer::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError, compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
OverlapScores, RouterEvent, KvRouterError, OverlapScores, RouterEvent,
}, },
// metrics_aggregator::EndpointCollector, // metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
...@@ -71,7 +71,8 @@ pub struct KvRouterConfig { ...@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
pub use_kv_events: bool, pub use_kv_events: bool,
// note: this is not actually used for now // TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32, pub max_num_batched_tokens: u32,
} }
...@@ -231,25 +232,25 @@ impl KvRouter { ...@@ -231,25 +232,25 @@ impl KvRouter {
let _guard = self.find_best_match_mutex.lock().await; let _guard = self.find_best_match_mutex.lock().await;
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_size = self.block_size;
let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?; let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let best_worker_id = self let best_worker_id = self
.scheduler .scheduler
.schedule( .schedule(
context_id.to_string(), context_id.to_string(),
isl_tokens, isl_tokens,
block_size, seq_hashes.clone(),
tokens,
overlap_scores.clone(), overlap_scores.clone(),
) )
.await?; .await?;
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer indexer
.process_routing_decision_for_request(tokens, best_worker_id) .process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.await .await
.unwrap(); .unwrap();
}; };
...@@ -262,9 +263,9 @@ impl KvRouter { ...@@ -262,9 +263,9 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount)) Ok((best_worker_id, overlap_amount))
} }
/// Push tokens to a specific request's sequence /// Free all blocks associated with a request
pub async fn push(&self, request_id: &String, tokens: &[u32]) { pub async fn mark_prefill_completed(&self, request_id: &String) {
self.scheduler.push(request_id, tokens).await self.scheduler.mark_prefill_completed(request_id).await
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
...@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = request.context().clone(); let stream_context = request.context().clone();
// Update the request with the estimated prefix hit blocks // Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
...@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream = stream::iter(vec![response]); let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context)); return Ok(ResponseStream::new(Box::pin(stream), stream_context));
} }
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context(); let stream_context = response_stream.context();
let chooser = self.chooser.clone(); let chooser = self.chooser.clone();
let request_id = context_id.clone();
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut accumulated_tokens = Vec::new(); if let Some(first_item) = response_stream.next().await {
let mut total_output_length = 0usize; chooser.mark_prefill_completed(&context_id).await;
let mut last_block_index = (isl.saturating_sub(1)) / block_size; yield first_item;
let mut first_push_done = false; }
while let Some(item) = response_stream.next().await { while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
let Some(ref output) = item.data else {
yield item;
continue;
};
if output.token_ids.is_empty() {
yield item;
continue;
}
// Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index);
if should_push {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
if !first_push_done {
first_push_done = true;
}
}
yield item; yield item;
} }
chooser.free(&request_id).await; chooser.free(&context_id).await;
}); });
Ok(ResponseStream::new(wrapped_stream, stream_context)) Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
......
...@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot}; ...@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::tokens::TokenBlockSequence; use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
...@@ -295,6 +295,26 @@ impl ApproxKvIndexer { ...@@ -295,6 +295,26 @@ impl ApproxKvIndexer {
self.kv_block_size self.kv_block_size
} }
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
/// Wrapper function that computes hashes from tokens and calls the core function
pub async fn process_routing_decision_for_request( pub async fn process_routing_decision_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
...@@ -309,16 +329,8 @@ impl ApproxKvIndexer { ...@@ -309,16 +329,8 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash()) .map(|b| b.sequence_hash())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.route_tx self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
.await .await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
} }
} }
......
...@@ -63,6 +63,7 @@ use xxhash_rust::xxh3; ...@@ -63,6 +63,7 @@ use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337; pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::protocols::*; use crate::kv_router::protocols::*;
use crate::tokens::SequenceHash;
/// Errors that can occur in the KV Router. /// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
...@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc ...@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc
.collect() .collect()
} }
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. /// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterEvent { pub struct RouterEvent {
......
...@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics; ...@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::sequence::ActiveSequencesMultiWorker; use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig; use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT; use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence; use crate::tokens::SequenceHash;
use dynamo_runtime::component::Instance; use dynamo_runtime::component::Instance;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -217,15 +217,13 @@ impl KvScheduler { ...@@ -217,15 +217,13 @@ impl KvScheduler {
&self, &self,
request_id: String, request_id: String,
isl_tokens: usize, isl_tokens: usize,
block_size: u32, token_seq: Vec<SequenceHash>,
tokens: &[u32],
overlaps: OverlapScores, overlaps: OverlapScores,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let (potential_blocks, potential_tokens) = let (potential_blocks, potential_tokens) =
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone()); sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone());
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
...@@ -247,10 +245,10 @@ impl KvScheduler { ...@@ -247,10 +245,10 @@ impl KvScheduler {
sequences.update_workers(new_worker_ids); sequences.update_workers(new_worker_ids);
} }
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request( sequences.add_request(
request_id, request_id,
token_sequence, token_seq,
isl_tokens,
response.overlap_blocks, response.overlap_blocks,
response.best_worker_id, response.best_worker_id,
); );
...@@ -258,10 +256,9 @@ impl KvScheduler { ...@@ -258,10 +256,9 @@ impl KvScheduler {
Ok(response.best_worker_id) Ok(response.best_worker_id)
} }
/// Push tokens to a specific request's sequence pub async fn mark_prefill_completed(&self, request_id: &String) {
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
sequences.push(request_id, tokens) sequences.mark_prefill_completed(request_id)
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
......
...@@ -36,50 +36,24 @@ ...@@ -36,50 +36,24 @@
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId; use crate::kv_router::indexer::WorkerId;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::SequenceHash;
use crate::tokens::TokenBlockSequence;
use derive_getters::Getters; use derive_getters::Getters;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::{mpsc, Arc}; use std::sync::{mpsc, Arc};
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use uuid;
// TODO: use the common request_id if it exists in the repo // TODO: use the common request_id if it exists in the repo
pub type RequestId = String; pub type RequestId = String;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
}
unique_blocks
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug, Getters)]
pub struct ActiveSequences { pub struct ActiveSequences {
active_seqs: HashMap<RequestId, TokenBlockSequence>, active_seqs: HashMap<RequestId, Vec<SequenceHash>>,
partial_blocks: HashMap<RequestId, UniqueBlock>,
prefill_tokens: HashMap<RequestId, usize>, prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>, unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: usize,
...@@ -99,7 +73,6 @@ impl ActiveSequences { ...@@ -99,7 +73,6 @@ impl ActiveSequences {
Self { Self {
active_seqs: HashMap::new(), active_seqs: HashMap::new(),
partial_blocks: HashMap::new(),
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(), unique_blocks: HashMap::new(),
block_size, block_size,
...@@ -108,24 +81,20 @@ impl ActiveSequences { ...@@ -108,24 +81,20 @@ impl ActiveSequences {
} }
} }
fn add_block(&mut self, request_id: RequestId, block: &UniqueBlock) { fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) {
let is_new_block = !self.unique_blocks.contains_key(block); let is_new_block = !self.unique_blocks.contains_key(block);
self.unique_blocks self.unique_blocks
.entry(block.clone()) .entry(*block)
.or_default() .or_default()
.insert(request_id.clone()); .insert(request_id.clone());
if is_new_block { if is_new_block {
self.active_blocks += 1; self.active_blocks += 1;
} }
if matches!(block, UniqueBlock::PartialBlock(_)) {
self.partial_blocks.insert(request_id, block.clone());
};
} }
fn remove_block(&mut self, request_id: &RequestId, block: &UniqueBlock) { fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else { let Some(request_ids) = self.unique_blocks.get_mut(block) else {
panic!("Cannot remove a block that does not exist.") panic!("Cannot remove a block that does not exist.")
}; };
...@@ -142,17 +111,16 @@ impl ActiveSequences { ...@@ -142,17 +111,16 @@ impl ActiveSequences {
pub fn add_request( pub fn add_request(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: TokenBlockSequence, token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32, overlap: u32,
) -> usize { ) -> usize {
let prefill_tokens = self.new_tokens(&token_sequence, overlap); let prefill_tokens = self.new_tokens(isl, overlap);
self.prefill_tokens self.prefill_tokens
.insert(request_id.clone(), prefill_tokens); .insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens; self.active_tokens += prefill_tokens;
let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size); for block in &token_sequence {
for block in &blocks {
self.add_block(request_id.clone(), block); self.add_block(request_id.clone(), block);
} }
...@@ -161,30 +129,35 @@ impl ActiveSequences { ...@@ -161,30 +129,35 @@ impl ActiveSequences {
self.active_blocks self.active_blocks
} }
pub fn new_tokens(&self, token_sequence: &TokenBlockSequence, overlap: u32) -> usize { /// Mark prefill as completed for a request, removing it from prefill_tokens tracking
let input_tokens = token_sequence.total_tokens(); pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
input_tokens if let Some(tokens) = self.prefill_tokens.remove(request_id) {
.checked_sub((overlap as usize) * self.block_size) self.active_tokens = self
.unwrap_or_else(|| { .active_tokens
panic!("prefill_tokens < 0 with overlap {overlap} and ISL {input_tokens}") .checked_sub(tokens.saturating_sub(1)) // Keep 1 token for decoding
}) .expect("active_tokens underflow");
}
}
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
isl.checked_sub((overlap as usize) * self.block_size)
.unwrap_or_else(|| panic!("prefill_tokens < 0 with overlap {overlap} and ISL {isl}"))
} }
pub fn potential_blocks_and_tokens( pub fn potential_blocks_and_tokens(
&self, &self,
token_sequence: &TokenBlockSequence, token_sequence: &[SequenceHash],
isl: usize,
overlap: u32, overlap: u32,
) -> (usize, usize) { ) -> (usize, usize) {
let potential_blocks = self.new_blocks(token_sequence) + self.active_blocks; let potential_blocks = self.new_blocks(token_sequence) + self.active_blocks;
let potential_tokens = self.new_tokens(token_sequence, overlap) + self.active_tokens; let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
/// Match a request against existing blocks and return the number of new blocks that would be added /// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize { pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size); token_sequence
blocks
.iter() .iter()
.filter(|block| !self.unique_blocks.contains_key(block)) .filter(|block| !self.unique_blocks.contains_key(block))
.count() .count()
...@@ -192,7 +165,7 @@ impl ActiveSequences { ...@@ -192,7 +165,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added /// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks /// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &TokenBlockSequence) -> usize { pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks self.new_blocks(token_sequence) + self.active_blocks
} }
...@@ -209,110 +182,49 @@ impl ActiveSequences { ...@@ -209,110 +182,49 @@ impl ActiveSequences {
return 0; return 0;
}; };
let blocks = create_unique_blocks_from_sequence(token_seq, None, self.block_size); for block in token_seq.clone() {
for block in blocks { self.remove_block(request_id, &block)
if matches!(block, UniqueBlock::FullBlock(_)) {
self.remove_block(request_id, &block);
}
}
if let Some(partial_block) = self.partial_blocks.remove(request_id) {
self.remove_block(request_id, &partial_block);
} }
self.active_seqs.remove(request_id).unwrap(); self.active_seqs.remove(request_id).unwrap();
self.active_blocks self.active_blocks
} }
/// Push tokens to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize {
if let Some(prefill_tokens) = self.prefill_tokens.get(request_id).cloned() {
self.prefill_tokens.remove(request_id);
// decoding has one active token
self.active_tokens = self
.active_tokens
.checked_sub(prefill_tokens)
.expect("active_tokens < 0")
+ 1;
};
// Collect operations to perform after releasing the borrow
let mut blocks_to_remove = Vec::new();
let mut blocks_to_add = Vec::new();
{
let token_seq = self
.active_seqs
.get_mut(request_id)
.expect("Request ID not found for token push");
for &token in tokens {
token_seq.append(token).expect("Token push failed.");
// Guard: skip if we didn't cross a block boundary
if token_seq.total_tokens() % self.block_size != 1 {
continue;
}
let last_seq_hash = token_seq
.last_complete_block()
.map(|block| block.sequence_hash());
// Queue operations for later
if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() {
blocks_to_remove.push(partial_block);
}
if let Some(full_block) = last_seq_hash {
blocks_to_add.push(UniqueBlock::FullBlock(full_block));
}
blocks_to_add.push(UniqueBlock::default());
}
} // token_seq borrow is dropped here
// Now perform all the queued operations
for block in blocks_to_remove {
self.remove_block(request_id, &block);
}
for block in blocks_to_add {
self.add_block(request_id.clone(), &block);
}
self.active_blocks
}
} }
#[derive(Debug)]
enum UpdateSequences { enum UpdateSequences {
AddRequest { AddRequest {
request_id: RequestId, request_id: RequestId,
token_sequence: TokenBlockSequence, token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32, overlap: u32,
}, },
Free { Free {
request_id: RequestId, request_id: RequestId,
}, },
Push { MarkPrefillCompleted {
request_id: RequestId, request_id: RequestId,
tokens: Vec<u32>, // Changed from token: u32
}, },
NewBlocks { NewBlocks {
token_sequence: Arc<TokenBlockSequence>, token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: mpsc::SyncSender<usize>, resp_tx: mpsc::SyncSender<usize>,
}, },
PotentialBlocks { PotentialBlocks {
token_sequence: Arc<TokenBlockSequence>, token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: mpsc::SyncSender<usize>, resp_tx: mpsc::SyncSender<usize>,
}, },
PotentialBlocksAndTokens { PotentialBlocksAndTokens {
token_sequence: Arc<TokenBlockSequence>, token_sequence: Arc<Vec<SequenceHash>>,
isl: usize,
overlap: u32, overlap: u32,
resp_tx: mpsc::SyncSender<(usize, usize)>, resp_tx: mpsc::SyncSender<(usize, usize)>,
}, },
ActiveBlocks { ActiveBlocks {
resp_tx: mpsc::SyncSender<usize>, resp_tx: mpsc::SyncSender<usize>,
}, },
ActiveTokens {
resp_tx: mpsc::SyncSender<usize>,
},
Shutdown, Shutdown,
} }
...@@ -357,15 +269,16 @@ impl ActiveSequencesMultiWorker { ...@@ -357,15 +269,16 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::AddRequest { UpdateSequences::AddRequest {
request_id, request_id,
token_sequence, token_sequence,
isl,
overlap, overlap,
} => { } => {
active_sequences.add_request(request_id, token_sequence, overlap); active_sequences.add_request(request_id, token_sequence, isl, overlap);
} }
UpdateSequences::Free { request_id } => { UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id); active_sequences.free(&request_id);
} }
UpdateSequences::Push { request_id, tokens } => { UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.push(&request_id, &tokens); // Changed to pass tokens slice active_sequences.mark_prefill_completed(&request_id);
} }
UpdateSequences::NewBlocks { UpdateSequences::NewBlocks {
token_sequence, token_sequence,
...@@ -383,17 +296,25 @@ impl ActiveSequencesMultiWorker { ...@@ -383,17 +296,25 @@ impl ActiveSequencesMultiWorker {
} }
UpdateSequences::PotentialBlocksAndTokens { UpdateSequences::PotentialBlocksAndTokens {
token_sequence, token_sequence,
isl,
overlap, overlap,
resp_tx, resp_tx,
} => { } => {
let potential_tokens = let potential_tokens = active_sequences.potential_blocks_and_tokens(
active_sequences.potential_blocks_and_tokens(&token_sequence, overlap); &token_sequence,
isl,
overlap,
);
let _ = resp_tx.send(potential_tokens); let _ = resp_tx.send(potential_tokens);
} }
UpdateSequences::ActiveBlocks { resp_tx } => { UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks(); let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks); let _ = resp_tx.send(active_blocks);
} }
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => { UpdateSequences::Shutdown => {
break; break;
} }
...@@ -443,7 +364,8 @@ impl ActiveSequencesMultiWorker { ...@@ -443,7 +364,8 @@ impl ActiveSequencesMultiWorker {
pub fn add_request( pub fn add_request(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: TokenBlockSequence, token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32, overlap: u32,
worker_id: WorkerId, worker_id: WorkerId,
) { ) {
...@@ -457,6 +379,7 @@ impl ActiveSequencesMultiWorker { ...@@ -457,6 +379,7 @@ impl ActiveSequencesMultiWorker {
.send(UpdateSequences::AddRequest { .send(UpdateSequences::AddRequest {
request_id, request_id,
token_sequence, token_sequence,
isl,
overlap, overlap,
}) })
.expect("Failed to send add_request command to worker"); .expect("Failed to send add_request command to worker");
...@@ -478,18 +401,19 @@ impl ActiveSequencesMultiWorker { ...@@ -478,18 +401,19 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(request_id); self.request_to_worker.remove(request_id);
} }
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) { /// Mark prefill as completed for a request
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
let worker_id = self let worker_id = self
.request_to_worker .request_to_worker
.get(request_id) .get(request_id)
.copied() .copied()
.expect("Request ID not found in request_to_worker mapping"); .expect("Request ID not found in request_to_worker mapping");
self.senders[&worker_id] self.senders[&worker_id]
.send(UpdateSequences::Push { .send(UpdateSequences::MarkPrefillCompleted {
request_id: request_id.clone(), request_id: request_id.clone(),
tokens: tokens.to_vec(), // Convert to Vec
}) })
.expect("Failed to send push command to worker"); .expect("Failed to send mark_prefill_completed command to worker");
} }
/// Get the number of workers /// Get the number of workers
...@@ -500,8 +424,8 @@ impl ActiveSequencesMultiWorker { ...@@ -500,8 +424,8 @@ impl ActiveSequencesMultiWorker {
/// Generic method to query all workers with a given command /// Generic method to query all workers with a given command
fn query_workers( fn query_workers(
&self, &self,
token_sequence: Option<TokenBlockSequence>, token_sequence: Option<Vec<SequenceHash>>,
command_fn: impl Fn(Option<Arc<TokenBlockSequence>>, mpsc::SyncSender<usize>) -> UpdateSequences, command_fn: impl Fn(Option<Arc<Vec<SequenceHash>>>, mpsc::SyncSender<usize>) -> UpdateSequences,
) -> HashMap<WorkerId, usize> { ) -> HashMap<WorkerId, usize> {
let mut results = HashMap::new(); let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new); let token_sequence_shared = token_sequence.map(Arc::new);
...@@ -528,7 +452,7 @@ impl ActiveSequencesMultiWorker { ...@@ -528,7 +452,7 @@ impl ActiveSequencesMultiWorker {
} }
/// Query all workers for the number of new blocks that would be added by a token sequence /// Query all workers for the number of new blocks that would be added by a token sequence
pub fn new_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> { pub fn new_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks { Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts, token_sequence: ts,
...@@ -539,7 +463,7 @@ impl ActiveSequencesMultiWorker { ...@@ -539,7 +463,7 @@ impl ActiveSequencesMultiWorker {
} }
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub fn potential_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> { pub fn potential_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks { Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts, token_sequence: ts,
...@@ -552,7 +476,8 @@ impl ActiveSequencesMultiWorker { ...@@ -552,7 +476,8 @@ impl ActiveSequencesMultiWorker {
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap /// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub fn potential_blocks_and_tokens( pub fn potential_blocks_and_tokens(
&self, &self,
token_sequence: TokenBlockSequence, token_sequence: Vec<SequenceHash>,
isl: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) { ) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
let mut potential_blocks = HashMap::new(); let mut potential_blocks = HashMap::new();
...@@ -568,6 +493,7 @@ impl ActiveSequencesMultiWorker { ...@@ -568,6 +493,7 @@ impl ActiveSequencesMultiWorker {
sender sender
.send(UpdateSequences::PotentialBlocksAndTokens { .send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(), token_sequence: token_sequence_shared.clone(),
isl,
overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0), overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0),
resp_tx, resp_tx,
}) })
...@@ -590,6 +516,11 @@ impl ActiveSequencesMultiWorker { ...@@ -590,6 +516,11 @@ impl ActiveSequencesMultiWorker {
pub fn active_blocks(&self) -> HashMap<WorkerId, usize> { pub fn active_blocks(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
} }
/// Query all workers for their current number of active tokens
pub fn active_tokens(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
}
} }
impl Drop for ActiveSequencesMultiWorker { impl Drop for ActiveSequencesMultiWorker {
...@@ -609,91 +540,102 @@ impl Drop for ActiveSequencesMultiWorker { ...@@ -609,91 +540,102 @@ impl Drop for ActiveSequencesMultiWorker {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::tokens::Tokens;
#[test]
fn test_shared_sequence_manager_operations() {
let block_size = 4;
let mut manager = ActiveSequences::new(block_size);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]), 0);
manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once
assert_eq!(manager.active_tokens(), 1);
assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]), 1);
assert_eq!(manager.active_tokens(), 1 + 3);
assert_eq!(manager.active_blocks(), 3);
// Check that only one key is FullBlock with both requests sharing it
let mut full_block_count = 0;
let mut shared_block_requests = None;
for (block, requests) in &manager.unique_blocks {
if let UniqueBlock::FullBlock(_) = block {
full_block_count += 1;
if requests.len() == 2 {
shared_block_requests = Some(requests.clone());
}
}
}
assert_eq!(full_block_count, 1);
assert!(shared_block_requests.is_some());
let shared_requests = shared_block_requests.unwrap();
assert!(shared_requests.contains("0"));
assert!(shared_requests.contains("1"));
let new_blocks = manager.new_blocks(&to_sequence(vec![0, 1, 2, 3, 4, 5]));
assert_eq!(new_blocks, 1);
// Step 3: Free request 1
manager.free(&"1".to_string());
assert_eq!(manager.active_blocks(), 2);
// Step 4: Free request 0
manager.free(&"0".to_string());
assert_eq!(manager.active_tokens(), 0);
assert_eq!(manager.active_blocks(), 0);
assert_eq!(manager.unique_blocks.len(), 0);
assert_eq!(manager.partial_blocks.len(), 0);
assert_eq!(manager.active_seqs.len(), 0);
}
#[test] #[test]
fn test_active_sequences_multi_worker() { fn test_multi_worker_block_sharing() {
let block_size = 4; // Create multi-worker sequence manager with 3 workers
let block_size = 4; // arbitrary block size
let worker_ids = vec![0, 1, 2]; let worker_ids = vec![0, 1, 2];
let mut manager = ActiveSequencesMultiWorker::new(block_size, worker_ids); let mut seq_manager = ActiveSequencesMultiWorker::new(block_size, worker_ids);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None); // Add requests to each worker
// Worker 0: sequence [0, 1, 2]
// Send request [0, 1, 2, 3] to worker 0 seq_manager.add_request(
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0, 0); "request_0".to_string(),
vec![0, 1, 2],
// Send request [0, 1, 2] to worker 1, then push 3 and 4 12, // ISL (3 blocks * 4 block_size)
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 0, 1); 0, // no overlap
manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once 0, // worker_id
);
// Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 0, 2); // Worker 1: sequence [3, 4]
seq_manager.add_request(
// Check new_blocks on tokens [0, 1, 2, 3, 4] "request_1".to_string(),
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4])); vec![3, 4],
8, // ISL (2 blocks * 4 block_size)
assert_eq!(new_blocks_map[&0], 1); // Worker 0 would have 1 new block 0, // no overlap
assert_eq!(new_blocks_map[&1], 1); // Worker 1 would have 1 new block 1, // worker_id
assert_eq!(new_blocks_map[&2], 2); // Worker 2 would have 2 new blocks );
manager.update_workers(vec![0, 1]); // Worker 2: sequence [0, 1, 2, 3]
manager.update_workers(vec![0, 1, 3]); seq_manager.add_request(
"request_2".to_string(),
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4])); vec![0, 1, 2, 3],
16, // ISL (4 blocks * 4 block_size)
assert_eq!(new_blocks_map.len(), 3); 0, // no overlap
assert_eq!(new_blocks_map[&3], 2); 2, // worker_id
);
// Verify active tokens after adding requests
let tokens_after_add = seq_manager.active_tokens();
assert_eq!(
tokens_after_add[&0], 12,
"Worker 0 should have 12 active tokens"
);
assert_eq!(
tokens_after_add[&1], 8,
"Worker 1 should have 8 active tokens"
);
assert_eq!(
tokens_after_add[&2], 16,
"Worker 2 should have 16 active tokens"
);
// Test potential blocks for sequence [0, 1]
let potential_blocks = seq_manager.potential_blocks(vec![0, 1]);
// Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1])
assert_eq!(
potential_blocks[&0], 3,
"Worker 0 should have 3 potential blocks"
);
// Worker 1 should return 4 (has blocks 3, 4, would need to add blocks 0, 1)
assert_eq!(
potential_blocks[&1], 4,
"Worker 1 should have 4 potential blocks"
);
// Worker 2 should return 4 (already has blocks 0, 1, 2, 3, so no new blocks needed for [0, 1])
assert_eq!(
potential_blocks[&2], 4,
"Worker 2 should have 4 potential blocks"
);
// Free all original requests
seq_manager.free(&"request_0".to_string());
seq_manager.free(&"request_1".to_string());
seq_manager.free(&"request_2".to_string());
// Verify active blocks are zero for all workers
let active_blocks = seq_manager.active_blocks();
assert_eq!(active_blocks[&0], 0, "Worker 0 should have 0 active blocks");
assert_eq!(active_blocks[&1], 0, "Worker 1 should have 0 active blocks");
assert_eq!(active_blocks[&2], 0, "Worker 2 should have 0 active blocks");
// Verify active tokens are zero for all workers
let final_tokens = seq_manager.active_tokens();
assert_eq!(
final_tokens[&0], 0,
"Worker 0 should have 0 active tokens after freeing all"
);
assert_eq!(
final_tokens[&1], 0,
"Worker 1 should have 0 active tokens after freeing all"
);
assert_eq!(
final_tokens[&2], 0,
"Worker 2 should have 0 active tokens after freeing all"
);
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
pub use crate::kv_router::protocols::ForwardPassMetrics;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::pipeline::network::{
ingress::push_endpoint::PushEndpoint,
PushWorkHandler,
};
use dynamo_runtime::transports::nats::{self, ServiceExt};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing as log;
#[derive(Builder)]
pub struct KvRoutedIngress {
#[builder(setter(into))]
pub service_name: String,
#[builder(setter(into))]
pub worker_id: String,
pub nats: nats::Client,
pub service_handler: Arc<dyn PushWorkHandler>,
pub metrics_rx: watch::Receiver<Arc<ForwardPassMetrics>>,
pub cancellation_token: CancellationToken,
}
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
impl KvRoutedIngress {
pub fn builder() -> KvRoutedIngressBuilder {
KvRoutedIngressBuilder::default()
}
pub async fn start(self) -> Result<()> {
let worker_id = self.worker_id;
log::trace!(
worker_id,
"Starting nats service: {}:{}",
self.service_name,
VERSION
);
let mut metrics_rx = self.metrics_rx;
let worker_id_clone = worker_id.clone();
let service = self
.nats
.client()
.service_builder()
.description("A handy min max service")
.stats_handler(move |name, stats| {
log::debug!(
worker_id = worker_id_clone.as_str(),
"[IN worker?] Stats for service {}: {:?}",
name,
stats
);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.start(self.service_name.as_str(), VERSION)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
let group = service.group(self.service_name.as_str());
log::trace!(worker_id, "Starting endpoint: {}", worker_id);
// creates an endpoint for the service
let service_endpoint = group
.endpoint(worker_id.clone())
.await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let push_endpoint = PushEndpoint::builder()
.service_handler(self.service_handler)
.cancellation_token(self.cancellation_token)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
push_endpoint.start(service_endpoint).await
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment