"lib/bindings/vscode:/vscode.git/clone" did not exist on "31f5ed3ce7db3eef7b2991963644f5fdb17ab063"
Unverified Commit df91fce2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: prefill aware routing (#1895)

parent ad8ad66b
......@@ -93,7 +93,7 @@ async fn mock_event_publisher(namespace: Namespace) {
let event = KVHitRateEvent {
worker_id,
isl_blocks,
overlap_blocks,
overlap_blocks: overlap_blocks as u32,
};
if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {
......
......@@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> {
&config_clone,
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
event.overlap_blocks as usize,
);
}
Err(e) => {
......
......@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage:
```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--use-kv-events=true] [--verbosity (-v|-vv)]
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.0] [--use-kv-events=true] [--verbosity (-v|-vv)]
```
Example: `dynamo run Qwen/Qwen3-0.6B`
......
......@@ -118,13 +118,13 @@ pub struct Flags {
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
/// Higher values prioritize KV cache reuse. Default: 1.0
#[arg(long)]
pub kv_overlap_score_weight: Option<f64>,
/// KV Router: Temperature for worker sampling via softmax.
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
/// Default: 0.5
/// Default: 0.0
#[arg(long)]
pub router_temperature: Option<f64>,
......
......@@ -78,7 +78,7 @@ impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
router_temperature: 0.5,
router_temperature: 0.0,
use_kv_events: true,
max_num_batched_tokens: 8192,
}
......@@ -337,6 +337,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let mut accumulated_tokens = Vec::new();
let mut total_output_length = 0usize;
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
let mut first_push_done = false;
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
......@@ -353,12 +354,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Check if we've moved to a new block
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
if current_block_index > last_block_index {
let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index);
if should_push {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
if !first_push_done {
first_push_done = true;
}
}
yield item;
......
......@@ -36,7 +36,7 @@ pub struct WorkerSelectionResult {
/// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate.
pub overlap_blocks: usize,
pub overlap_blocks: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
......
......@@ -25,7 +25,6 @@ use tokio::sync::Mutex;
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
......@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence;
pub struct KVHitRateEvent {
pub worker_id: i64,
pub isl_blocks: usize,
pub overlap_blocks: usize,
pub overlap_blocks: u32,
}
#[derive(Debug, thiserror::Error)]
......@@ -79,13 +78,15 @@ impl Endpoint {
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker_id: i64,
pub overlap_blocks: u32, // Add this field
pub endpoints_changed: Option<Vec<i64>>,
}
pub struct SchedulingRequest {
pub isl_tokens: usize,
pub overlap: OverlapScores,
pub overlaps: OverlapScores,
pub potential_blocks: HashMap<i64, usize>,
pub potential_tokens: HashMap<i64, usize>,
resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
}
......@@ -174,6 +175,7 @@ impl KvScheduler {
let response = SchedulingResponse {
best_worker_id: selection.worker_id,
overlap_blocks: selection.overlap_blocks,
endpoints_changed: pending_endpoint_update.take(),
};
request.respond(response);
......@@ -207,18 +209,20 @@ impl KvScheduler {
isl_tokens: usize,
block_size: u32,
tokens: &[u32],
overlap: OverlapScores,
overlaps: OverlapScores,
) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let potential_blocks = sequences.potential_blocks(token_sequence);
let (potential_blocks, potential_tokens) =
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
isl_tokens,
overlap,
overlaps,
potential_blocks,
potential_tokens,
resp_tx,
};
self.request_tx
......@@ -234,31 +238,16 @@ impl KvScheduler {
}
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request(request_id, token_sequence, response.best_worker_id);
sequences.add_request(
request_id,
token_sequence,
response.overlap_blocks,
response.best_worker_id,
);
Ok(response.best_worker_id)
}
/// Find the potential blocks for each worker if the sequence were routed there
pub async fn potential_blocks(
&self,
token_sequence: TokenBlockSequence,
) -> HashMap<i64, usize> {
let sequences = self.sequences.lock().await;
sequences.potential_blocks(token_sequence)
}
/// Add a new request with its initial tokens to a specific worker
pub async fn add_request(
&self,
request_id: String,
token_sequence: TokenBlockSequence,
worker_id: WorkerId,
) {
let mut sequences = self.sequences.lock().await;
sequences.add_request(request_id, token_sequence, worker_id)
}
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
let mut sequences = self.sequences.lock().await;
......@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector {
return Err(KvSchedulerError::NoEndpoints);
}
let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
// active blocks for decoding
let potential_active_blocks = &request.potential_blocks;
// active tokens in the batch (processed by the linear layers), mostly prefill tokens
let potential_active_tokens = &request.potential_tokens;
let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker
for (worker_id, _) in workers.endpoints.iter() {
let cached_blocks = request.overlap.scores.get(worker_id).copied().unwrap_or(0) as f64;
let prefill_blocks = request_blocks as f64 - cached_blocks;
// this is the number of tokens each worker would have if the request were scheduled there
let potential_tokens = *potential_active_tokens.get(worker_id).unwrap_or_else(|| {
tracing::warn!(
"assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
);
&isl
}) as f64;
// this is the number of blocks each worker would have if the request were scheduled there
let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
{tracing::warn!("assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet");
&0
{tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
&request_blocks
}) as f64;
let potential_prefill_blocks = potential_tokens / (block_size as f64);
// Calculate logit (lower is better)
let logit =
self.kv_router_config.overlap_score_weight * prefill_blocks + potential_blocks;
let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks
+ potential_blocks;
max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit);
tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {cached_blocks})",
"Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})",
self.kv_router_config.overlap_score_weight,
cached_blocks = cached_blocks
overlaps.get(worker_id).unwrap_or(&0),
);
}
......@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector {
let temperature = self.kv_router_config.router_temperature;
let best_worker_id = softmax_sample(&worker_logits, temperature);
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0) as usize;
let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0);
let best_logit = worker_logits[&best_worker_id];
tracing::info!(
......
......@@ -34,6 +34,7 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::TokenBlockSequence;
......@@ -76,6 +77,8 @@ pub struct ActiveSequences {
partial_blocks: HashMap<RequestId, UniqueBlock>,
prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>,
#[getter(copy)]
......@@ -83,6 +86,9 @@ pub struct ActiveSequences {
#[getter(copy)]
active_blocks: usize,
#[getter(copy)]
active_tokens: usize,
}
impl ActiveSequences {
......@@ -94,9 +100,11 @@ impl ActiveSequences {
Self {
active_seqs: HashMap::new(),
partial_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
block_size,
active_blocks: 0,
active_tokens: 0,
}
}
......@@ -135,7 +143,13 @@ impl ActiveSequences {
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
overlap: u32,
) -> usize {
let prefill_tokens = self.new_tokens(&token_sequence, overlap);
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size);
for block in &blocks {
......@@ -147,6 +161,25 @@ impl ActiveSequences {
self.active_blocks
}
pub fn new_tokens(&self, token_sequence: &TokenBlockSequence, overlap: u32) -> usize {
let input_tokens = token_sequence.total_tokens();
input_tokens
.checked_sub((overlap as usize) * self.block_size)
.unwrap_or_else(|| {
panic!("prefill_tokens < 0 with overlap {overlap} and ISL {input_tokens}")
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: &TokenBlockSequence,
overlap: u32,
) -> (usize, usize) {
let potential_blocks = self.new_blocks(token_sequence) + self.active_blocks;
let potential_tokens = self.new_tokens(token_sequence, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size);
......@@ -165,6 +198,12 @@ impl ActiveSequences {
/// Free all blocks associated with a request
pub fn free(&mut self, request_id: &RequestId) -> usize {
// decoding has one active token
self.active_tokens = self
.active_tokens
.checked_sub(self.prefill_tokens.remove(request_id).unwrap_or(1))
.expect("active_tokens < 0");
let Some(token_seq) = self.active_seqs.get(request_id) else {
tracing::warn!("Trying to free free non-existent request {request_id}");
return 0;
......@@ -187,6 +226,16 @@ impl ActiveSequences {
/// Push tokens to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, tokens: &[u32]) -> usize {
if let Some(prefill_tokens) = self.prefill_tokens.get(request_id).cloned() {
self.prefill_tokens.remove(request_id);
// decoding has one active token
self.active_tokens = self
.active_tokens
.checked_sub(prefill_tokens)
.expect("active_tokens < 0")
+ 1;
};
// Collect operations to perform after releasing the borrow
let mut blocks_to_remove = Vec::new();
let mut blocks_to_add = Vec::new();
......@@ -239,6 +288,7 @@ enum UpdateSequences {
AddRequest {
request_id: RequestId,
token_sequence: TokenBlockSequence,
overlap: u32,
},
Free {
request_id: RequestId,
......@@ -255,6 +305,11 @@ enum UpdateSequences {
token_sequence: Arc<TokenBlockSequence>,
resp_tx: mpsc::SyncSender<usize>,
},
PotentialBlocksAndTokens {
token_sequence: Arc<TokenBlockSequence>,
overlap: u32,
resp_tx: mpsc::SyncSender<(usize, usize)>,
},
ActiveBlocks {
resp_tx: mpsc::SyncSender<usize>,
},
......@@ -302,8 +357,9 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::AddRequest {
request_id,
token_sequence,
overlap,
} => {
active_sequences.add_request(request_id, token_sequence);
active_sequences.add_request(request_id, token_sequence, overlap);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
......@@ -325,6 +381,15 @@ impl ActiveSequencesMultiWorker {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
overlap,
resp_tx,
} => {
let potential_tokens =
active_sequences.potential_blocks_and_tokens(&token_sequence, overlap);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
......@@ -379,6 +444,7 @@ impl ActiveSequencesMultiWorker {
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
overlap: u32,
worker_id: WorkerId,
) {
if !self.senders.contains_key(&worker_id) {
......@@ -391,6 +457,7 @@ impl ActiveSequencesMultiWorker {
.send(UpdateSequences::AddRequest {
request_id,
token_sequence,
overlap,
})
.expect("Failed to send add_request command to worker");
}
......@@ -482,6 +549,43 @@ impl ActiveSequencesMultiWorker {
})
}
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub fn potential_blocks_and_tokens(
&self,
token_sequence: TokenBlockSequence,
overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = Arc::new(token_sequence);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for (worker_id, sender) in &self.senders {
let (resp_tx, resp_rx) = mpsc::sync_channel(0);
receivers.push((worker_id, resp_rx));
sender
.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(),
overlap: overlaps.scores.get(worker_id).copied().unwrap_or(0),
resp_tx,
})
.expect("Failed to send potential_tokens command to worker");
}
// Collect results from all workers
for (worker_id, receiver) in receivers {
let (blocks, tokens) = receiver
.recv_timeout(Duration::from_secs(1))
.expect("Failed to receive response from worker");
potential_blocks.insert(*worker_id, blocks);
potential_tokens.insert(*worker_id, tokens);
}
(potential_blocks, potential_tokens)
}
/// Query all workers for their current number of active blocks
pub fn active_blocks(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
......@@ -515,14 +619,15 @@ mod tests {
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]));
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]), 0);
manager.push(&"0".to_string(), &[3, 4]); // Push both tokens at once
assert_eq!(manager.active_tokens(), 1);
assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]));
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]), 1);
assert_eq!(manager.active_tokens(), 1 + 3);
assert_eq!(manager.active_blocks(), 3);
// Check that only one key is FullBlock with both requests sharing it
......@@ -551,6 +656,7 @@ mod tests {
// Step 4: Free request 0
manager.free(&"0".to_string());
assert_eq!(manager.active_tokens(), 0);
assert_eq!(manager.active_blocks(), 0);
assert_eq!(manager.unique_blocks.len(), 0);
assert_eq!(manager.partial_blocks.len(), 0);
......@@ -566,14 +672,14 @@ mod tests {
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0);
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0, 0);
// Send request [0, 1, 2] to worker 1, then push 3 and 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1);
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 0, 1);
manager.push(&"req1".to_string(), &[3, 4]); // Push both tokens at once
// Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2);
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 0, 2);
// Check new_blocks on tokens [0, 1, 2, 3, 4]
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment