Unverified Commit dcbccbcd authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): add worker_type to router selection log (#7258)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 7b193905
...@@ -272,7 +272,7 @@ mod tests { ...@@ -272,7 +272,7 @@ mod tests {
let (cfg_tx, cfg_rx) = watch::channel(configs); let (cfg_tx, cfg_rx) = watch::channel(configs);
std::mem::forget(cfg_tx); std::mem::forget(cfg_tx);
let selector = Box::new(DefaultWorkerSelector::default()); let selector = Box::new(DefaultWorkerSelector::new(None, "test"));
let queue = Arc::new(SchedulerQueue::new( let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots), Arc::clone(&slots),
cfg_rx, cfg_rx,
......
...@@ -79,15 +79,17 @@ fn softmax_sample( ...@@ -79,15 +79,17 @@ fn softmax_sample(
} }
/// Default implementation matching the Python _cost_function. /// Default implementation matching the Python _cost_function.
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone)]
pub struct DefaultWorkerSelector { pub struct DefaultWorkerSelector {
pub kv_router_config: KvRouterConfig, pub kv_router_config: KvRouterConfig,
pub worker_type: &'static str,
} }
impl DefaultWorkerSelector { impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self { pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self {
Self { Self {
kv_router_config: kv_router_config.unwrap_or_default(), kv_router_config: kv_router_config.unwrap_or_default(),
worker_type,
} }
} }
} }
...@@ -187,6 +189,21 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -187,6 +189,21 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let best_logit = worker_logits[&best_worker]; let best_logit = worker_logits[&best_worker];
if self.worker_type == "decode" {
tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}",
self.worker_type,
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
);
return Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
});
}
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0); let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let total_blocks_info = workers let total_blocks_info = workers
...@@ -203,7 +220,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -203,7 +220,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(0); .unwrap_or(0);
tracing::info!( tracing::info!(
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}", "Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
self.worker_type,
best_worker.worker_id, best_worker.worker_id,
best_worker.dp_rank, best_worker.dp_rank,
best_logit, best_logit,
......
...@@ -588,7 +588,7 @@ impl ModelManager { ...@@ -588,7 +588,7 @@ impl ModelManager {
// Get of create runtime config watcher for this endpoint // Get of create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?; let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config, worker_type));
let chooser = KvRouter::new( let chooser = KvRouter::new(
endpoint.clone(), endpoint.clone(),
client, client,
......
...@@ -43,7 +43,7 @@ impl KvScheduler { ...@@ -43,7 +43,7 @@ impl KvScheduler {
kv_router_config: &KvRouterConfig, kv_router_config: &KvRouterConfig,
worker_type: &'static str, worker_type: &'static str,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
// Get initial workers from watch receiver. // Get initial workers from watch receiver.
// Caller must ensure at least one worker is present (via wait_for). // Caller must ensure at least one worker is present (via wait_for).
......
...@@ -30,7 +30,12 @@ pub type NumBlocks = usize; ...@@ -30,7 +30,12 @@ pub type NumBlocks = usize;
/// For Use and Promote variants, block hashes are included for KV event publishing /// For Use and Promote variants, block hashes are included for KV event publishing
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock { pub enum MoveBlock {
Use(Vec<UniqueBlock>, Vec<BlockHash>, Option<Vec<Vec<u32>>>), Use(
Vec<UniqueBlock>,
Vec<BlockHash>,
Option<Vec<Vec<u32>>>,
Option<UniqueBlock>,
),
Destroy(Vec<UniqueBlock>), Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>), Deref(Vec<UniqueBlock>),
Promote(Uuid, SequenceHash, Option<u64>, BlockHash, Option<Vec<u32>>), Promote(Uuid, SequenceHash, Option<u64>, BlockHash, Option<Vec<u32>>),
......
...@@ -143,7 +143,12 @@ impl ActiveSequence { ...@@ -143,7 +143,12 @@ impl ActiveSequence {
None None
}; };
Some(MoveBlock::Use(blocks, hashes, token_ids)) let parent = if prev_blocks > 0 {
Some(self.unique_blocks[prev_blocks - 1].clone())
} else {
None
};
Some(MoveBlock::Use(blocks, hashes, token_ids, parent))
} }
/// Commit a successful allocation by advancing `num_allocated_tokens`. /// Commit a successful allocation by advancing `num_allocated_tokens`.
...@@ -237,7 +242,7 @@ impl ActiveSequence { ...@@ -237,7 +242,7 @@ impl ActiveSequence {
let new_partial_block = UniqueBlock::default(); let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone()); self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block], vec![], None)); signals.push(MoveBlock::Use(vec![new_partial_block], vec![], None, None));
Some(signals) Some(signals)
} }
......
...@@ -186,12 +186,12 @@ impl KvManager { ...@@ -186,12 +186,12 @@ impl KvManager {
/// For other variants, returns the total block count (they always succeed or panic). /// For other variants, returns the total block count (they always succeed or panic).
pub fn process(&mut self, event: &MoveBlock) -> usize { pub fn process(&mut self, event: &MoveBlock) -> usize {
match event { match event {
MoveBlock::Use(hashes, local_hashes, token_ids) => { MoveBlock::Use(hashes, local_hashes, token_ids, parent) => {
let mut blocks_stored = Vec::<u64>::new(); let mut blocks_stored = Vec::<u64>::new();
let mut stored_token_ids: Option<Vec<Vec<u32>>> = let mut stored_token_ids: Option<Vec<Vec<u32>>> =
token_ids.as_ref().map(|_| Vec::new()); token_ids.as_ref().map(|_| Vec::new());
let mut parent_block: Option<&UniqueBlock> = None; let mut parent_block: Option<&UniqueBlock> = parent.as_ref();
let mut allocated = 0; let mut allocated = 0;
for (i, hash) in hashes.iter().enumerate() { for (i, hash) in hashes.iter().enumerate() {
// First check if it already exists in active blocks // First check if it already exists in active blocks
...@@ -420,7 +420,7 @@ mod tests { ...@@ -420,7 +420,7 @@ mod tests {
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> usize { fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> usize {
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect(); let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect(); let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes, None)) manager.process(&MoveBlock::Use(blocks, hashes, None, None))
} }
// First use 10 blocks (0 to 9) in a batch // First use 10 blocks (0 to 9) in a batch
...@@ -447,7 +447,7 @@ mod tests { ...@@ -447,7 +447,7 @@ mod tests {
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) { fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect(); let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect(); let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes, None)); manager.process(&MoveBlock::Use(blocks, hashes, None, None));
} }
// Helper function to destroy multiple blocks // Helper function to destroy multiple blocks
...@@ -561,4 +561,72 @@ mod tests { ...@@ -561,4 +561,72 @@ mod tests {
use_blocks(&mut manager, vec![13]); use_blocks(&mut manager, vec![13]);
} }
#[test]
fn test_chunked_prefill_parent_hash() {
use std::sync::Mutex;
use crate::common::sequence::ActiveSequence;
#[derive(Default)]
struct CapturingSink {
events: Mutex<Vec<KvCacheEvent>>,
}
impl KvCacheEventSink for CapturingSink {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
}
let block_size = 64;
let tokens: Vec<u32> = (0..512).collect(); // 8 blocks
let mut seq = ActiveSequence::new(tokens, 100, Some(block_size), true, false);
let sink = Arc::new(CapturingSink::default());
let mut manager =
KvManager::new_with_event_sink(256, block_size, Some(sink.clone() as _), 0);
// Chunk 1: allocate blocks 0-3
let signal = seq.prepare_allocation(256).unwrap();
manager.process(&signal);
seq.commit_allocation(256);
// Chunk 2: allocate blocks 4-7
let signal = seq.prepare_allocation(512).unwrap();
manager.process(&signal);
seq.commit_allocation(512);
let events = sink.events.lock().unwrap();
assert_eq!(events.len(), 2, "expected two store events");
// First event: parent_hash should be None (starts from root)
let KvCacheEventData::Stored(ref store1) = events[0].data else {
panic!("expected store event");
};
assert!(
store1.parent_hash.is_none(),
"first chunk should have no parent"
);
// Second event: parent_hash should be the seq_hash of block 3
// (the last block from the first chunk)
let KvCacheEventData::Stored(ref store2) = events[1].data else {
panic!("expected store event");
};
let expected_parent = seq.unique_blocks()[3].clone();
let UniqueBlock::FullBlock(expected_hash) = expected_parent else {
panic!("expected full block");
};
assert_eq!(
store2.parent_hash,
Some(ExternalSequenceBlockHash(expected_hash)),
"second chunk's parent should be block 3's seq_hash"
);
}
} }
...@@ -139,7 +139,7 @@ sglang_configs = { ...@@ -139,7 +139,7 @@ sglang_configs = {
expected_log=[ expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+(?:, [^)]*)?\)", r"ZMQ listener .* received batch with \d+ events \(seq=\d+(?:, [^)]*)?\)",
r"Event processor for worker_id \d+ processing event: Stored\(", r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ", r"Selected worker: worker_type=\w+, worker_id=\d+ dp_rank=.*?, logit: ",
] ]
) )
], ],
......
...@@ -151,7 +151,7 @@ trtllm_configs = { ...@@ -151,7 +151,7 @@ trtllm_configs = {
chat_payload_default( chat_payload_default(
expected_log=[ expected_log=[
r"Event processor for worker_id \d+ processing event: Stored\(", r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ", r"Selected worker: worker_type=\w+, worker_id=\d+ dp_rank=.*?, logit: ",
] ]
) )
], ],
......
...@@ -203,7 +203,7 @@ vllm_configs = { ...@@ -203,7 +203,7 @@ vllm_configs = {
expected_log=[ expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+(?:, [^)]*)?\)", r"ZMQ listener .* received batch with \d+ events \(seq=\d+(?:, [^)]*)?\)",
r"Event processor for worker_id \d+ processing event: Stored\(", r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ", r"Selected worker: worker_type=\w+, worker_id=\d+ dp_rank=.*?, logit: ",
] ]
) )
], ],
...@@ -228,7 +228,7 @@ vllm_configs = { ...@@ -228,7 +228,7 @@ vllm_configs = {
repeat_count=3, repeat_count=3,
expected_log=[ expected_log=[
# Verify scheduler is selecting workers with cache awareness # Verify scheduler is selecting workers with cache awareness
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ", r"Selected worker: worker_type=\w+, worker_id=\d+ dp_rank=.*?, logit: ",
# After first request, should see cached blocks being tracked # After first request, should see cached blocks being tracked
r"with \d+ cached blocks", r"with \d+ cached blocks",
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment