Unverified Commit df91fce2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: prefill aware routing (#1895)

parent ad8ad66b
...@@ -93,7 +93,7 @@ async fn mock_event_publisher(namespace: Namespace) { ...@@ -93,7 +93,7 @@ async fn mock_event_publisher(namespace: Namespace) {
let event = KVHitRateEvent { let event = KVHitRateEvent {
worker_id, worker_id,
isl_blocks, isl_blocks,
overlap_blocks, overlap_blocks: overlap_blocks as u32,
}; };
if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await { if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {
......
...@@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> {
&config_clone, &config_clone,
event.worker_id, event.worker_id,
event.isl_blocks, event.isl_blocks,
event.overlap_blocks, event.overlap_blocks as usize,
); );
} }
Err(e) => { Err(e) => {
......
...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm. ...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage: Usage:
``` ```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--use-kv-events=true] [--verbosity (-v|-vv)] dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.0] [--use-kv-events=true] [--verbosity (-v|-vv)]
``` ```
Example: `dynamo run Qwen/Qwen3-0.6B` Example: `dynamo run Qwen/Qwen3-0.6B`
......
...@@ -118,13 +118,13 @@ pub struct Flags { ...@@ -118,13 +118,13 @@ pub struct Flags {
pub max_num_batched_tokens: Option<u32>, pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection. /// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0 /// Higher values prioritize KV cache reuse. Default: 1.0
#[arg(long)] #[arg(long)]
pub kv_overlap_score_weight: Option<f64>, pub kv_overlap_score_weight: Option<f64>,
/// KV Router: Temperature for worker sampling via softmax. /// KV Router: Temperature for worker sampling via softmax.
/// Higher values promote more randomness, and 0 fallbacks to deterministic. /// Higher values promote more randomness, and 0 fallbacks to deterministic.
/// Default: 0.5 /// Default: 0.0
#[arg(long)] #[arg(long)]
pub router_temperature: Option<f64>, pub router_temperature: Option<f64>,
......
...@@ -78,7 +78,7 @@ impl Default for KvRouterConfig { ...@@ -78,7 +78,7 @@ impl Default for KvRouterConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
router_temperature: 0.5, router_temperature: 0.0,
use_kv_events: true, use_kv_events: true,
max_num_batched_tokens: 8192, max_num_batched_tokens: 8192,
} }
...@@ -337,6 +337,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -337,6 +337,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let mut accumulated_tokens = Vec::new(); let mut accumulated_tokens = Vec::new();
let mut total_output_length = 0usize; let mut total_output_length = 0usize;
let mut last_block_index = (isl.saturating_sub(1)) / block_size; let mut last_block_index = (isl.saturating_sub(1)) / block_size;
let mut first_push_done = false;
while let Some(item) = response_stream.next().await { while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response // Track tokens if they exist in the response
...@@ -353,12 +354,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -353,12 +354,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
accumulated_tokens.extend_from_slice(&output.token_ids); accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len(); total_output_length += output.token_ids.len();
// Check if we've moved to a new block // Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size; let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
if current_block_index > last_block_index { let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index);
if should_push {
chooser.push(&request_id, &accumulated_tokens).await; chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear(); accumulated_tokens.clear();
last_block_index = current_block_index; last_block_index = current_block_index;
if !first_push_done {
first_push_done = true;
}
} }
yield item; yield item;
......
...@@ -36,7 +36,7 @@ pub struct WorkerSelectionResult { ...@@ -36,7 +36,7 @@ pub struct WorkerSelectionResult {
/// The number of blocks that the selected worker may already have cached. /// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate. /// This is not a guarantee, but an estimate.
pub overlap_blocks: usize, pub overlap_blocks: u32,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
......
...@@ -25,7 +25,6 @@ use tokio::sync::Mutex; ...@@ -25,7 +25,6 @@ use tokio::sync::Mutex;
use super::protocols::WorkerSelectionResult; use super::protocols::WorkerSelectionResult;
use super::WorkerSelector; use super::WorkerSelector;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::LoadMetrics; use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::sequence::ActiveSequencesMultiWorker; use crate::kv_router::sequence::ActiveSequencesMultiWorker;
...@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence; ...@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence;
pub struct KVHitRateEvent { pub struct KVHitRateEvent {
pub worker_id: i64, pub worker_id: i64,
pub isl_blocks: usize, pub isl_blocks: usize,
pub overlap_blocks: usize, pub overlap_blocks: u32,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
...@@ -79,13 +78,15 @@ impl Endpoint { ...@@ -79,13 +78,15 @@ impl Endpoint {
#[derive(Debug)] #[derive(Debug)]
pub struct SchedulingResponse { pub struct SchedulingResponse {
pub best_worker_id: i64, pub best_worker_id: i64,
pub overlap_blocks: u32, // Add this field
pub endpoints_changed: Option<Vec<i64>>, pub endpoints_changed: Option<Vec<i64>>,
} }
pub struct SchedulingRequest { pub struct SchedulingRequest {
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlap: OverlapScores, pub overlaps: OverlapScores,
pub potential_blocks: HashMap<i64, usize>, pub potential_blocks: HashMap<i64, usize>,
pub potential_tokens: HashMap<i64, usize>,
resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>, resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
} }
...@@ -174,6 +175,7 @@ impl KvScheduler { ...@@ -174,6 +175,7 @@ impl KvScheduler {
let response = SchedulingResponse { let response = SchedulingResponse {
best_worker_id: selection.worker_id, best_worker_id: selection.worker_id,
overlap_blocks: selection.overlap_blocks,
endpoints_changed: pending_endpoint_update.take(), endpoints_changed: pending_endpoint_update.take(),
}; };
request.respond(response); request.respond(response);
...@@ -207,18 +209,20 @@ impl KvScheduler { ...@@ -207,18 +209,20 @@ impl KvScheduler {
isl_tokens: usize, isl_tokens: usize,
block_size: u32, block_size: u32,
tokens: &[u32], tokens: &[u32],
overlap: OverlapScores, overlaps: OverlapScores,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None); let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let potential_blocks = sequences.potential_blocks(token_sequence); let (potential_blocks, potential_tokens) =
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
isl_tokens, isl_tokens,
overlap, overlaps,
potential_blocks, potential_blocks,
potential_tokens,
resp_tx, resp_tx,
}; };
self.request_tx self.request_tx
...@@ -234,31 +238,16 @@ impl KvScheduler { ...@@ -234,31 +238,16 @@ impl KvScheduler {
} }
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None); let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request(request_id, token_sequence, response.best_worker_id); sequences.add_request(
request_id,
token_sequence,
response.overlap_blocks,
response.best_worker_id,
);
Ok(response.best_worker_id) Ok(response.best_worker_id)
} }
/// Find the potential blocks for each worker if the sequence were routed there
pub async fn potential_blocks(
&self,
token_sequence: TokenBlockSequence,
) -> HashMap<i64, usize> {
let sequences = self.sequences.lock().await;
sequences.potential_blocks(token_sequence)
}
/// Add a new request with its initial tokens to a specific worker
pub async fn add_request(
&self,
request_id: String,
token_sequence: TokenBlockSequence,
worker_id: WorkerId,
) {
let mut sequences = self.sequences.lock().await;
sequences.add_request(request_id, token_sequence, worker_id)
}
/// Push tokens to a specific request's sequence /// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) { pub async fn push(&self, request_id: &String, tokens: &[u32]) {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
...@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector {
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} }
let request_blocks = request.isl_tokens.div_ceil(block_size as usize); let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
// active blocks for decoding
let potential_active_blocks = &request.potential_blocks; let potential_active_blocks = &request.potential_blocks;
// active tokens in the batch (processed by the linear layers), mostly prefill tokens
let potential_active_tokens = &request.potential_tokens;
let mut worker_logits = HashMap::new(); let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY; let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker // Calculate logits for each worker
for (worker_id, _) in workers.endpoints.iter() { for (worker_id, _) in workers.endpoints.iter() {
let cached_blocks = request.overlap.scores.get(worker_id).copied().unwrap_or(0) as f64; // this is the number of tokens each worker would have if the request were scheduled there
let prefill_blocks = request_blocks as f64 - cached_blocks; let potential_tokens = *potential_active_tokens.get(worker_id).unwrap_or_else(|| {
tracing::warn!(
"assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
);
&isl
}) as f64;
// this is the number of blocks each worker would have if the request were scheduled there // this is the number of blocks each worker would have if the request were scheduled there
let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(|| let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
{tracing::warn!("assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet"); {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
&0 &request_blocks
}) as f64; }) as f64;
let potential_prefill_blocks = potential_tokens / (block_size as f64);
// Calculate logit (lower is better) // Calculate logit (lower is better)
let logit = let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks
self.kv_router_config.overlap_score_weight * prefill_blocks + potential_blocks; + potential_blocks;
max_logit = max_logit.max(logit); max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit); worker_logits.insert(*worker_id, logit);
tracing::info!( tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {cached_blocks})", "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})",
self.kv_router_config.overlap_score_weight, self.kv_router_config.overlap_score_weight,
cached_blocks = cached_blocks overlaps.get(worker_id).unwrap_or(&0),
); );
} }
...@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector {
let temperature = self.kv_router_config.router_temperature; let temperature = self.kv_router_config.router_temperature;
let best_worker_id = softmax_sample(&worker_logits, temperature); let best_worker_id = softmax_sample(&worker_logits, temperature);
let overlap_blocks = request let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0);
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0) as usize;
let best_logit = worker_logits[&best_worker_id]; let best_logit = worker_logits[&best_worker_id];
tracing::info!( tracing::info!(
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple //! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples). //! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId; use crate::kv_router::indexer::WorkerId;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use crate::tokens::TokenBlockSequence; use crate::tokens::TokenBlockSequence;
...@@ -76,6 +77,8 @@ pub struct ActiveSequences { ...@@ -76,6 +77,8 @@ pub struct ActiveSequences {
partial_blocks: HashMap<RequestId, UniqueBlock>, partial_blocks: HashMap<RequestId, UniqueBlock>,
prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>, unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>,
#[getter(copy)] #[getter(copy)]
...@@ -83,6 +86,9 @@ pub struct ActiveSequences { ...@@ -83,6 +86,9 @@ pub struct ActiveSequences {
#[getter(copy)] #[getter(copy)]
active_blocks: usize, active_blocks: usize,
#[getter(copy)]
active_tokens: usize,
} }
impl ActiveSequences { impl ActiveSequences {
...@@ -94,9 +100,11 @@ impl ActiveSequences { ...@@ -94,9 +100,11 @@ impl ActiveSequences {
Self { Self {
active_seqs: HashMap::new(), active_seqs: HashMap::new(),
partial_blocks: HashMap::new(), partial_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(), unique_blocks: HashMap::new(),
block_size, block_size,
active_blocks: 0, active_blocks: 0,
active_tokens: 0,
} }
} }
...@@ -135,7 +143,13 @@ impl ActiveSequences { ...@@ -135,7 +143,13 @@ impl ActiveSequences {
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: TokenBlockSequence, token_sequence: TokenBlockSequence,
overlap: u32,
) -> usize { ) -> usize {
let prefill_tokens = self.new_tokens(&token_sequence, overlap);
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size); let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size);
for block in &blocks { for block in &blocks {
...@@ -147,6 +161,25 @@ impl ActiveSequences { ...@@ -147,6 +161,25 @@ impl ActiveSequences {
self.active_blocks self.active_blocks
} }
pub fn new_tokens(&self, token_sequence: &TokenBlockSequence, overlap: u32) -> usize {
let input_tokens = token_sequence.total_tokens();
input_tokens
.checked_sub((overlap as usize) * self.block_size)
.unwrap_or_else(|| {
panic!("prefill_tokens < 0 with overlap {overlap} and ISL {input_tokens}")
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: &TokenBlockSequence,
overlap: u32,
) -> (usize, usize) {
let potential_blocks = self.new_blocks(token_sequence) + self.active_blocks;
let potential_tokens = self.new_tokens(token_sequence, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
}
/// Match a request against existing blocks and return the number of new blocks that would be added /// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize { pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size); let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size);
...@@ -165,6 +198,12 @@ impl ActiveSequences { ...@@ -165,6 +198,12 @@ impl ActiveSequences {
/// Free all blocks associated with a request /// Free all blocks associated with a request
pub fn free(&mut self, request_id: &RequestId) -> usize { pub fn free(&mut self, request_id: &RequestId) -> usize {
// decoding has one active token
self.active_tokens = self
.active_tokens
.checked_sub(self.prefill_tokens.remove(request_id).unwrap_or(1))
.expect("active_tokens < 0");
let Some(token_seq) = self.active_seqs.get(request_id) else { let Some(token_seq) = self.active_seqs.get(request_id) else {
tracing::warn!("Trying to free free non-existent request {request_id}"); tracing::warn!("Trying to free free non-existent request {request_id}");
return 0; return 0;
...@@ -187,6 +226,16 @@ impl ActiveSequences { ...@@ -187,6 +226,16 @@ impl ActiveSequences {
/// Push tokens to a specific request's sequence /// Push tokens to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize { pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize {
if let Some(prefill_tokens) = self.prefill_tokens.get(request_id).cloned() {
self.prefill_tokens.remove(request_id);
// decoding has one active token
self.active_tokens = self
.active_tokens
.checked_sub(prefill_tokens)
.expect("active_tokens < 0")
+ 1;
};
// Collect operations to perform after releasing the borrow // Collect operations to perform after releasing the borrow
let mut blocks_to_remove = Vec::new(); let mut blocks_to_remove = Vec::new();
let mut blocks_to_add = Vec::new(); let mut blocks_to_add = Vec::new();
...@@ -239,6 +288,7 @@ enum UpdateSequences { ...@@ -239,6 +288,7 @@ enum UpdateSequences {
AddRequest { AddRequest {
request_id: RequestId, request_id: RequestId,
token_sequence: TokenBlockSequence, token_sequence: TokenBlockSequence,
overlap: u32,
}, },
Free { Free {
request_id: RequestId, request_id: RequestId,
...@@ -255,6 +305,11 @@ enum UpdateSequences { ...@@ -255,6 +305,11 @@ enum UpdateSequences {
token_sequence: Arc<TokenBlockSequence>, token_sequence: Arc<TokenBlockSequence>,
resp_tx: mpsc::SyncSender<usize>, resp_tx: mpsc::SyncSender<usize>,
}, },
PotentialBlocksAndTokens {
token_sequence: Arc<TokenBlockSequence>,
overlap: u32,
resp_tx: mpsc::SyncSender<(usize, usize)>,
},
ActiveBlocks { ActiveBlocks {
resp_tx: mpsc::SyncSender<usize>, resp_tx: mpsc::SyncSender<usize>,
}, },
...@@ -302,8 +357,9 @@ impl ActiveSequencesMultiWorker { ...@@ -302,8 +357,9 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::AddRequest { UpdateSequences::AddRequest {
request_id, request_id,
token_sequence, token_sequence,
overlap,
} => { } => {
active_sequences.add_request(request_id, token_sequence); active_sequences.add_request(request_id, token_sequence, overlap);
} }
UpdateSequences::Free { request_id } => { UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id); active_sequences.free(&request_id);
...@@ -325,6 +381,15 @@ impl ActiveSequencesMultiWorker { ...@@ -325,6 +381,15 @@ impl ActiveSequencesMultiWorker {
let potential_blocks = active_sequences.potential_blocks(&token_sequence); let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks); let _ = resp_tx.send(potential_blocks);
} }
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
overlap,
resp_tx,
} => {
let potential_tokens =
active_sequences.potential_blocks_and_tokens(&token_sequence, overlap);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => { UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks(); let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks); let _ = resp_tx.send(active_blocks);
...@@ -379,6 +444,7 @@ impl ActiveSequencesMultiWorker { ...@@ -379,6 +444,7 @@ impl ActiveSequencesMultiWorker {
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: TokenBlockSequence, token_sequence: TokenBlockSequence,
overlap: u32,
worker_id: WorkerId, worker_id: WorkerId,
) { ) {
if !self.senders.contains_key(&worker_id) { if !self.senders.contains_key(&worker_id) {
...@@ -391,6 +457,7 @@ impl ActiveSequencesMultiWorker { ...@@ -391,6 +457,7 @@ impl ActiveSequencesMultiWorker {
.send(UpdateSequences::AddRequest { .send(UpdateSequences::AddRequest {
request_id, request_id,
token_sequence, token_sequence,
overlap,
}) })
.expect("Failed to send add_request command to worker"); .expect("Failed to send add_request command to worker");
} }
...@@ -482,6 +549,43 @@ impl ActiveSequencesMultiWorker { ...@@ -482,6 +549,43 @@ impl ActiveSequencesMultiWorker {
}) })
} }
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub fn potential_blocks_and_tokens(
&self,
token_sequence: TokenBlockSequence,
overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = Arc::new(token_sequence);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for (worker_id, sender) in &self.senders {
let (resp_tx, resp_rx) = mpsc::sync_channel(0);
receivers.push((worker_id, resp_rx));
sender
.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(),
overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0),
resp_tx,
})
.expect("Failed to send potential_tokens command to worker");
}
// Collect results from all workers
for (worker_id, receiver) in receivers {
let (blocks, tokens) = receiver
.recv_timeout(Duration::from_secs(1))
.expect("Failed to receive response from worker");
potential_blocks.insert(*worker_id, blocks);
potential_tokens.insert(*worker_id, tokens);
}
(potential_blocks, potential_tokens)
}
/// Query all workers for their current number of active blocks /// Query all workers for their current number of active blocks
pub fn active_blocks(&self) -> HashMap<WorkerId, usize> { pub fn active_blocks(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
...@@ -515,14 +619,15 @@ mod tests { ...@@ -515,14 +619,15 @@ mod tests {
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None); |tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4 // Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2])); manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]), 0);
manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once
assert_eq!(manager.active_tokens(), 1);
assert_eq!(manager.active_blocks(), 2); assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1); assert_eq!(manager.partial_blocks.len(), 1);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6] // Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6])); manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]), 1);
assert_eq!(manager.active_tokens(), 1 + 3);
assert_eq!(manager.active_blocks(), 3); assert_eq!(manager.active_blocks(), 3);
// Check that only one key is FullBlock with both requests sharing it // Check that only one key is FullBlock with both requests sharing it
...@@ -551,6 +656,7 @@ mod tests { ...@@ -551,6 +656,7 @@ mod tests {
// Step 4: Free request 0 // Step 4: Free request 0
manager.free(&"0".to_string()); manager.free(&"0".to_string());
assert_eq!(manager.active_tokens(), 0);
assert_eq!(manager.active_blocks(), 0); assert_eq!(manager.active_blocks(), 0);
assert_eq!(manager.unique_blocks.len(), 0); assert_eq!(manager.unique_blocks.len(), 0);
assert_eq!(manager.partial_blocks.len(), 0); assert_eq!(manager.partial_blocks.len(), 0);
...@@ -566,14 +672,14 @@ mod tests { ...@@ -566,14 +672,14 @@ mod tests {
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None); |tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Send request [0, 1, 2, 3] to worker 0 // Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0); manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0, 0);
// Send request [0, 1, 2] to worker 1, then push 3 and 4 // Send request [0, 1, 2] to worker 1, then push 3 and 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1); manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 0, 1);
manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once
// Send request [0, 1, 2] to worker 2 // Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2); manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 0, 2);
// Check new_blocks on tokens [0, 1, 2, 3, 4] // Check new_blocks on tokens [0, 1, 2, 3, 4]
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4])); let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment