Unverified Commit 7d604dd3 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: tie break on tree size when routing (#4257)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f8219b12
...@@ -182,6 +182,7 @@ impl Indexer { ...@@ -182,6 +182,7 @@ impl Indexer {
Indexer::None => Ok(OverlapScores { Indexer::None => Ok(OverlapScores {
scores: HashMap::new(), scores: HashMap::new(),
frequencies: Vec::new(), frequencies: Vec::new(),
tree_sizes: HashMap::new(),
}), }),
} }
} }
......
...@@ -326,6 +326,16 @@ impl RadixTree { ...@@ -326,6 +326,16 @@ impl RadixTree {
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores); tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
scores scores
} }
...@@ -680,6 +690,8 @@ pub struct OverlapScores { ...@@ -680,6 +690,8 @@ pub struct OverlapScores {
pub scores: HashMap<WorkerWithDpRank, u32>, pub scores: HashMap<WorkerWithDpRank, u32>,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted. // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>, pub frequencies: Vec<usize>,
// Map of worker to their tree size (number of blocks in the tree for that worker)
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
} }
impl Default for OverlapScores { impl Default for OverlapScores {
...@@ -698,6 +710,7 @@ impl OverlapScores { ...@@ -698,6 +710,7 @@ impl OverlapScores {
Self { Self {
scores: HashMap::new(), scores: HashMap::new(),
frequencies: Vec::with_capacity(32), frequencies: Vec::with_capacity(32),
tree_sizes: HashMap::new(),
} }
} }
...@@ -1225,6 +1238,7 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1225,6 +1238,7 @@ impl KvIndexerInterface for KvIndexerSharded {
match match_rx.recv().await { match match_rx.recv().await {
Some(response) => { Some(response) => {
scores.scores.extend(response.scores); scores.scores.extend(response.scores);
scores.tree_sizes.extend(response.tree_sizes);
if response_num == 0 { if response_num == 0 {
scores.frequencies = response.frequencies; scores.frequencies = response.frequencies;
......
...@@ -386,12 +386,16 @@ impl KvScheduler { ...@@ -386,12 +386,16 @@ impl KvScheduler {
} }
// Helper function for softmax sampling // Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> WorkerWithDpRank { // Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
logits: &HashMap<WorkerWithDpRank, f64>,
temperature: f64,
) -> Vec<WorkerWithDpRank> {
if logits.is_empty() { if logits.is_empty() {
panic!("Empty logits for softmax sampling"); panic!("Empty logits for softmax sampling");
} }
// Guard: if temperature is 0, return the key with the smallest logit value // Guard: if temperature is 0, return all keys with the smallest logit value (ties)
if temperature == 0.0 { if temperature == 0.0 {
// Find the minimum logit value // Find the minimum logit value
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b)); let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
...@@ -403,10 +407,7 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> ...@@ -403,10 +407,7 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
.map(|(k, _)| *k) .map(|(k, _)| *k)
.collect(); .collect();
// Randomly select from the minimum keys (handles single key case naturally) return min_keys;
let mut rng = rand::rng();
let index = rng.random_range(0..min_keys.len());
return min_keys[index];
} }
let keys: Vec<_> = logits.keys().copied().collect(); let keys: Vec<_> = logits.keys().copied().collect();
...@@ -449,12 +450,12 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> ...@@ -449,12 +450,12 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
for (i, &prob) in probabilities.iter().enumerate() { for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob; cumsum += prob;
if sample <= cumsum { if sample <= cumsum {
return keys[i]; return vec![keys[i]];
} }
} }
// Fallback to last key (shouldn't normally reach here) // Fallback to last key (shouldn't normally reach here)
keys[keys.len() - 1] vec![keys[keys.len() - 1]]
} }
// Default implementation matching the Python _cost_function // Default implementation matching the Python _cost_function
...@@ -542,14 +543,34 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -542,14 +543,34 @@ impl WorkerSelector for DefaultWorkerSelector {
} }
} }
// Use softmax sampling to select worker // Use softmax sampling to select worker(s)
// Use override if provided, otherwise use default config // Use override if provided, otherwise use default config
let temperature = request let temperature = request
.router_config_override .router_config_override
.as_ref() .as_ref()
.and_then(|cfg| cfg.router_temperature) .and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let best_worker = softmax_sample(&worker_logits, temperature); let candidates = softmax_sample(&worker_logits, temperature);
// If multiple candidates (tied), use tree size as tie-breaker
// If tree sizes are also equal, min_by_key uses HashMap iteration order (pseudo-random)
let best_worker = if candidates.len() > 1 {
tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
*candidates
.iter()
.min_by_key(|worker| {
request
.overlaps
.tree_sizes
.get(worker)
.copied()
.unwrap_or(0)
})
.expect("candidates should not be empty")
} else {
candidates[0]
};
let best_logit = worker_logits[&best_worker]; let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0); let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
...@@ -562,12 +583,20 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -562,12 +583,20 @@ impl WorkerSelector for DefaultWorkerSelector {
.map(|blocks| format!(", total blocks: {}", blocks)) .map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default(); .unwrap_or_default();
let tree_size = request
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
tracing::info!( tracing::info!(
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}", "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
best_worker.worker_id, best_worker.worker_id,
best_worker.dp_rank, best_worker.dp_rank,
best_logit, best_logit,
best_overlap, best_overlap,
tree_size,
total_blocks_info total_blocks_info
); );
...@@ -593,26 +622,33 @@ mod tests { ...@@ -593,26 +622,33 @@ mod tests {
// Test with different temperatures // Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] { for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature); let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker, "Should return the only available worker"); assert_eq!(result.len(), 1, "Should return exactly one worker");
assert_eq!(result[0], worker, "Should return the only available worker");
} }
// Test with different logit values // Test with different logit values
logits.clear(); logits.clear();
logits.insert(worker, -100.0); // Very negative value logits.insert(worker, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker); let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear(); logits.clear();
logits.insert(worker, 100.0); // Very positive value logits.insert(worker, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker); let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear(); logits.clear();
logits.insert(worker, 0.0); // Zero value logits.insert(worker, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker); let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
} }
#[test] #[test]
fn test_softmax_sample_zero_temperature() { fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit // Test that with temperature 0, softmax_sample returns all keys with smallest logit
let mut logits = HashMap::new(); let mut logits = HashMap::new();
let worker1 = WorkerWithDpRank::from_worker_id(1); let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2); let worker2 = WorkerWithDpRank::from_worker_id(2);
...@@ -623,14 +659,37 @@ mod tests { ...@@ -623,14 +659,37 @@ mod tests {
logits.insert(worker3, 7.0); logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5); logits.insert(worker4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit) // With temperature 0, should always return only worker2 (smallest logit)
for _ in 0..10 { let result = softmax_sample(&logits, 0.0);
let result = softmax_sample(&logits, 0.0); assert_eq!(
assert_eq!( result.len(),
result, worker2, 1,
"Should return worker with smallest logit when temperature is 0" "Should return one worker when there's no tie"
); );
} assert_eq!(
result[0], worker2,
"Should return worker with smallest logit when temperature is 0"
);
// Test with tied minimum logits
logits.clear();
let worker5 = WorkerWithDpRank::from_worker_id(5);
let worker6 = WorkerWithDpRank::from_worker_id(6);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // Tied for smallest
logits.insert(worker5, 3.0); // Tied for smallest
logits.insert(worker6, 7.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
2,
"Should return all workers with smallest logit when tied"
);
assert!(
result.contains(&worker2) && result.contains(&worker5),
"Should contain both tied workers"
);
// Test with negative values // Test with negative values
logits.clear(); logits.clear();
...@@ -642,6 +701,10 @@ mod tests { ...@@ -642,6 +701,10 @@ mod tests {
logits.insert(worker30, 0.0); logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!(result, worker20, "Should handle negative logits correctly"); assert_eq!(result.len(), 1);
assert_eq!(
result[0], worker20,
"Should handle negative logits correctly"
);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment