"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "98fcba1575da8d80e47d0540898015d2906d4720"
Unverified Commit 1368ccd6 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

perf: Miscellaneous router perf improvements (#7477)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 0adfd98d
...@@ -31,6 +31,29 @@ pub struct BlockHashOptions<'a> { ...@@ -31,6 +31,29 @@ pub struct BlockHashOptions<'a> {
pub is_eagle: Option<bool>, pub is_eagle: Option<bool>,
} }
#[inline]
fn hash_block_no_mm(chunk: &[u32], seed: u64, scratch_bytes: &mut Vec<u8>) -> LocalBlockHash {
#[cfg(target_endian = "little")]
{
let _ = scratch_bytes;
// SAFETY: `u32` is plain-old-data, and on little-endian targets its in-memory
// representation matches the `to_le_bytes()` sequence used for hashing.
let chunk_bytes = unsafe {
std::slice::from_raw_parts(chunk.as_ptr().cast::<u8>(), std::mem::size_of_val(chunk))
};
LocalBlockHash(xxh3::xxh3_64_with_seed(chunk_bytes, seed))
}
#[cfg(not(target_endian = "little"))]
{
scratch_bytes.clear();
for &token in chunk {
scratch_bytes.extend_from_slice(&token.to_le_bytes());
}
LocalBlockHash(xxh3::xxh3_64_with_seed(scratch_bytes, seed))
}
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata /// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// and LoRA adapter identity. /// and LoRA adapter identity.
/// ///
...@@ -56,35 +79,42 @@ pub fn compute_block_hash_for_seq( ...@@ -56,35 +79,42 @@ pub fn compute_block_hash_for_seq(
Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())), Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())),
None => XXH3_SEED, None => XXH3_SEED,
}; };
let is_eagle_flag = options.is_eagle.unwrap_or(false); let is_eagle_flag = options.is_eagle.unwrap_or(false);
let stride = kv_block_size as usize; let stride = kv_block_size as usize;
let window_size = if is_eagle_flag { stride + 1 } else { stride }; let window_size = if is_eagle_flag { stride + 1 } else { stride };
let estimated_blocks = if is_eagle_flag {
let mut hashes = Vec::new(); tokens.len().saturating_sub(1) / stride
} else {
tokens.len() / stride
};
let mut hashes = Vec::with_capacity(estimated_blocks);
let mut bytes = Vec::with_capacity(window_size * std::mem::size_of::<u32>());
let mut mm_hashes = Vec::new();
let mut block_idx = 0; let mut block_idx = 0;
let mut start = 0; let mut start = 0;
while start + window_size <= tokens.len() { while start + window_size <= tokens.len() {
let chunk = &tokens[start..start + window_size]; let chunk = &tokens[start..start + window_size];
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
if let Some(mm_infos) = options.block_mm_infos if let Some(mm_infos) = options.block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx) && let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{ {
let mut mm_hashes: Vec<u64> = block_mm_info bytes.clear();
.mm_objects for &token in chunk {
.iter() bytes.extend_from_slice(&token.to_le_bytes());
.map(|obj| obj.mm_hash) }
.collect();
mm_hashes.clear();
mm_hashes.extend(block_mm_info.mm_objects.iter().map(|obj| obj.mm_hash));
mm_hashes.sort_unstable(); mm_hashes.sort_unstable();
for mm_hash in mm_hashes { for &mm_hash in &mm_hashes {
bytes.extend_from_slice(&mm_hash.to_le_bytes()); bytes.extend_from_slice(&mm_hash.to_le_bytes());
} }
}
hashes.push(LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed))); hashes.push(LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed)));
} else {
hashes.push(hash_block_no_mm(chunk, seed, &mut bytes));
}
start += stride; start += stride;
block_idx += 1; block_idx += 1;
...@@ -110,8 +140,25 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen ...@@ -110,8 +140,25 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen
let current_block_hash = block_hashes[i].0; let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash]; let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect(); #[cfg(target_endian = "little")]
let seq_hash = compute_hash(&bytes); let seq_hash = {
// SAFETY: `u64` is plain-old-data, and on little-endian targets its in-memory
// representation matches the `to_le_bytes()` sequence used by the previous code.
let bytes = unsafe {
std::slice::from_raw_parts(
combined.as_ptr().cast::<u8>(),
std::mem::size_of_val(&combined),
)
};
compute_hash(bytes)
};
#[cfg(not(target_endian = "little"))]
let seq_hash = {
let mut bytes = [0_u8; std::mem::size_of::<u64>() * 2];
bytes[..8].copy_from_slice(&parent_seq_hash.to_le_bytes());
bytes[8..].copy_from_slice(&current_block_hash.to_le_bytes());
compute_hash(&bytes)
};
sequence_hashes.push(seq_hash); sequence_hashes.push(seq_hash);
} }
......
...@@ -9,7 +9,9 @@ use rand::Rng; ...@@ -9,7 +9,9 @@ use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError}; use validator::{Validate, ValidationError};
use crate::protocols::{BlockHashOptions, compute_block_hash_for_seq, compute_seq_hash_for_block}; use crate::protocols::{
BlockHashOptions, LocalBlockHash, compute_block_hash_for_seq, compute_seq_hash_for_block,
};
const fn default_min_initial_workers() -> usize { const fn default_min_initial_workers() -> usize {
1 1
...@@ -218,6 +220,7 @@ impl KvRouterConfig { ...@@ -218,6 +220,7 @@ impl KvRouterConfig {
block_size: u32, block_size: u32,
config_override: Option<&RouterConfigOverride>, config_override: Option<&RouterConfigOverride>,
hash_options: BlockHashOptions<'_>, hash_options: BlockHashOptions<'_>,
precomputed_block_hashes: Option<&[LocalBlockHash]>,
) -> Option<Vec<u64>> { ) -> Option<Vec<u64>> {
if !self.router_track_active_blocks { if !self.router_track_active_blocks {
return None; return None;
...@@ -233,8 +236,14 @@ impl KvRouterConfig { ...@@ -233,8 +236,14 @@ impl KvRouterConfig {
.unwrap_or(self.router_assume_kv_reuse); .unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse { if assume_kv_reuse {
let block_hashes = compute_block_hash_for_seq(tokens, block_size, hash_options); let block_hashes = match precomputed_block_hashes {
Some(compute_seq_hash_for_block(&block_hashes)) Some(block_hashes) => block_hashes,
None => {
let computed = compute_block_hash_for_seq(tokens, block_size, hash_options);
return Some(compute_seq_hash_for_block(&computed));
}
};
Some(compute_seq_hash_for_block(block_hashes))
} else { } else {
let mut rng = rand::rng(); let mut rng = rand::rng();
Some((0..num_blocks).map(|_| rng.random::<u64>()).collect()) Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
...@@ -305,7 +314,7 @@ mod tests { ...@@ -305,7 +314,7 @@ mod tests {
]; ];
let without_mm = cfg let without_mm = cfg
.compute_seq_hashes_for_tracking(&tokens, 2, None, BlockHashOptions::default()) .compute_seq_hashes_for_tracking(&tokens, 2, None, BlockHashOptions::default(), None)
.unwrap(); .unwrap();
let with_mm = cfg let with_mm = cfg
.compute_seq_hashes_for_tracking( .compute_seq_hashes_for_tracking(
...@@ -316,9 +325,27 @@ mod tests { ...@@ -316,9 +325,27 @@ mod tests {
block_mm_infos: Some(&mm_infos), block_mm_infos: Some(&mm_infos),
..Default::default() ..Default::default()
}, },
None,
) )
.unwrap(); .unwrap();
assert_ne!(without_mm, with_mm); assert_ne!(without_mm, with_mm);
} }
#[test]
fn compute_seq_hashes_for_tracking_uses_precomputed_block_hashes() {
let config = KvRouterConfig::default();
let tokens: Vec<u32> = (0..8).collect();
let precomputed = vec![LocalBlockHash(11), LocalBlockHash(29)];
let seq_hashes = config.compute_seq_hashes_for_tracking(
&tokens,
4,
None,
BlockHashOptions::default(),
Some(&precomputed),
);
assert_eq!(seq_hashes, Some(compute_seq_hash_for_block(&precomputed)));
}
} }
...@@ -22,36 +22,52 @@ pub trait WorkerSelector<C: WorkerConfigLike> { ...@@ -22,36 +22,52 @@ pub trait WorkerSelector<C: WorkerConfigLike> {
} }
/// Helper function for softmax sampling. /// Helper function for softmax sampling.
/// Returns a vec of workers: multiple if tied, single if sampled. /// Returns the selected worker and its logit.
fn softmax_sample( fn softmax_sample(
logits: &HashMap<WorkerWithDpRank, f64>, logits: &HashMap<WorkerWithDpRank, f64>,
temperature: f64, temperature: f64,
) -> Vec<WorkerWithDpRank> { ) -> (WorkerWithDpRank, f64) {
let mut rng = rand::rng();
softmax_sample_with_sample(logits, temperature, rng.random())
}
fn softmax_sample_with_sample(
logits: &HashMap<WorkerWithDpRank, f64>,
temperature: f64,
sample: f64,
) -> (WorkerWithDpRank, f64) {
if logits.is_empty() { if logits.is_empty() {
panic!("Empty logits for softmax sampling"); panic!("Empty logits for softmax sampling");
} }
// Guard: if temperature is 0, return all keys with the smallest logit value (ties) // Guard: at zero temperature, return a minimum-logit worker directly.
if temperature == 0.0 { if temperature == 0.0 {
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b)); let mut logit_iter = logits.iter();
let (first_key, first_logit) = logit_iter.next().unwrap();
let min_keys: Vec<_> = logits
.iter() let mut min_logit = first_logit;
.filter(|&(_, &v)| v == min_logit) let mut min_key = first_key;
.map(|(k, _)| *k) for (key, logit) in logit_iter {
.collect(); if logit < min_logit {
min_logit = logit;
min_key = key;
}
}
return min_keys; return (*min_key, *min_logit);
} }
let keys: Vec<_> = logits.keys().copied().collect(); let entries: Vec<_> = logits
let values: Vec<_> = logits.values().copied().collect(); .iter()
.map(|(worker, logit)| (*worker, *logit))
.collect();
let values: Vec<_> = entries.iter().map(|(_, logit)| *logit).collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b)); let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let probabilities = if min_val == max_val { let probabilities = if min_val == max_val {
vec![1.0 / keys.len() as f64; keys.len()] vec![1.0 / entries.len() as f64; entries.len()]
} else { } else {
// Fused normalize -> negate -> scale -> exp, then normalize probabilities // Fused normalize -> negate -> scale -> exp, then normalize probabilities
let range = max_val - min_val; let range = max_val - min_val;
...@@ -63,19 +79,16 @@ fn softmax_sample( ...@@ -63,19 +79,16 @@ fn softmax_sample(
probs probs
}; };
let mut rng = rand::rng();
let sample: f64 = rng.random();
let mut cumsum = 0.0; let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() { for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob; cumsum += prob;
if sample <= cumsum { if sample <= cumsum {
return vec![keys[i]]; return entries[i];
} }
} }
// Fallback to last key (shouldn't normally reach here) // Fallback to last key (shouldn't normally reach here)
vec![keys[keys.len() - 1]] entries[entries.len() - 1]
} }
/// Default implementation matching the Python _cost_function. /// Default implementation matching the Python _cost_function.
...@@ -118,30 +131,22 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -118,30 +131,22 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let decode_blocks = &request.decode_blocks; let decode_blocks = &request.decode_blocks;
let prefill_tokens = &request.prefill_tokens; let prefill_tokens = &request.prefill_tokens;
let mut worker_logits = HashMap::new();
let overlap_weight = request let overlap_weight = request
.router_config_override .router_config_override
.as_ref() .as_ref()
.and_then(|cfg| cfg.overlap_score_weight) .and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight); .unwrap_or(self.kv_router_config.overlap_score_weight);
for (worker_id, config) in workers let temperature = request
.iter() .router_config_override
.filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid))) .as_ref()
{ .and_then(|cfg| cfg.router_temperature)
let data_parallel_size = config.data_parallel_size(); .unwrap_or(self.kv_router_config.router_temperature);
let data_parallel_start_rank = config.data_parallel_start_rank();
for dp_rank in data_parallel_start_rank..(data_parallel_start_rank + data_parallel_size)
{
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
let get_score = |worker: WorkerWithDpRank| -> f64 {
let overlap = *overlaps.get(&worker).unwrap_or(&0); let overlap = *overlaps.get(&worker).unwrap_or(&0);
let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl); let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64); let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
let decode_block = *decode_blocks let decode_block = *decode_blocks
.get(&worker) .get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize)) .unwrap_or(&(potential_prefill_block.floor() as usize))
...@@ -149,8 +154,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -149,8 +154,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let logit = overlap_weight * potential_prefill_block + decode_block; let logit = overlap_weight * potential_prefill_block + decode_block;
worker_logits.insert(worker, logit);
tracing::debug!( tracing::debug!(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \ "Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \
...@@ -158,36 +161,62 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -158,36 +161,62 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
worker.worker_id, worker.worker_id,
worker.dp_rank worker.dp_rank
); );
logit
};
let worker_iter = workers
.iter()
.filter(move |(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
.flat_map(|(worker_id, config)| {
let data_parallel_size = config.data_parallel_size();
let data_parallel_start_rank = config.data_parallel_start_rank();
(data_parallel_start_rank..(data_parallel_start_rank + data_parallel_size))
.map(move |dp_rank| WorkerWithDpRank::new(*worker_id, dp_rank))
});
let (best_worker, best_logit) = if temperature == 0.0 {
let mut min_workers = Vec::new();
let mut min_score = f64::INFINITY;
for worker in worker_iter {
let score = get_score(worker);
if score < min_score {
min_workers.clear();
min_workers.push(worker);
min_score = score;
} else if score == min_score {
min_workers.push(worker);
} }
} }
let temperature = request if min_workers.len() > 1 {
.router_config_override
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let candidates = softmax_sample(&worker_logits, temperature);
let best_worker = if candidates.len() > 1 {
tracing::debug!( tracing::debug!(
"Multiple workers tied with same logit, using tree size as tie-breaker" "Multiple workers tied with same logit, using tree size as tie-breaker"
); );
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers
.iter() .iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w)) .map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect(); .collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) { if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
let idx = rand::rng().random_range(0..candidates.len()); let idx = rand::rng().random_range(0..min_workers.len());
candidates[idx] (min_workers[idx], min_score)
} else { } else {
*tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1 let (_, worker) = *tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap();
(*worker, min_score)
} }
} else { } else {
candidates[0] (min_workers[0], min_score)
}; }
} else {
let mut worker_logits = HashMap::new();
for worker in worker_iter {
let score = get_score(worker);
worker_logits.insert(worker, score);
}
let best_logit = worker_logits[&best_worker]; softmax_sample(&worker_logits, temperature)
};
if self.worker_type == "decode" { if self.worker_type == "decode" {
tracing::info!( tracing::info!(
...@@ -246,31 +275,22 @@ mod tests { ...@@ -246,31 +275,22 @@ mod tests {
fn test_softmax_sample_single_key() { fn test_softmax_sample_single_key() {
let mut logits = HashMap::new(); let mut logits = HashMap::new();
let worker = WorkerWithDpRank::from_worker_id(42); let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker, 0.5); for (logit, temperature) in [
(0.5, 0.1),
for temperature in &[0.1, 1.0, 10.0] { (0.5, 1.0),
let result = softmax_sample(&logits, *temperature); (0.5, 10.0),
assert_eq!(result.len(), 1, "Should return exactly one worker"); (-100.0, 1.0),
assert_eq!(result[0], worker, "Should return the only available worker"); (100.0, 1.0),
} (0.0, 1.0),
(0.0, 0.0),
] {
logits.clear(); logits.clear();
logits.insert(worker, -100.0); logits.insert(worker, logit);
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear(); let result = softmax_sample(&logits, temperature);
logits.insert(worker, 100.0); assert_eq!(result.0, worker, "Should return the only available worker");
let result = softmax_sample(&logits, 1.0); assert_eq!(result.1, logit, "Should return the selected worker's logit");
assert_eq!(result.len(), 1); }
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 0.0);
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
} }
#[test] #[test]
...@@ -287,13 +307,12 @@ mod tests { ...@@ -287,13 +307,12 @@ mod tests {
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!( assert_eq!(
result.len(), result.0, worker2,
1, "Should return worker with smallest logit when temperature is 0"
"Should return one worker when there's no tie"
); );
assert_eq!( assert_eq!(
result[0], worker2, result.1, 3.0,
"Should return worker with smallest logit when temperature is 0" "Should return the smallest logit when temperature is 0"
); );
logits.clear(); logits.clear();
...@@ -305,15 +324,11 @@ mod tests { ...@@ -305,15 +324,11 @@ mod tests {
logits.insert(worker6, 7.0); logits.insert(worker6, 7.0);
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
2,
"Should return all workers with smallest logit when tied"
);
assert!( assert!(
result.contains(&worker2) && result.contains(&worker5), result.0 == worker2 || result.0 == worker5,
"Should contain both tied workers" "Should return one of the workers tied for the smallest logit"
); );
assert_eq!(result.1, 3.0, "Should return the tied minimum logit");
logits.clear(); logits.clear();
let worker10 = WorkerWithDpRank::from_worker_id(10); let worker10 = WorkerWithDpRank::from_worker_id(10);
...@@ -324,10 +339,44 @@ mod tests { ...@@ -324,10 +339,44 @@ mod tests {
logits.insert(worker30, 0.0); logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!(result.len(), 1);
assert_eq!( assert_eq!(
result[0], worker20, result.0, worker20,
"Should handle negative logits correctly" "Should handle negative logits correctly"
); );
assert_eq!(result.1, -5.0, "Should return the minimum negative logit");
}
#[test]
fn test_softmax_sample_with_sample_returns_selected_logit() {
let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2);
let worker3 = WorkerWithDpRank::from_worker_id(3);
let logits = HashMap::from([(worker1, 0.0), (worker2, 3.0), (worker3, 9.0)]);
let entries: Vec<_> = logits
.iter()
.map(|(worker, logit)| (*worker, *logit))
.collect();
let values: Vec<_> = entries.iter().map(|(_, logit)| *logit).collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let temperature = 1.0;
let range = max_val - min_val;
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut probabilities: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum: f64 = probabilities.iter().sum();
probabilities.iter_mut().for_each(|p| *p /= sum);
let target_idx = entries
.iter()
.position(|(_, logit)| *logit > min_val)
.expect("expected at least one non-minimum logit");
let cumsum_before: f64 = probabilities.iter().take(target_idx).sum();
let sample = cumsum_before + probabilities[target_idx] / 2.0;
let result = softmax_sample_with_sample(&logits, temperature, sample);
assert_eq!(result, entries[target_idx]);
} }
} }
...@@ -429,42 +429,34 @@ where ...@@ -429,42 +429,34 @@ where
} }
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let hash_options = BlockHashOptions {
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes").in_scope(|| {
compute_block_hash_for_seq(
tokens,
self.block_size,
BlockHashOptions {
block_mm_infos, block_mm_infos,
lora_name: lora_name.as_deref(), lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle), is_eagle: Some(self.is_eagle),
}, };
)
});
let hash_elapsed = start.elapsed();
let overlap_scores = self
.indexer
.find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
let find_matches_elapsed = start.elapsed();
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
let hash_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking // Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| { let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking( self.kv_router_config.compute_seq_hashes_for_tracking(
tokens, tokens,
self.block_size, self.block_size,
router_config_override, router_config_override,
BlockHashOptions { hash_options,
block_mm_infos, Some(&block_hashes),
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
},
) )
}); });
let seq_hash_elapsed = start.elapsed(); let seq_hash_elapsed = start.elapsed();
let overlap_scores = self
.indexer
.find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
let find_matches_elapsed = start.elapsed();
let response = self let response = self
.scheduler .scheduler
.schedule( .schedule(
...@@ -486,8 +478,8 @@ where ...@@ -486,8 +478,8 @@ where
if let Some(m) = metrics::RoutingOverheadMetrics::get() { if let Some(m) = metrics::RoutingOverheadMetrics::get() {
m.observe( m.observe(
hash_elapsed, hash_elapsed,
find_matches_elapsed,
seq_hash_elapsed, seq_hash_elapsed,
find_matches_elapsed,
total_elapsed, total_elapsed,
); );
} }
...@@ -496,9 +488,9 @@ where ...@@ -496,9 +488,9 @@ where
tracing::info!( tracing::info!(
isl_tokens, isl_tokens,
hash_us = hash_elapsed.as_micros() as u64, hash_us = hash_elapsed.as_micros() as u64,
find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64, seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
seq_hash_us = (seq_hash_elapsed - find_matches_elapsed).as_micros() as u64, find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
schedule_us = (total_elapsed - seq_hash_elapsed).as_micros() as u64, schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
total_us = total_elapsed.as_micros() as u64, total_us = total_elapsed.as_micros() as u64,
"find_best_match completed" "find_best_match completed"
); );
...@@ -524,16 +516,18 @@ where ...@@ -524,16 +516,18 @@ where
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let hash_options = BlockHashOptions {
block_mm_infos,
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
};
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking( let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens, tokens,
self.block_size, self.block_size,
router_config_override, router_config_override,
BlockHashOptions { hash_options,
block_mm_infos, None,
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
},
); );
if let Err(e) = self if let Err(e) = self
...@@ -615,28 +609,23 @@ where ...@@ -615,28 +609,23 @@ where
lora_name: Option<&str>, lora_name: Option<&str>,
) -> Result<Vec<PotentialLoad>> { ) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq( let hash_options = BlockHashOptions {
tokens,
self.block_size,
BlockHashOptions {
block_mm_infos, block_mm_infos,
lora_name, lora_name,
is_eagle: Some(self.is_eagle), is_eagle: Some(self.is_eagle),
}, };
); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking( let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens, tokens,
self.block_size, self.block_size,
router_config_override, router_config_override,
BlockHashOptions { hash_options,
block_mm_infos, Some(&block_hashes),
lora_name,
is_eagle: Some(self.is_eagle),
},
); );
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(self Ok(self
.scheduler .scheduler
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)) .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores))
......
...@@ -263,26 +263,26 @@ impl RoutingOverheadMetrics { ...@@ -263,26 +263,26 @@ impl RoutingOverheadMetrics {
pub fn observe( pub fn observe(
&self, &self,
hash_elapsed: Duration, hash_elapsed: Duration,
find_matches_elapsed: Duration,
seq_hash_elapsed: Duration, seq_hash_elapsed: Duration,
find_matches_elapsed: Duration,
total_elapsed: Duration, total_elapsed: Duration,
) { ) {
self.block_hashing self.block_hashing
.observe(hash_elapsed.as_secs_f64() * 1000.0); .observe(hash_elapsed.as_secs_f64() * 1000.0);
self.seq_hashing
.observe(seq_hash_elapsed.saturating_sub(hash_elapsed).as_secs_f64() * 1000.0);
self.indexer_find_matches.observe( self.indexer_find_matches.observe(
find_matches_elapsed find_matches_elapsed
.saturating_sub(hash_elapsed) .saturating_sub(seq_hash_elapsed)
.as_secs_f64() .as_secs_f64()
* 1000.0, * 1000.0,
); );
self.seq_hashing.observe( self.scheduling.observe(
seq_hash_elapsed total_elapsed
.saturating_sub(find_matches_elapsed) .saturating_sub(find_matches_elapsed)
.as_secs_f64() .as_secs_f64()
* 1000.0, * 1000.0,
); );
self.scheduling
.observe(total_elapsed.saturating_sub(seq_hash_elapsed).as_secs_f64() * 1000.0);
self.total.observe(total_elapsed.as_secs_f64() * 1000.0); self.total.observe(total_elapsed.as_secs_f64() * 1000.0);
} }
} }
...@@ -557,7 +557,7 @@ dynamo_frontend_router_queue_pending_requests{worker_type=\"decode\"} 5 ...@@ -557,7 +557,7 @@ dynamo_frontend_router_queue_pending_requests{worker_type=\"decode\"} 5
total: make("test_total_ms"), total: make("test_total_ms"),
}; };
// Out-of-order durations: each phase < previous (would panic without saturating_sub) // Out-of-order cumulative durations: each phase < previous (would panic without saturating_sub)
metrics.observe( metrics.observe(
Duration::from_millis(10), Duration::from_millis(10),
Duration::from_millis(5), Duration::from_millis(5),
......
...@@ -518,8 +518,8 @@ impl WorkerQueryClient { ...@@ -518,8 +518,8 @@ impl WorkerQueryClient {
events.len(), events.len(),
last_event_id last_event_id
); );
for event in &events { for event in events {
self.indexer.apply_event(event.clone()).await; self.indexer.apply_event(event).await;
} }
new_cursor = new_cursor.advance_to(last_event_id); new_cursor = new_cursor.advance_to(last_event_id);
successful_response = true; successful_response = true;
......
...@@ -348,6 +348,7 @@ impl OfflineReplayRouter { ...@@ -348,6 +348,7 @@ impl OfflineReplayRouter {
self.block_size, self.block_size,
None, None,
BlockHashOptions::default(), BlockHashOptions::default(),
None,
) )
}; };
(overlaps, token_seq) (overlaps, token_seq)
...@@ -359,6 +360,7 @@ impl OfflineReplayRouter { ...@@ -359,6 +360,7 @@ impl OfflineReplayRouter {
self.block_size, self.block_size,
None, None,
BlockHashOptions::default(), BlockHashOptions::default(),
None,
); );
(overlaps, token_seq) (overlaps, token_seq)
} }
......
...@@ -188,6 +188,7 @@ impl KvReplayRouter { ...@@ -188,6 +188,7 @@ impl KvReplayRouter {
self.block_size, self.block_size,
None, None,
BlockHashOptions::default(), BlockHashOptions::default(),
None,
); );
let response = self let response = self
.scheduler .scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment