Unverified Commit cd6984b9 authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

feat: use RNG when dp routing targets are tied; override no-assume-kv-reuse...


feat: use RNG when dp routing targets are tied; override no-assume-kv-reuse for decode requests (#6253)
Signed-off-by: default avatarKaren Chung <karenc@nvidia.com>
parent dd6c3995
......@@ -921,6 +921,7 @@ pub unsafe extern "C" fn dynamo_router_add_request(
None,
worker,
None, // lora_name not exposed in C API yet
None, // router_config_override not exposed in C API yet
)
.await;
......
......@@ -1051,7 +1051,7 @@ impl KvPushRouter {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = chooser
.get_potential_loads(&token_ids)
.get_potential_loads(&token_ids, None)
.await
.map_err(to_pyerr)?;
......
......@@ -387,9 +387,11 @@ impl KvRouter {
let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = self
.kv_router_config
.compute_seq_hashes_for_tracking(tokens, self.block_size);
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
);
let seq_hash_elapsed = start.elapsed();
let best_worker = self
......@@ -444,12 +446,15 @@ impl KvRouter {
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
lora_name: Option<String>,
router_config_override: Option<&RouterConfigOverride>,
) {
let isl_tokens = tokens.len();
let maybe_seq_hashes = self
.kv_router_config
.compute_seq_hashes_for_tracking(tokens, self.block_size);
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
);
if let Err(e) = self
.scheduler
......@@ -509,14 +514,20 @@ impl KvRouter {
}
/// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
pub async fn get_potential_loads(
&self,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self
.kv_router_config
.compute_seq_hashes_for_tracking(tokens, self.block_size);
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
);
Ok(self
.scheduler
......
......@@ -17,6 +17,9 @@ pub struct RouterConfigOverride {
#[builder(default)]
#[validate(range(min = 0.0))]
pub router_temperature: Option<f64>,
#[builder(default)]
pub assume_kv_reuse: Option<bool>,
}
/// KV Router configuration parameters
......@@ -129,6 +132,7 @@ impl KvRouterConfig {
&self,
tokens: &[u32],
block_size: u32,
config_override: Option<&RouterConfigOverride>,
) -> Option<Vec<u64>> {
if !self.router_track_active_blocks {
return None;
......@@ -139,7 +143,12 @@ impl KvRouterConfig {
return Some(Vec::new());
}
if self.router_assume_kv_reuse {
// Use override if provided, otherwise use default config
let assume_kv_reuse = config_override
.and_then(|cfg| cfg.assume_kv_reuse)
.unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse {
// Compute actual block hashes and sequence hashes
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
Some(compute_seq_hash_for_block(&block_hashes))
......
......@@ -627,10 +627,14 @@ impl
decode_req.bootstrap_info = Some(info);
}
// Set router_config_override for decode: overlap_score_weight = 0
// Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
// may already have blocks cached from prefill transfer)
let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride {
overlap_score_weight: Some(0.0),
assume_kv_reuse: Some(false),
..existing_override.unwrap_or_default()
});
......
......@@ -113,6 +113,7 @@ impl KvPushRouter {
expected_output_tokens,
worker,
lora_name,
request.router_config_override.as_ref(),
)
.await;
} else {
......
......@@ -540,20 +540,20 @@ impl WorkerSelector for DefaultWorkerSelector {
let candidates = softmax_sample(&worker_logits, temperature);
// If multiple candidates (tied), use tree size as tie-breaker
// If tree sizes are also equal, min_by_key uses HashMap iteration order (pseudo-random)
// If tree sizes are also equal, use random selection to avoid bias
let best_worker = if candidates.len() > 1 {
tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
*candidates
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
.iter()
.min_by_key(|worker| {
request
.overlaps
.tree_sizes
.get(worker)
.copied()
.unwrap_or(0)
})
.expect("candidates should not be empty")
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
let idx = rand::rng().random_range(0..candidates.len());
candidates[idx]
} else {
*tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
}
} else {
candidates[0]
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment