Unverified Commit 7d604dd3 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: tie break on tree size when routing (#4257)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f8219b12
......@@ -182,6 +182,7 @@ impl Indexer {
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
tree_sizes: HashMap::new(),
}),
}
}
......
......@@ -326,6 +326,16 @@ impl RadixTree {
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
scores
}
......@@ -680,6 +690,8 @@ pub struct OverlapScores {
pub scores: HashMap<WorkerWithDpRank, u32>,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>,
// Map of worker to their tree size (number of blocks in the tree for that worker)
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
}
impl Default for OverlapScores {
......@@ -698,6 +710,7 @@ impl OverlapScores {
Self {
scores: HashMap::new(),
frequencies: Vec::with_capacity(32),
tree_sizes: HashMap::new(),
}
}
......@@ -1225,6 +1238,7 @@ impl KvIndexerInterface for KvIndexerSharded {
match match_rx.recv().await {
Some(response) => {
scores.scores.extend(response.scores);
scores.tree_sizes.extend(response.tree_sizes);
if response_num == 0 {
scores.frequencies = response.frequencies;
......
......@@ -386,12 +386,16 @@ impl KvScheduler {
}
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> WorkerWithDpRank {
// Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
logits: &HashMap<WorkerWithDpRank, f64>,
temperature: f64,
) -> Vec<WorkerWithDpRank> {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
// Guard: if temperature is 0, return the key with the smallest logit value
// Guard: if temperature is 0, return all keys with the smallest logit value (ties)
if temperature == 0.0 {
// Find the minimum logit value
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
......@@ -403,10 +407,7 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
.map(|(k, _)| *k)
.collect();
// Randomly select from the minimum keys (handles single key case naturally)
let mut rng = rand::rng();
let index = rng.random_range(0..min_keys.len());
return min_keys[index];
return min_keys;
}
let keys: Vec<_> = logits.keys().copied().collect();
......@@ -449,12 +450,12 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return keys[i];
return vec![keys[i]];
}
}
// Fallback to last key (shouldn't normally reach here)
keys[keys.len() - 1]
vec![keys[keys.len() - 1]]
}
// Default implementation matching the Python _cost_function
......@@ -542,14 +543,34 @@ impl WorkerSelector for DefaultWorkerSelector {
}
}
// Use softmax sampling to select worker
// Use softmax sampling to select worker(s)
// Use override if provided, otherwise use default config
let temperature = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let best_worker = softmax_sample(&worker_logits, temperature);
let candidates = softmax_sample(&worker_logits, temperature);
// If multiple candidates (tied), use tree size as tie-breaker
// If tree sizes are also equal, min_by_key uses HashMap iteration order (pseudo-random)
let best_worker = if candidates.len() > 1 {
tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
*candidates
.iter()
.min_by_key(|worker| {
request
.overlaps
.tree_sizes
.get(worker)
.copied()
.unwrap_or(0)
})
.expect("candidates should not be empty")
} else {
candidates[0]
};
let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
......@@ -562,12 +583,20 @@ impl WorkerSelector for DefaultWorkerSelector {
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
let tree_size = request
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}",
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
tree_size,
total_blocks_info
);
......@@ -593,26 +622,33 @@ mod tests {
// Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker, "Should return the only available worker");
assert_eq!(result.len(), 1, "Should return exactly one worker");
assert_eq!(result[0], worker, "Should return the only available worker");
}
// Test with different logit values
logits.clear();
logits.insert(worker, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker);
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker);
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker);
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
}
#[test]
fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit
// Test that with temperature 0, softmax_sample returns all keys with smallest logit
let mut logits = HashMap::new();
let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2);
......@@ -623,14 +659,37 @@ mod tests {
logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit)
for _ in 0..10 {
// With temperature 0, should always return only worker2 (smallest logit)
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result, worker2,
result.len(),
1,
"Should return one worker when there's no tie"
);
assert_eq!(
result[0], worker2,
"Should return worker with smallest logit when temperature is 0"
);
}
// Test with tied minimum logits
logits.clear();
let worker5 = WorkerWithDpRank::from_worker_id(5);
let worker6 = WorkerWithDpRank::from_worker_id(6);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // Tied for smallest
logits.insert(worker5, 3.0); // Tied for smallest
logits.insert(worker6, 7.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
2,
"Should return all workers with smallest logit when tied"
);
assert!(
result.contains(&worker2) && result.contains(&worker5),
"Should contain both tied workers"
);
// Test with negative values
logits.clear();
......@@ -642,6 +701,10 @@ mod tests {
logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(result, worker20, "Should handle negative logits correctly");
assert_eq!(result.len(), 1);
assert_eq!(
result[0], worker20,
"Should handle negative logits correctly"
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment