Unverified Commit f978f4d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dp rank routing (#3597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 29f5b822
...@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; ...@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor; use crate::protocols::tensor;
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig { pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>, pub total_kv_blocks: Option<u64>,
...@@ -19,6 +19,10 @@ pub struct ModelRuntimeConfig { ...@@ -19,6 +19,10 @@ pub struct ModelRuntimeConfig {
pub reasoning_parser: Option<String>, pub reasoning_parser: Option<String>,
/// Total number of data parallel ranks for this worker (1 if DP not enabled)
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
/// Mapping of engine-specific runtime configs /// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")] #[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>, pub runtime_data: HashMap<String, serde_json::Value>,
...@@ -34,6 +38,25 @@ pub struct ModelRuntimeConfig { ...@@ -34,6 +38,25 @@ pub struct ModelRuntimeConfig {
pub tensor_model_config: Option<tensor::TensorModelConfig>, pub tensor_model_config: Option<tensor::TensorModelConfig>,
} }
const fn default_data_parallel_size() -> u32 {
1
}
impl Default for ModelRuntimeConfig {
fn default() -> Self {
Self {
total_kv_blocks: None,
max_num_seqs: None,
max_num_batched_tokens: None,
tool_call_parser: None,
reasoning_parser: None,
data_parallel_size: default_data_parallel_size(),
runtime_data: HashMap::new(),
tensor_model_config: None,
}
}
}
impl ModelRuntimeConfig { impl ModelRuntimeConfig {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
......
...@@ -124,7 +124,7 @@ impl MockVllmEngine { ...@@ -124,7 +124,7 @@ impl MockVllmEngine {
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
args.clone(), args.clone(),
Some(dp_rank), dp_rank,
Some(output_tx), Some(output_tx),
Some(kv_events_tx), // Pass the KV events sender to scheduler Some(kv_events_tx), // Pass the KV events sender to scheduler
Some(cancel_token.clone()), Some(cancel_token.clone()),
...@@ -283,6 +283,7 @@ impl MockVllmEngine { ...@@ -283,6 +283,7 @@ impl MockVllmEngine {
let event = KvCacheEvent { let event = KvCacheEvent {
event_id: Uuid::new_v4().as_u128() as u64, event_id: Uuid::new_v4().as_u128() as u64,
data: event_data, data: event_data,
dp_rank,
}; };
// Publish the event // Publish the event
...@@ -316,18 +317,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -316,18 +317,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
) -> Result<ManyOut<LLMEngineOutput>, Error> { ) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts(); let (request, ctx) = input.into_parts();
// Extract dp_rank from annotations if present // Extract dp_rank from request field (defaults to 0 if not set)
let dp_rank = request let dp_rank = request.dp_rank.unwrap_or(0);
.annotations
.iter()
.find_map(|ann| {
if ann.starts_with("dp_rank:") {
ann.strip_prefix("dp_rank:").and_then(|s| s.parse().ok())
} else {
None
}
})
.unwrap_or(0);
// Validate dp_rank // Validate dp_rank
if dp_rank >= self.engine_args.dp_size { if dp_rank >= self.engine_args.dp_size {
...@@ -348,7 +339,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -348,7 +339,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
.expect("max_output_tokens must be specified for mocker") .expect("max_output_tokens must be specified for mocker")
as usize, as usize,
uuid: Some(request_uuid), uuid: Some(request_uuid),
dp_rank: Some(dp_rank), dp_rank,
}; };
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
...@@ -512,7 +503,7 @@ pub async fn make_mocker_engine( ...@@ -512,7 +503,7 @@ pub async fn make_mocker_engine(
args: MockEngineArgs, args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> { ) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine // Create the mocker engine
tracing::debug!("Creating mocker engine with config: {args:?}"); tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine = let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id); AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
......
...@@ -37,7 +37,7 @@ pub struct DirectRequest { ...@@ -37,7 +37,7 @@ pub struct DirectRequest {
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
pub max_output_tokens: usize, pub max_output_tokens: usize,
pub uuid: Option<Uuid>, pub uuid: Option<Uuid>,
pub dp_rank: Option<u32>, pub dp_rank: u32,
} }
/// Represents the cost of prefilling content in the cache /// Represents the cost of prefilling content in the cache
......
...@@ -248,7 +248,7 @@ impl Scheduler { ...@@ -248,7 +248,7 @@ impl Scheduler {
/// Create a new Scheduler with the given parameters /// Create a new Scheduler with the given parameters
pub fn new( pub fn new(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: Option<u32>, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>, kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
...@@ -280,7 +280,7 @@ impl Scheduler { ...@@ -280,7 +280,7 @@ impl Scheduler {
// Create channel for request handling // Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>(); let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let mut initial_metrics = ForwardPassMetrics::default(); let mut initial_metrics = ForwardPassMetrics::default();
initial_metrics.worker_stats.data_parallel_rank = dp_rank; initial_metrics.worker_stats.data_parallel_rank = Some(dp_rank);
let (metrics_tx, metrics_rx) = let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics); tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
...@@ -573,7 +573,7 @@ fn get_fwd_pass_metrics( ...@@ -573,7 +573,7 @@ fn get_fwd_pass_metrics(
state: &SchedulerState, state: &SchedulerState,
kv_manager: &KvManager, kv_manager: &KvManager,
hit_rates: &VecDeque<f32>, hit_rates: &VecDeque<f32>,
dp_rank: Option<u32>, dp_rank: u32,
) -> ForwardPassMetrics { ) -> ForwardPassMetrics {
// Get state metrics // Get state metrics
let request_active_slots = state.decode.len() as u64; let request_active_slots = state.decode.len() as u64;
...@@ -597,7 +597,7 @@ fn get_fwd_pass_metrics( ...@@ -597,7 +597,7 @@ fn get_fwd_pass_metrics(
}; };
let worker_stats = WorkerStats { let worker_stats = WorkerStats {
data_parallel_rank: dp_rank, data_parallel_rank: Some(dp_rank),
request_active_slots, request_active_slots,
request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128 request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
num_requests_waiting, num_requests_waiting,
...@@ -728,7 +728,7 @@ mod tests { ...@@ -728,7 +728,7 @@ mod tests {
.unwrap(); .unwrap();
// Create scheduler with new args struct // Create scheduler with new args struct
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None); let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create shared tokens for caching case // Create shared tokens for caching case
let shared_tokens = if use_shared_tokens { let shared_tokens = if use_shared_tokens {
...@@ -759,7 +759,7 @@ mod tests { ...@@ -759,7 +759,7 @@ mod tests {
tokens: input_tokens, tokens: input_tokens,
max_output_tokens, max_output_tokens,
uuid: None, uuid: None,
dp_rank: None, dp_rank: 0,
}; };
scheduler.receive(request).await; scheduler.receive(request).await;
} }
...@@ -853,7 +853,7 @@ mod tests { ...@@ -853,7 +853,7 @@ mod tests {
.unwrap(); .unwrap();
// Create scheduler // Create scheduler
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None); let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create identical tokens for all requests // Create identical tokens for all requests
let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect(); let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
...@@ -864,7 +864,7 @@ mod tests { ...@@ -864,7 +864,7 @@ mod tests {
tokens: identical_tokens.clone(), tokens: identical_tokens.clone(),
max_output_tokens, max_output_tokens,
uuid: None, uuid: None,
dp_rank: None, dp_rank: 0,
}; };
scheduler.receive(request).await; scheduler.receive(request).await;
// Sleep for 0.1 second after each request // Sleep for 0.1 second after each request
...@@ -950,7 +950,7 @@ mod tests { ...@@ -950,7 +950,7 @@ mod tests {
.unwrap(); .unwrap();
// Create scheduler // Create scheduler
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None); let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create request with 256 tokens // Create request with 256 tokens
let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect(); let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
...@@ -958,7 +958,7 @@ mod tests { ...@@ -958,7 +958,7 @@ mod tests {
tokens, tokens,
max_output_tokens, max_output_tokens,
uuid: None, uuid: None,
dp_rank: None, dp_rank: 0,
}; };
scheduler.receive(request).await; scheduler.receive(request).await;
......
...@@ -61,6 +61,11 @@ pub struct PreprocessedRequest { ...@@ -61,6 +61,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>, pub disaggregated_params: Option<serde_json::Value>,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
/// Additional arguments for extensibility /// Additional arguments for extensibility
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
...@@ -294,6 +294,7 @@ pub mod llm_kvbm { ...@@ -294,6 +294,7 @@ pub mod llm_kvbm {
let event = KvCacheEvent { let event = KvCacheEvent {
data, data,
event_id: event_id_counter, event_id: event_id_counter,
dp_rank: 0,
}; };
let router_event = RouterEvent::new(worker_identifier as i64, event); let router_event = RouterEvent::new(worker_identifier as i64, event);
event_id_counter += 1; event_id_counter += 1;
...@@ -313,6 +314,7 @@ pub mod llm_kvbm { ...@@ -313,6 +314,7 @@ pub mod llm_kvbm {
block_hashes: vec![ExternalSequenceBlockHash(sequence_hash)], block_hashes: vec![ExternalSequenceBlockHash(sequence_hash)],
}), }),
event_id: event_id_counter, event_id: event_id_counter,
dp_rank: 0,
}; };
let router_event = RouterEvent::new(worker_identifier as i64, event); let router_event = RouterEvent::new(worker_identifier as i64, event);
event_id_counter += 1; event_id_counter += 1;
...@@ -573,6 +575,7 @@ mod tests { ...@@ -573,6 +575,7 @@ mod tests {
}], }],
parent_hash: None, parent_hash: None,
}), }),
dp_rank: 0,
}, },
); );
...@@ -587,6 +590,7 @@ mod tests { ...@@ -587,6 +590,7 @@ mod tests {
}], }],
parent_hash: None, parent_hash: None,
}), }),
dp_rank: 0,
}, },
); );
...@@ -630,6 +634,7 @@ mod tests { ...@@ -630,6 +634,7 @@ mod tests {
}], }],
parent_hash: None, parent_hash: None,
}), }),
dp_rank: 0,
}, },
); );
...@@ -678,6 +683,7 @@ mod tests { ...@@ -678,6 +683,7 @@ mod tests {
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(4)], block_hashes: vec![ExternalSequenceBlockHash(4)],
}), }),
dp_rank: 0,
}, },
); );
......
...@@ -26,34 +26,59 @@ struct LoadEvent { ...@@ -26,34 +26,59 @@ struct LoadEvent {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ForwardPassMetrics { struct ForwardPassMetrics {
worker_stats: WorkerStats,
kv_stats: KvStats, kv_stats: KvStats,
} }
#[derive(serde::Deserialize)]
struct WorkerStats {
data_parallel_rank: Option<u32>,
}
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct KvStats { struct KvStats {
kv_active_blocks: u64, kv_active_blocks: u64,
} }
#[derive(serde::Deserialize)] #[derive(serde::Deserialize, Clone)]
struct RuntimeConfig { struct RuntimeConfig {
total_kv_blocks: Option<u64>, total_kv_blocks: Option<u64>,
data_parallel_size: u32,
} }
/// Worker load monitoring state /// Worker load monitoring state per dp_rank
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct WorkerLoadState { pub struct WorkerLoadState {
pub kv_active_blocks: Option<u64>, pub kv_active_blocks: HashMap<u32, u64>,
pub kv_total_blocks: Option<u64>, pub kv_total_blocks: HashMap<u32, u64>,
} }
impl WorkerLoadState { impl WorkerLoadState {
/// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
pub fn is_busy(&self, threshold: f64) -> bool { pub fn is_busy(&self, threshold: f64) -> bool {
match (self.kv_active_blocks, self.kv_total_blocks) { // Get all dp_ranks that exist in both active and total blocks
(Some(active), Some(total)) if total > 0 => { let common_dp_ranks: Vec<_> = self
(active as f64) > (threshold * total as f64) .kv_active_blocks
} .keys()
_ => false, .filter(|dp_rank| self.kv_total_blocks.contains_key(dp_rank))
.collect();
// If no common dp_ranks, not busy
if common_dp_ranks.is_empty() {
return false;
} }
// Check if ALL common dp_ranks exceed threshold
common_dp_ranks.iter().all(|&&dp_rank| {
if let (Some(&active), Some(&total)) = (
self.kv_active_blocks.get(&dp_rank),
self.kv_total_blocks.get(&dp_rank),
) {
total > 0 && (active as f64) > (threshold * total as f64)
} else {
false
}
})
} }
} }
...@@ -97,9 +122,10 @@ impl WorkerMonitor { ...@@ -97,9 +122,10 @@ impl WorkerMonitor {
"v1/mdc/", // should be model_card::ROOT_PREFIX but wrong crate "v1/mdc/", // should be model_card::ROOT_PREFIX but wrong crate
key_extractors::lease_id, key_extractors::lease_id,
|card: serde_json::Value| { |card: serde_json::Value| {
card.get("runtime_config") let runtime_config: Option<RuntimeConfig> = card
.and_then(|rc| rc.get("total_kv_blocks")) .get("runtime_config")
.and_then(|t_kv| t_kv.as_u64()) .and_then(|rc| serde_json::from_value(rc.clone()).ok());
runtime_config
}, },
component.drt().child_token(), component.drt().child_token(),
) )
...@@ -132,13 +158,17 @@ impl WorkerMonitor { ...@@ -132,13 +158,17 @@ impl WorkerMonitor {
let mut states = worker_load_states.write().unwrap(); let mut states = worker_load_states.write().unwrap();
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id)); states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks // Update worker load states with total blocks for all dp_ranks
for (lease_id, total_blocks) in runtime_configs.iter() { for (lease_id, runtime_config) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_insert(WorkerLoadState { let state = states.entry(*lease_id).or_default();
kv_active_blocks: None,
kv_total_blocks: None, // Populate total_blocks for all dp_ranks (they share the same total)
}); // data_parallel_size defaults to 1 via serde in ModelRuntimeConfig
state.kv_total_blocks = Some(*total_blocks); if let Some(total_blocks) = runtime_config.total_kv_blocks {
for dp_rank in 0..runtime_config.data_parallel_size {
state.kv_total_blocks.insert(dp_rank, total_blocks);
}
}
} }
} }
...@@ -152,14 +182,12 @@ impl WorkerMonitor { ...@@ -152,14 +182,12 @@ impl WorkerMonitor {
if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) { if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
let worker_id = load_event.worker_id; let worker_id = load_event.worker_id;
let active_blocks = load_event.data.kv_stats.kv_active_blocks; let active_blocks = load_event.data.kv_stats.kv_active_blocks;
let dp_rank = load_event.data.worker_stats.data_parallel_rank.unwrap_or(0);
// Update worker load state // Update worker load state per dp_rank
let mut states = worker_load_states.write().unwrap(); let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_insert(WorkerLoadState { let state = states.entry(worker_id).or_default();
kv_active_blocks: None, state.kv_active_blocks.insert(dp_rank, active_blocks);
kv_total_blocks: None,
});
state.kv_active_blocks = Some(active_blocks);
drop(states); drop(states);
// Recalculate all busy instances and update // Recalculate all busy instances and update
......
...@@ -298,6 +298,7 @@ async def send_request_via_python_kv_router( ...@@ -298,6 +298,7 @@ async def send_request_via_python_kv_router(
worker_id: Optional[ worker_id: Optional[
int int
] = None, # If None, Router will select the best available worker ] = None, # If None, Router will select the best available worker
dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0)
): ):
"""Send a request to the specified mocker instance. """Send a request to the specified mocker instance.
Returns True if mockers respond, otherwise raises or returns False. Returns True if mockers respond, otherwise raises or returns False.
...@@ -324,6 +325,7 @@ async def send_request_via_python_kv_router( ...@@ -324,6 +325,7 @@ async def send_request_via_python_kv_router(
output_options=output_options, output_options=output_options,
router_config_override=router_config_override, router_config_override=router_config_override,
worker_id=worker_id, worker_id=worker_id,
dp_rank=dp_rank,
) )
if stream is not None: if stream is not None:
...@@ -1314,33 +1316,38 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -1314,33 +1316,38 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
def test_router_decisions(request, runtime_services, predownload_tokenizers): def test_router_decisions(request, runtime_services, predownload_tokenizers):
"""Validate KV cache prefix reuse by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Flow: Flow:
- Start two mocker workers sharing a namespace. - Start two mocker workers, each with dp_size=4 (8 total dp ranks).
- Wait for workers to be ready. - Wait for workers to be ready.
- Send 4 progressive requests, each extending the previous tokens: - Send 4 progressive requests, each extending the previous tokens:
* Request 1: BLOCK_SIZE random tokens * Request 1: BLOCK_SIZE random tokens (forced to specific worker_id and dp_rank=1)
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens * Request 2: Request 1 tokens + BLOCK_SIZE new random tokens (naturally routed)
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens * Request 3: Request 2 tokens + BLOCK_SIZE new random tokens (naturally routed)
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens * Request 4: Request 3 tokens + BLOCK_SIZE new random tokens (naturally routed)
- Dump events from router and verify: - Dump events from router and verify:
* All but one worker should have no events (one worker handles all due to prefix reuse) * All but one (worker_id, dp_rank) should have no events (due to prefix reuse)
* The worker with events should have exactly 4 events (one per request) * The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
""" """
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting test router prefix reuse and KV events synchronization") logger.info("Starting test router prefix reuse and KV events synchronization")
# Create mocker args dictionary # Create mocker args dictionary with dp_size=4
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"dp_size": 4,
}
try: try:
# Start mocker instances with the new CLI interface # Start 2 mocker instances, each with dp_size=4 (8 total dp ranks)
logger.info(f"Starting {NUM_MOCKERS} mocker instances") logger.info(
mockers = MockerProcess( "Starting 2 mocker instances with dp_size=4 each (8 total dp ranks)"
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
) )
mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=2)
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers # Initialize mockers
mockers.__enter__() mockers.__enter__()
...@@ -1363,9 +1370,19 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers): ...@@ -1363,9 +1370,19 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Use async to manage the test flow # Use async to manage the test flow
async def test_sync(): async def test_sync():
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
mocker_worker_ids = await wait_for_mockers_ready(endpoint, kv_push_router) mocker_worker_ids = await wait_for_mockers_ready(
endpoint, kv_push_router, expected_num_workers=2
)
logger.info(f"Workers ready: {mocker_worker_ids}") logger.info(f"Workers ready: {mocker_worker_ids}")
# Use the first worker_id for forced routing
forced_worker_id = mocker_worker_ids[0]
forced_dp_rank = 1
logger.info(
f"Will force first request to worker_id={forced_worker_id}, dp_rank={forced_dp_rank}"
)
# Send 4 progressive requests with overlapping prefixes # Send 4 progressive requests with overlapping prefixes
cumulative_tokens = [] cumulative_tokens = []
...@@ -1374,9 +1391,14 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers): ...@@ -1374,9 +1391,14 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
new_tokens = [random.randint(1, 10000) for _ in range(BLOCK_SIZE)] new_tokens = [random.randint(1, 10000) for _ in range(BLOCK_SIZE)]
cumulative_tokens.extend(new_tokens) cumulative_tokens.extend(new_tokens)
# Force first request to specific worker_id and dp_rank=1, let subsequent requests follow naturally
worker_id_override = forced_worker_id if i == 0 else None
dp_rank_override = forced_dp_rank if i == 0 else None
logger.info( logger.info(
f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens " f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens "
f"(added {len(new_tokens)} new tokens)" f"(added {len(new_tokens)} new tokens)"
f"{f' - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}' if worker_id_override is not None else ''}"
) )
await send_request_via_python_kv_router( await send_request_via_python_kv_router(
...@@ -1388,6 +1410,8 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers): ...@@ -1388,6 +1410,8 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
"ignore_eos": True, # Don't stop on EOS token "ignore_eos": True, # Don't stop on EOS token
"max_tokens": 2, # Generate exactly 2 tokens "max_tokens": 2, # Generate exactly 2 tokens
}, },
worker_id=worker_id_override,
dp_rank=dp_rank_override,
) )
# Wait a bit between requests # Wait a bit between requests
...@@ -1398,46 +1422,64 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers): ...@@ -1398,46 +1422,64 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Dump events from the router # Dump events from the router
events_json = await kv_push_router.dump_events() events_json = await kv_push_router.dump_events()
return events_json return events_json, forced_worker_id, forced_dp_rank
# Run the async test # Run the async test
events_json = asyncio.run(test_sync()) events_json, expected_worker_id, expected_dp_rank = asyncio.run(test_sync())
# Parse events and count by worker # Parse events and count by (worker_id, dp_rank)
events = json.loads(events_json) events = json.loads(events_json)
events_by_worker: dict[int, list[Any]] = {} events_by_worker_dp: dict[tuple[int, int], list[Any]] = {}
for event in events: for event in events:
worker_id = event.get("worker_id") worker_id = event.get("worker_id")
if worker_id not in events_by_worker: # Extract dp_rank from the event's KvCacheEvent
events_by_worker[worker_id] = [] dp_rank = event.get("event", {}).get("dp_rank", 0)
events_by_worker[worker_id].append(event) key = (worker_id, dp_rank)
if key not in events_by_worker_dp:
events_by_worker_dp[key] = []
events_by_worker_dp[key].append(event)
logger.info( logger.info(
f"Events by worker: {[(wid, len(evts)) for wid, evts in events_by_worker.items()]}" f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_worker_dp.items()]}"
) )
# Verify: All but one worker should have no events # Verify: All but one (worker_id, dp_rank) should have no events
workers_with_events = [ workers_with_events = [
wid for wid, evts in events_by_worker.items() if len(evts) > 0 key for key, evts in events_by_worker_dp.items() if len(evts) > 0
] ]
assert len(workers_with_events) == 1, ( assert len(workers_with_events) == 1, (
f"Expected exactly 1 worker to have events (due to prefix reuse), " f"Expected exactly 1 (worker_id, dp_rank) to have events (due to prefix reuse), "
f"but found {len(workers_with_events)} workers with events: {workers_with_events}" f"but found {len(workers_with_events)} with events: {workers_with_events}"
) )
# Verify: The worker with events should have exactly 4 events # Verify: The (worker_id, dp_rank) with events should have exactly 4 events
active_worker = workers_with_events[0] active_worker_dp = workers_with_events[0]
num_events = len(events_by_worker[active_worker]) num_events = len(events_by_worker_dp[active_worker_dp])
assert num_events == 4, ( assert num_events == 4, (
f"Expected worker {active_worker} to have exactly 4 events, " f"Expected (worker_id, dp_rank) {active_worker_dp} to have exactly 4 events, "
f"but found {num_events} events" f"but found {num_events} events"
) )
# Verify: Both worker_id and dp_rank should match the forced values
active_worker_id = active_worker_dp[0]
active_dp_rank = active_worker_dp[1]
assert active_worker_id == expected_worker_id, (
f"Expected all events to have worker_id={expected_worker_id} (forced in first request), "
f"but found worker_id={active_worker_id}"
)
assert active_dp_rank == expected_dp_rank, (
f"Expected all events to have dp_rank={expected_dp_rank} (forced in first request), "
f"but found dp_rank={active_dp_rank}"
)
logger.info( logger.info(
f"Successfully verified: Worker {active_worker} handled all 4 requests with prefix reuse. " f"Successfully verified: Worker {active_worker_id} dp_rank {active_dp_rank} handled all 4 requests with prefix reuse. "
f"All events correctly routed to worker_id={expected_worker_id}, dp_rank={expected_dp_rank} as expected. "
f"KV events synchronized correctly." f"KV events synchronized correctly."
) )
......
...@@ -69,7 +69,7 @@ sglang_configs = { ...@@ -69,7 +69,7 @@ sglang_configs = {
expected_log=[ expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)", r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(", r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ", r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
] ]
) )
], ],
......
...@@ -60,7 +60,7 @@ trtllm_configs = { ...@@ -60,7 +60,7 @@ trtllm_configs = {
chat_payload_default( chat_payload_default(
expected_log=[ expected_log=[
r"Event processor for worker_id \d+ processing event: Stored\(", r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ", r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
] ]
) )
], ],
......
...@@ -53,7 +53,7 @@ vllm_configs = { ...@@ -53,7 +53,7 @@ vllm_configs = {
expected_log=[ expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)", r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(", r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ", r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
] ]
) )
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment