"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "5ea71ff46fe503df12f18ad41d40f5c2b18dcfcd"
Unverified Commit cd6984b9 authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

feat: use RNG when dp routing targets are tied; override no-assume-kv-reuse...


feat: use RNG when dp routing targets are tied; override no-assume-kv-reuse for decode requests (#6253)
Signed-off-by: default avatarKaren Chung <karenc@nvidia.com>
parent dd6c3995
...@@ -921,6 +921,7 @@ pub unsafe extern "C" fn dynamo_router_add_request( ...@@ -921,6 +921,7 @@ pub unsafe extern "C" fn dynamo_router_add_request(
None, None,
worker, worker,
None, // lora_name not exposed in C API yet None, // lora_name not exposed in C API yet
None, // router_config_override not exposed in C API yet
) )
.await; .await;
......
...@@ -1051,7 +1051,7 @@ impl KvPushRouter { ...@@ -1051,7 +1051,7 @@ impl KvPushRouter {
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = chooser let loads = chooser
.get_potential_loads(&token_ids) .get_potential_loads(&token_ids, None)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -387,9 +387,11 @@ impl KvRouter { ...@@ -387,9 +387,11 @@ impl KvRouter {
let find_matches_elapsed = start.elapsed(); let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking // Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = self let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
.kv_router_config tokens,
.compute_seq_hashes_for_tracking(tokens, self.block_size); self.block_size,
router_config_override,
);
let seq_hash_elapsed = start.elapsed(); let seq_hash_elapsed = start.elapsed();
let best_worker = self let best_worker = self
...@@ -444,12 +446,15 @@ impl KvRouter { ...@@ -444,12 +446,15 @@ impl KvRouter {
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<String>, lora_name: Option<String>,
router_config_override: Option<&RouterConfigOverride>,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let maybe_seq_hashes = self let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
.kv_router_config tokens,
.compute_seq_hashes_for_tracking(tokens, self.block_size); self.block_size,
router_config_override,
);
if let Err(e) = self if let Err(e) = self
.scheduler .scheduler
...@@ -509,14 +514,20 @@ impl KvRouter { ...@@ -509,14 +514,20 @@ impl KvRouter {
} }
/// Get potential prefill and decode loads for all workers /// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> { pub async fn get_potential_loads(
&self,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
.kv_router_config tokens,
.compute_seq_hashes_for_tracking(tokens, self.block_size); self.block_size,
router_config_override,
);
Ok(self Ok(self
.scheduler .scheduler
......
...@@ -17,6 +17,9 @@ pub struct RouterConfigOverride { ...@@ -17,6 +17,9 @@ pub struct RouterConfigOverride {
#[builder(default)] #[builder(default)]
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub router_temperature: Option<f64>, pub router_temperature: Option<f64>,
#[builder(default)]
pub assume_kv_reuse: Option<bool>,
} }
/// KV Router configuration parameters /// KV Router configuration parameters
...@@ -129,6 +132,7 @@ impl KvRouterConfig { ...@@ -129,6 +132,7 @@ impl KvRouterConfig {
&self, &self,
tokens: &[u32], tokens: &[u32],
block_size: u32, block_size: u32,
config_override: Option<&RouterConfigOverride>,
) -> Option<Vec<u64>> { ) -> Option<Vec<u64>> {
if !self.router_track_active_blocks { if !self.router_track_active_blocks {
return None; return None;
...@@ -139,7 +143,12 @@ impl KvRouterConfig { ...@@ -139,7 +143,12 @@ impl KvRouterConfig {
return Some(Vec::new()); return Some(Vec::new());
} }
if self.router_assume_kv_reuse { // Use override if provided, otherwise use default config
let assume_kv_reuse = config_override
.and_then(|cfg| cfg.assume_kv_reuse)
.unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse {
// Compute actual block hashes and sequence hashes // Compute actual block hashes and sequence hashes
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
Some(compute_seq_hash_for_block(&block_hashes)) Some(compute_seq_hash_for_block(&block_hashes))
......
...@@ -627,10 +627,14 @@ impl ...@@ -627,10 +627,14 @@ impl
decode_req.bootstrap_info = Some(info); decode_req.bootstrap_info = Some(info);
} }
// Set router_config_override for decode: overlap_score_weight = 0 // Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
// may already have blocks cached from prefill transfer)
let existing_override = decode_req.router_config_override.take(); let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride { decode_req.router_config_override = Some(RouterConfigOverride {
overlap_score_weight: Some(0.0), overlap_score_weight: Some(0.0),
assume_kv_reuse: Some(false),
..existing_override.unwrap_or_default() ..existing_override.unwrap_or_default()
}); });
......
...@@ -113,6 +113,7 @@ impl KvPushRouter { ...@@ -113,6 +113,7 @@ impl KvPushRouter {
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name, lora_name,
request.router_config_override.as_ref(),
) )
.await; .await;
} else { } else {
......
...@@ -540,20 +540,20 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -540,20 +540,20 @@ impl WorkerSelector for DefaultWorkerSelector {
let candidates = softmax_sample(&worker_logits, temperature); let candidates = softmax_sample(&worker_logits, temperature);
// If multiple candidates (tied), use tree size as tie-breaker // If multiple candidates (tied), use tree size as tie-breaker
// If tree sizes are also equal, min_by_key uses HashMap iteration order (pseudo-random) // If tree sizes are also equal, use random selection to avoid bias
let best_worker = if candidates.len() > 1 { let best_worker = if candidates.len() > 1 {
tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker"); tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
*candidates let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
.iter() .iter()
.min_by_key(|worker| { .map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
request .collect();
.overlaps
.tree_sizes if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
.get(worker) let idx = rand::rng().random_range(0..candidates.len());
.copied() candidates[idx]
.unwrap_or(0) } else {
}) *tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
.expect("candidates should not be empty") }
} else { } else {
candidates[0] candidates[0]
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment