Unverified Commit 66231cf0 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: reduce / revert routing overheads, do not consider output tokens (#2182)

parent dbd33df6
......@@ -31,8 +31,8 @@ use crate::{
kv_router::{
approx::ApproxKvIndexer,
indexer::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
KvRouterError, OverlapScores, RouterEvent,
},
// metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
......@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
pub use_kv_events: bool,
// note: this is not actually used for now
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
}
......@@ -231,25 +232,25 @@ impl KvRouter {
let _guard = self.find_best_match_mutex.lock().await;
let isl_tokens = tokens.len();
let block_size = self.block_size;
let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let best_worker_id = self
.scheduler
.schedule(
context_id.to_string(),
isl_tokens,
block_size,
tokens,
seq_hashes.clone(),
overlap_scores.clone(),
)
.await?;
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision_for_request(tokens, best_worker_id)
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.await
.unwrap();
};
......@@ -262,9 +263,9 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount))
}
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
self.scheduler.push(request_id, tokens).await
/// Free all blocks associated with a request
pub async fn mark_prefill_completed(&self, request_id: &String) {
self.scheduler.mark_prefill_completed(request_id).await
}
/// Free all blocks associated with a request
......@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = request.context().clone();
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
......@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! {
let mut accumulated_tokens = Vec::new();
let mut total_output_length = 0usize;
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
let mut first_push_done = false;
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
let Some(ref output) = item.data else {
yield item;
continue;
};
if output.token_ids.is_empty() {
yield item;
continue;
}
// Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index);
if should_push {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
if !first_push_done {
first_push_done = true;
}
if let Some(first_item) = response_stream.next().await {
chooser.mark_prefill_completed(&context_id).await;
yield first_item;
}
while let Some(item) = response_stream.next().await {
yield item;
}
chooser.free(&request_id).await;
chooser.free(&context_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
......
......@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use crate::tokens::TokenBlockSequence;
use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
......@@ -295,6 +295,26 @@ impl ApproxKvIndexer {
self.kv_block_size
}
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
/// Wrapper function that computes hashes from tokens and calls the core function
pub async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
......@@ -309,16 +329,8 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.route_tx
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
}
......
......@@ -63,6 +63,7 @@ use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::protocols::*;
use crate::tokens::SequenceHash;
/// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)]
......@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc
.collect()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterEvent {
......
......@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence;
use crate::tokens::SequenceHash;
use dynamo_runtime::component::Instance;
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -217,15 +217,13 @@ impl KvScheduler {
&self,
request_id: String,
isl_tokens: usize,
block_size: u32,
tokens: &[u32],
token_seq: Vec<SequenceHash>,
overlaps: OverlapScores,
) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let (potential_blocks, potential_tokens) =
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone());
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
......@@ -247,10 +245,10 @@ impl KvScheduler {
sequences.update_workers(new_worker_ids);
}
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request(
request_id,
token_sequence,
token_seq,
isl_tokens,
response.overlap_blocks,
response.best_worker_id,
);
......@@ -258,10 +256,9 @@ impl KvScheduler {
Ok(response.best_worker_id)
}
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
pub async fn mark_prefill_completed(&self, request_id: &String) {
let mut sequences = self.sequences.lock().await;
sequences.push(request_id, tokens)
sequences.mark_prefill_completed(request_id)
}
/// Free all blocks associated with a request
......
......@@ -36,50 +36,24 @@
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::TokenBlockSequence;
use crate::tokens::SequenceHash;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::Duration;
use uuid;
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
}
unique_blocks
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, TokenBlockSequence>,
partial_blocks: HashMap<RequestId, UniqueBlock>,
active_seqs: HashMap<RequestId, Vec<SequenceHash>>,
prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>,
unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>,
#[getter(copy)]
block_size: usize,
......@@ -99,7 +73,6 @@ impl ActiveSequences {
Self {
active_seqs: HashMap::new(),
partial_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
block_size,
......@@ -108,24 +81,20 @@ impl ActiveSequences {
}
}
fn add_block(&mut self, request_id: RequestId, block: &UniqueBlock) {
fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) {
let is_new_block = !self.unique_blocks.contains_key(block);
self.unique_blocks
.entry(block.clone())
.entry(*block)
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
}
if matches!(block, UniqueBlock::PartialBlock(_)) {
self.partial_blocks.insert(request_id, block.clone());
};
}
fn remove_block(&mut self, request_id: &RequestId, block: &UniqueBlock) {
fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else {
panic!("Cannot remove a block that does not exist.")
};
......@@ -142,17 +111,16 @@ impl ActiveSequences {
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32,
) -> usize {
let prefill_tokens = self.new_tokens(&token_sequence, overlap);
let prefill_tokens = self.new_tokens(isl, overlap);
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size);
for block in &blocks {
for block in &token_sequence {
self.add_block(request_id.clone(), block);
}
......@@ -161,30 +129,35 @@ impl ActiveSequences {
self.active_blocks
}
pub fn new_tokens(&self, token_sequence: &TokenBlockSequence, overlap: u32) -> usize {
let input_tokens = token_sequence.total_tokens();
input_tokens
.checked_sub((overlap as usize) * self.block_size)
.unwrap_or_else(|| {
panic!("prefill_tokens < 0 with overlap {overlap} and ISL {input_tokens}")
})
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
if let Some(tokens) = self.prefill_tokens.remove(request_id) {
self.active_tokens = self
.active_tokens
.checked_sub(tokens.saturating_sub(1)) // Keep 1 token for decoding
.expect("active_tokens underflow");
}
}
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
isl.checked_sub((overlap as usize) * self.block_size)
.unwrap_or_else(|| panic!("prefill_tokens < 0 with overlap {overlap} and ISL {isl}"))
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: &TokenBlockSequence,
token_sequence: &[SequenceHash],
isl: usize,
overlap: u32,
) -> (usize, usize) {
let potential_blocks = self.new_blocks(token_sequence) + self.active_blocks;
let potential_tokens = self.new_tokens(token_sequence, overlap) + self.active_tokens;
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size);
blocks
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence
.iter()
.filter(|block| !self.unique_blocks.contains_key(block))
.count()
......@@ -192,7 +165,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks
}
......@@ -209,110 +182,49 @@ impl ActiveSequences {
return 0;
};
let blocks = create_unique_blocks_from_sequence(token_seq, None, self.block_size);
for block in blocks {
if matches!(block, UniqueBlock::FullBlock(_)) {
self.remove_block(request_id, &block);
}
}
if let Some(partial_block) = self.partial_blocks.remove(request_id) {
self.remove_block(request_id, &partial_block);
for block in token_seq.clone() {
self.remove_block(request_id, &block)
}
self.active_seqs.remove(request_id).unwrap();
self.active_blocks
}
/// Push tokens to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize {
if let Some(prefill_tokens) = self.prefill_tokens.get(request_id).cloned() {
self.prefill_tokens.remove(request_id);
// decoding has one active token
self.active_tokens = self
.active_tokens
.checked_sub(prefill_tokens)
.expect("active_tokens < 0")
+ 1;
};
// Collect operations to perform after releasing the borrow
let mut blocks_to_remove = Vec::new();
let mut blocks_to_add = Vec::new();
{
let token_seq = self
.active_seqs
.get_mut(request_id)
.expect("Request ID not found for token push");
for &token in tokens {
token_seq.append(token).expect("Token push failed.");
// Guard: skip if we didn't cross a block boundary
if token_seq.total_tokens() % self.block_size != 1 {
continue;
}
let last_seq_hash = token_seq
.last_complete_block()
.map(|block| block.sequence_hash());
// Queue operations for later
if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() {
blocks_to_remove.push(partial_block);
}
if let Some(full_block) = last_seq_hash {
blocks_to_add.push(UniqueBlock::FullBlock(full_block));
}
blocks_to_add.push(UniqueBlock::default());
}
} // token_seq borrow is dropped here
// Now perform all the queued operations
for block in blocks_to_remove {
self.remove_block(request_id, &block);
}
for block in blocks_to_add {
self.add_block(request_id.clone(), &block);
}
self.active_blocks
}
}
#[derive(Debug)]
enum UpdateSequences {
AddRequest {
request_id: RequestId,
token_sequence: TokenBlockSequence,
token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32,
},
Free {
request_id: RequestId,
},
Push {
MarkPrefillCompleted {
request_id: RequestId,
tokens: Vec<u32>, // Changed from token: u32
},
NewBlocks {
token_sequence: Arc<TokenBlockSequence>,
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: mpsc::SyncSender<usize>,
},
PotentialBlocks {
token_sequence: Arc<TokenBlockSequence>,
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: mpsc::SyncSender<usize>,
},
PotentialBlocksAndTokens {
token_sequence: Arc<TokenBlockSequence>,
token_sequence: Arc<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
resp_tx: mpsc::SyncSender<(usize, usize)>,
},
ActiveBlocks {
resp_tx: mpsc::SyncSender<usize>,
},
ActiveTokens {
resp_tx: mpsc::SyncSender<usize>,
},
Shutdown,
}
......@@ -357,15 +269,16 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::AddRequest {
request_id,
token_sequence,
isl,
overlap,
} => {
active_sequences.add_request(request_id, token_sequence, overlap);
active_sequences.add_request(request_id, token_sequence, isl, overlap);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::Push { request_id, tokens } => {
active_sequences.push(&request_id, &tokens); // Changed to pass tokens slice
UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id);
}
UpdateSequences::NewBlocks {
token_sequence,
......@@ -383,17 +296,25 @@ impl ActiveSequencesMultiWorker {
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
isl,
overlap,
resp_tx,
} => {
let potential_tokens =
active_sequences.potential_blocks_and_tokens(&token_sequence, overlap);
let potential_tokens = active_sequences.potential_blocks_and_tokens(
&token_sequence,
isl,
overlap,
);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => {
break;
}
......@@ -443,7 +364,8 @@ impl ActiveSequencesMultiWorker {
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32,
worker_id: WorkerId,
) {
......@@ -457,6 +379,7 @@ impl ActiveSequencesMultiWorker {
.send(UpdateSequences::AddRequest {
request_id,
token_sequence,
isl,
overlap,
})
.expect("Failed to send add_request command to worker");
......@@ -478,18 +401,19 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(request_id);
}
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) {
/// Mark prefill as completed for a request
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
let worker_id = self
.request_to_worker
.get(request_id)
.copied()
.expect("Request ID not found in request_to_worker mapping");
self.senders[&worker_id]
.send(UpdateSequences::Push {
.send(UpdateSequences::MarkPrefillCompleted {
request_id: request_id.clone(),
tokens: tokens.to_vec(), // Convert to Vec
})
.expect("Failed to send push command to worker");
.expect("Failed to send mark_prefill_completed command to worker");
}
/// Get the number of workers
......@@ -500,8 +424,8 @@ impl ActiveSequencesMultiWorker {
/// Generic method to query all workers with a given command
fn query_workers(
&self,
token_sequence: Option<TokenBlockSequence>,
command_fn: impl Fn(Option<Arc<TokenBlockSequence>>, mpsc::SyncSender<usize>) -> UpdateSequences,
token_sequence: Option<Vec<SequenceHash>>,
command_fn: impl Fn(Option<Arc<Vec<SequenceHash>>>, mpsc::SyncSender<usize>) -> UpdateSequences,
) -> HashMap<WorkerId, usize> {
let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
......@@ -528,7 +452,7 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub fn new_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> {
pub fn new_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts,
......@@ -539,7 +463,7 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub fn potential_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> {
pub fn potential_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts,
......@@ -552,7 +476,8 @@ impl ActiveSequencesMultiWorker {
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub fn potential_blocks_and_tokens(
&self,
token_sequence: TokenBlockSequence,
token_sequence: Vec<SequenceHash>,
isl: usize,
overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
let mut potential_blocks = HashMap::new();
......@@ -568,6 +493,7 @@ impl ActiveSequencesMultiWorker {
sender
.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(),
isl,
overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0),
resp_tx,
})
......@@ -590,6 +516,11 @@ impl ActiveSequencesMultiWorker {
pub fn active_blocks(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
}
/// Query all workers for their current number of active tokens
pub fn active_tokens(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
}
}
impl Drop for ActiveSequencesMultiWorker {
......@@ -609,91 +540,102 @@ impl Drop for ActiveSequencesMultiWorker {
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::Tokens;
#[test]
fn test_shared_sequence_manager_operations() {
let block_size = 4;
let mut manager = ActiveSequences::new(block_size);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]), 0);
manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once
assert_eq!(manager.active_tokens(), 1);
assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]), 1);
assert_eq!(manager.active_tokens(), 1 + 3);
assert_eq!(manager.active_blocks(), 3);
// Check that only one key is FullBlock with both requests sharing it
let mut full_block_count = 0;
let mut shared_block_requests = None;
for (block, requests) in &manager.unique_blocks {
if let UniqueBlock::FullBlock(_) = block {
full_block_count += 1;
if requests.len() == 2 {
shared_block_requests = Some(requests.clone());
}
}
}
assert_eq!(full_block_count, 1);
assert!(shared_block_requests.is_some());
let shared_requests = shared_block_requests.unwrap();
assert!(shared_requests.contains("0"));
assert!(shared_requests.contains("1"));
let new_blocks = manager.new_blocks(&to_sequence(vec![0, 1, 2, 3, 4, 5]));
assert_eq!(new_blocks, 1);
// Step 3: Free request 1
manager.free(&"1".to_string());
assert_eq!(manager.active_blocks(), 2);
// Step 4: Free request 0
manager.free(&"0".to_string());
assert_eq!(manager.active_tokens(), 0);
assert_eq!(manager.active_blocks(), 0);
assert_eq!(manager.unique_blocks.len(), 0);
assert_eq!(manager.partial_blocks.len(), 0);
assert_eq!(manager.active_seqs.len(), 0);
}
#[test]
fn test_active_sequences_multi_worker() {
let block_size = 4;
fn test_multi_worker_block_sharing() {
// Create multi-worker sequence manager with 3 workers
let block_size = 4; // arbitrary block size
let worker_ids = vec![0, 1, 2];
let mut manager = ActiveSequencesMultiWorker::new(block_size, worker_ids);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0, 0);
// Send request [0, 1, 2] to worker 1, then push 3 and 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 0, 1);
manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once
// Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 0, 2);
// Check new_blocks on tokens [0, 1, 2, 3, 4]
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
assert_eq!(new_blocks_map[&0], 1); // Worker 0 would have 1 new block
assert_eq!(new_blocks_map[&1], 1); // Worker 1 would have 1 new block
assert_eq!(new_blocks_map[&2], 2); // Worker 2 would have 2 new blocks
manager.update_workers(vec![0, 1]);
manager.update_workers(vec![0, 1, 3]);
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
assert_eq!(new_blocks_map.len(), 3);
assert_eq!(new_blocks_map[&3], 2);
let mut seq_manager = ActiveSequencesMultiWorker::new(block_size, worker_ids);
// Add requests to each worker
// Worker 0: sequence [0, 1, 2]
seq_manager.add_request(
"request_0".to_string(),
vec![0, 1, 2],
12, // ISL (3 blocks * 4 block_size)
0, // no overlap
0, // worker_id
);
// Worker 1: sequence [3, 4]
seq_manager.add_request(
"request_1".to_string(),
vec![3, 4],
8, // ISL (2 blocks * 4 block_size)
0, // no overlap
1, // worker_id
);
// Worker 2: sequence [0, 1, 2, 3]
seq_manager.add_request(
"request_2".to_string(),
vec![0, 1, 2, 3],
16, // ISL (4 blocks * 4 block_size)
0, // no overlap
2, // worker_id
);
// Verify active tokens after adding requests
let tokens_after_add = seq_manager.active_tokens();
assert_eq!(
tokens_after_add[&0], 12,
"Worker 0 should have 12 active tokens"
);
assert_eq!(
tokens_after_add[&1], 8,
"Worker 1 should have 8 active tokens"
);
assert_eq!(
tokens_after_add[&2], 16,
"Worker 2 should have 16 active tokens"
);
// Test potential blocks for sequence [0, 1]
let potential_blocks = seq_manager.potential_blocks(vec![0, 1]);
// Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1])
assert_eq!(
potential_blocks[&0], 3,
"Worker 0 should have 3 potential blocks"
);
// Worker 1 should return 4 (has blocks 3, 4, would need to add blocks 0, 1)
assert_eq!(
potential_blocks[&1], 4,
"Worker 1 should have 4 potential blocks"
);
// Worker 2 should return 4 (already has blocks 0, 1, 2, 3, so no new blocks needed for [0, 1])
assert_eq!(
potential_blocks[&2], 4,
"Worker 2 should have 4 potential blocks"
);
// Free all original requests
seq_manager.free(&"request_0".to_string());
seq_manager.free(&"request_1".to_string());
seq_manager.free(&"request_2".to_string());
// Verify active blocks are zero for all workers
let active_blocks = seq_manager.active_blocks();
assert_eq!(active_blocks[&0], 0, "Worker 0 should have 0 active blocks");
assert_eq!(active_blocks[&1], 0, "Worker 1 should have 0 active blocks");
assert_eq!(active_blocks[&2], 0, "Worker 2 should have 0 active blocks");
// Verify active tokens are zero for all workers
let final_tokens = seq_manager.active_tokens();
assert_eq!(
final_tokens[&0], 0,
"Worker 0 should have 0 active tokens after freeing all"
);
assert_eq!(
final_tokens[&1], 0,
"Worker 1 should have 0 active tokens after freeing all"
);
assert_eq!(
final_tokens[&2], 0,
"Worker 2 should have 0 active tokens after freeing all"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
pub use crate::kv_router::protocols::ForwardPassMetrics;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::pipeline::network::{
ingress::push_endpoint::PushEndpoint,
PushWorkHandler,
};
use dynamo_runtime::transports::nats::{self, ServiceExt};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing as log;
#[derive(Builder)]
pub struct KvRoutedIngress {
#[builder(setter(into))]
pub service_name: String,
#[builder(setter(into))]
pub worker_id: String,
pub nats: nats::Client,
pub service_handler: Arc<dyn PushWorkHandler>,
pub metrics_rx: watch::Receiver<Arc<ForwardPassMetrics>>,
pub cancellation_token: CancellationToken,
}
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
impl KvRoutedIngress {
pub fn builder() -> KvRoutedIngressBuilder {
KvRoutedIngressBuilder::default()
}
pub async fn start(self) -> Result<()> {
let worker_id = self.worker_id;
log::trace!(
worker_id,
"Starting nats service: {}:{}",
self.service_name,
VERSION
);
let mut metrics_rx = self.metrics_rx;
let worker_id_clone = worker_id.clone();
let service = self
.nats
.client()
.service_builder()
.description("A handy min max service")
.stats_handler(move |name, stats| {
log::debug!(
worker_id = worker_id_clone.as_str(),
"[IN worker?] Stats for service {}: {:?}",
name,
stats
);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.start(self.service_name.as_str(), VERSION)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
let group = service.group(self.service_name.as_str());
log::trace!(worker_id, "Starting endpoint: {}", worker_id);
// creates an endpoint for the service
let service_endpoint = group
.endpoint(worker_id.clone())
.await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let push_endpoint = PushEndpoint::builder()
.service_handler(self.service_handler)
.cancellation_token(self.cancellation_token)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
push_endpoint.start(service_endpoint).await
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment