Unverified Commit f978f4d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dp rank routing (#3597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 29f5b822
......@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor;
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>,
......@@ -19,6 +19,10 @@ pub struct ModelRuntimeConfig {
pub reasoning_parser: Option<String>,
/// Total number of data parallel ranks for this worker (1 if DP not enabled)
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
......@@ -34,6 +38,25 @@ pub struct ModelRuntimeConfig {
pub tensor_model_config: Option<tensor::TensorModelConfig>,
}
const fn default_data_parallel_size() -> u32 {
1
}
impl Default for ModelRuntimeConfig {
fn default() -> Self {
Self {
total_kv_blocks: None,
max_num_seqs: None,
max_num_batched_tokens: None,
tool_call_parser: None,
reasoning_parser: None,
data_parallel_size: default_data_parallel_size(),
runtime_data: HashMap::new(),
tensor_model_config: None,
}
}
}
impl ModelRuntimeConfig {
pub fn new() -> Self {
Self::default()
......
......@@ -124,7 +124,7 @@ impl MockVllmEngine {
let scheduler = Scheduler::new(
args.clone(),
Some(dp_rank),
dp_rank,
Some(output_tx),
Some(kv_events_tx), // Pass the KV events sender to scheduler
Some(cancel_token.clone()),
......@@ -283,6 +283,7 @@ impl MockVllmEngine {
let event = KvCacheEvent {
event_id: Uuid::new_v4().as_u128() as u64,
data: event_data,
dp_rank,
};
// Publish the event
......@@ -316,18 +317,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts();
// Extract dp_rank from annotations if present
let dp_rank = request
.annotations
.iter()
.find_map(|ann| {
if ann.starts_with("dp_rank:") {
ann.strip_prefix("dp_rank:").and_then(|s| s.parse().ok())
} else {
None
}
})
.unwrap_or(0);
// Extract dp_rank from request field (defaults to 0 if not set)
let dp_rank = request.dp_rank.unwrap_or(0);
// Validate dp_rank
if dp_rank >= self.engine_args.dp_size {
......@@ -348,7 +339,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
.expect("max_output_tokens must be specified for mocker")
as usize,
uuid: Some(request_uuid),
dp_rank: Some(dp_rank),
dp_rank,
};
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
......@@ -512,7 +503,7 @@ pub async fn make_mocker_engine(
args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine
tracing::debug!("Creating mocker engine with config: {args:?}");
tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
......
......@@ -37,7 +37,7 @@ pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
pub dp_rank: Option<u32>,
pub dp_rank: u32,
}
/// Represents the cost of prefilling content in the cache
......
......@@ -248,7 +248,7 @@ impl Scheduler {
/// Create a new Scheduler with the given parameters
pub fn new(
args: MockEngineArgs,
dp_rank: Option<u32>,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
cancellation_token: Option<CancellationToken>,
......@@ -280,7 +280,7 @@ impl Scheduler {
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let mut initial_metrics = ForwardPassMetrics::default();
initial_metrics.worker_stats.data_parallel_rank = dp_rank;
initial_metrics.worker_stats.data_parallel_rank = Some(dp_rank);
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
......@@ -573,7 +573,7 @@ fn get_fwd_pass_metrics(
state: &SchedulerState,
kv_manager: &KvManager,
hit_rates: &VecDeque<f32>,
dp_rank: Option<u32>,
dp_rank: u32,
) -> ForwardPassMetrics {
// Get state metrics
let request_active_slots = state.decode.len() as u64;
......@@ -597,7 +597,7 @@ fn get_fwd_pass_metrics(
};
let worker_stats = WorkerStats {
data_parallel_rank: dp_rank,
data_parallel_rank: Some(dp_rank),
request_active_slots,
request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
num_requests_waiting,
......@@ -728,7 +728,7 @@ mod tests {
.unwrap();
// Create scheduler with new args struct
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create shared tokens for caching case
let shared_tokens = if use_shared_tokens {
......@@ -759,7 +759,7 @@ mod tests {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: None,
dp_rank: 0,
};
scheduler.receive(request).await;
}
......@@ -853,7 +853,7 @@ mod tests {
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create identical tokens for all requests
let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
......@@ -864,7 +864,7 @@ mod tests {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: None,
dp_rank: 0,
};
scheduler.receive(request).await;
// Sleep for 0.1 second after each request
......@@ -950,7 +950,7 @@ mod tests {
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create request with 256 tokens
let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
......@@ -958,7 +958,7 @@ mod tests {
tokens,
max_output_tokens,
uuid: None,
dp_rank: None,
dp_rank: 0,
};
scheduler.receive(request).await;
......
......@@ -61,6 +61,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
/// Additional arguments for extensibility
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
......
......@@ -294,6 +294,7 @@ pub mod llm_kvbm {
let event = KvCacheEvent {
data,
event_id: event_id_counter,
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_identifier as i64, event);
event_id_counter += 1;
......@@ -313,6 +314,7 @@ pub mod llm_kvbm {
block_hashes: vec![ExternalSequenceBlockHash(sequence_hash)],
}),
event_id: event_id_counter,
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_identifier as i64, event);
event_id_counter += 1;
......@@ -573,6 +575,7 @@ mod tests {
}],
parent_hash: None,
}),
dp_rank: 0,
},
);
......@@ -587,6 +590,7 @@ mod tests {
}],
parent_hash: None,
}),
dp_rank: 0,
},
);
......@@ -630,6 +634,7 @@ mod tests {
}],
parent_hash: None,
}),
dp_rank: 0,
},
);
......@@ -678,6 +683,7 @@ mod tests {
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(4)],
}),
dp_rank: 0,
},
);
......
......@@ -26,34 +26,59 @@ struct LoadEvent {
#[derive(serde::Deserialize)]
struct ForwardPassMetrics {
worker_stats: WorkerStats,
kv_stats: KvStats,
}
#[derive(serde::Deserialize)]
struct WorkerStats {
data_parallel_rank: Option<u32>,
}
#[derive(serde::Deserialize)]
struct KvStats {
kv_active_blocks: u64,
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize, Clone)]
struct RuntimeConfig {
total_kv_blocks: Option<u64>,
data_parallel_size: u32,
}
/// Worker load monitoring state
#[derive(Clone, Debug)]
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
pub struct WorkerLoadState {
pub kv_active_blocks: Option<u64>,
pub kv_total_blocks: Option<u64>,
pub kv_active_blocks: HashMap<u32, u64>,
pub kv_total_blocks: HashMap<u32, u64>,
}
impl WorkerLoadState {
/// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
pub fn is_busy(&self, threshold: f64) -> bool {
match (self.kv_active_blocks, self.kv_total_blocks) {
(Some(active), Some(total)) if total > 0 => {
(active as f64) > (threshold * total as f64)
}
_ => false,
// Get all dp_ranks that exist in both active and total blocks
let common_dp_ranks: Vec<_> = self
.kv_active_blocks
.keys()
.filter(|dp_rank| self.kv_total_blocks.contains_key(dp_rank))
.collect();
// If no common dp_ranks, not busy
if common_dp_ranks.is_empty() {
return false;
}
// Check if ALL common dp_ranks exceed threshold
common_dp_ranks.iter().all(|&&dp_rank| {
if let (Some(&active), Some(&total)) = (
self.kv_active_blocks.get(&dp_rank),
self.kv_total_blocks.get(&dp_rank),
) {
total > 0 && (active as f64) > (threshold * total as f64)
} else {
false
}
})
}
}
......@@ -97,9 +122,10 @@ impl WorkerMonitor {
"v1/mdc/", // should be model_card::ROOT_PREFIX but wrong crate
key_extractors::lease_id,
|card: serde_json::Value| {
card.get("runtime_config")
.and_then(|rc| rc.get("total_kv_blocks"))
.and_then(|t_kv| t_kv.as_u64())
let runtime_config: Option<RuntimeConfig> = card
.get("runtime_config")
.and_then(|rc| serde_json::from_value(rc.clone()).ok());
runtime_config
},
component.drt().child_token(),
)
......@@ -132,13 +158,17 @@ impl WorkerMonitor {
let mut states = worker_load_states.write().unwrap();
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks
for (lease_id, total_blocks) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_insert(WorkerLoadState {
kv_active_blocks: None,
kv_total_blocks: None,
});
state.kv_total_blocks = Some(*total_blocks);
// Update worker load states with total blocks for all dp_ranks
for (lease_id, runtime_config) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_default();
// Populate total_blocks for all dp_ranks (they share the same total)
// data_parallel_size defaults to 1 via serde in ModelRuntimeConfig
if let Some(total_blocks) = runtime_config.total_kv_blocks {
for dp_rank in 0..runtime_config.data_parallel_size {
state.kv_total_blocks.insert(dp_rank, total_blocks);
}
}
}
}
......@@ -152,14 +182,12 @@ impl WorkerMonitor {
if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
let worker_id = load_event.worker_id;
let active_blocks = load_event.data.kv_stats.kv_active_blocks;
let dp_rank = load_event.data.worker_stats.data_parallel_rank.unwrap_or(0);
// Update worker load state
// Update worker load state per dp_rank
let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_insert(WorkerLoadState {
kv_active_blocks: None,
kv_total_blocks: None,
});
state.kv_active_blocks = Some(active_blocks);
let state = states.entry(worker_id).or_default();
state.kv_active_blocks.insert(dp_rank, active_blocks);
drop(states);
// Recalculate all busy instances and update
......
......@@ -298,6 +298,7 @@ async def send_request_via_python_kv_router(
worker_id: Optional[
int
] = None, # If None, Router will select the best available worker
dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0)
):
"""Send a request to the specified mocker instance.
Returns True if mockers respond, otherwise raises or returns False.
......@@ -324,6 +325,7 @@ async def send_request_via_python_kv_router(
output_options=output_options,
router_config_override=router_config_override,
worker_id=worker_id,
dp_rank=dp_rank,
)
if stream is not None:
......@@ -1314,33 +1316,38 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.pre_merge
@pytest.mark.model(MODEL_NAME)
def test_router_decisions(request, runtime_services, predownload_tokenizers):
"""Validate KV cache prefix reuse by sending progressive requests with overlapping prefixes.
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Flow:
- Start two mocker workers sharing a namespace.
- Start two mocker workers, each with dp_size=4 (8 total dp ranks).
- Wait for workers to be ready.
- Send 4 progressive requests, each extending the previous tokens:
* Request 1: BLOCK_SIZE random tokens
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens
* Request 1: BLOCK_SIZE random tokens (forced to specific worker_id and dp_rank=1)
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens (naturally routed)
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens (naturally routed)
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens (naturally routed)
- Dump events from router and verify:
* All but one worker should have no events (one worker handles all due to prefix reuse)
* The worker with events should have exactly 4 events (one per request)
* All but one (worker_id, dp_rank) should have no events (due to prefix reuse)
* The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
"""
# runtime_services starts etcd and nats
logger.info("Starting test router prefix reuse and KV events synchronization")
# Create mocker args dictionary
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
# Create mocker args dictionary with dp_size=4
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"dp_size": 4,
}
try:
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
# Start 2 mocker instances, each with dp_size=4 (8 total dp ranks)
logger.info(
"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks)"
)
mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=2)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers
mockers.__enter__()
......@@ -1363,9 +1370,19 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Use async to manage the test flow
async def test_sync():
# Wait for workers to be ready and get their instance IDs
mocker_worker_ids = await wait_for_mockers_ready(endpoint, kv_push_router)
mocker_worker_ids = await wait_for_mockers_ready(
endpoint, kv_push_router, expected_num_workers=2
)
logger.info(f"Workers ready: {mocker_worker_ids}")
# Use the first worker_id for forced routing
forced_worker_id = mocker_worker_ids[0]
forced_dp_rank = 1
logger.info(
f"Will force first request to worker_id={forced_worker_id}, dp_rank={forced_dp_rank}"
)
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens = []
......@@ -1374,9 +1391,14 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
new_tokens = [random.randint(1, 10000) for _ in range(BLOCK_SIZE)]
cumulative_tokens.extend(new_tokens)
# Force first request to specific worker_id and dp_rank=1, let subsequent requests follow naturally
worker_id_override = forced_worker_id if i == 0 else None
dp_rank_override = forced_dp_rank if i == 0 else None
logger.info(
f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens "
f"(added {len(new_tokens)} new tokens)"
f"{f' - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}' if worker_id_override is not None else ''}"
)
await send_request_via_python_kv_router(
......@@ -1388,6 +1410,8 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
"ignore_eos": True, # Don't stop on EOS token
"max_tokens": 2, # Generate exactly 2 tokens
},
worker_id=worker_id_override,
dp_rank=dp_rank_override,
)
# Wait a bit between requests
......@@ -1398,46 +1422,64 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Dump events from the router
events_json = await kv_push_router.dump_events()
return events_json
return events_json, forced_worker_id, forced_dp_rank
# Run the async test
events_json = asyncio.run(test_sync())
events_json, expected_worker_id, expected_dp_rank = asyncio.run(test_sync())
# Parse events and count by worker
# Parse events and count by (worker_id, dp_rank)
events = json.loads(events_json)
events_by_worker: dict[int, list[Any]] = {}
events_by_worker_dp: dict[tuple[int, int], list[Any]] = {}
for event in events:
worker_id = event.get("worker_id")
if worker_id not in events_by_worker:
events_by_worker[worker_id] = []
events_by_worker[worker_id].append(event)
# Extract dp_rank from the event's KvCacheEvent
dp_rank = event.get("event", {}).get("dp_rank", 0)
key = (worker_id, dp_rank)
if key not in events_by_worker_dp:
events_by_worker_dp[key] = []
events_by_worker_dp[key].append(event)
logger.info(
f"Events by worker: {[(wid, len(evts)) for wid, evts in events_by_worker.items()]}"
f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_worker_dp.items()]}"
)
# Verify: All but one worker should have no events
# Verify: All but one (worker_id, dp_rank) should have no events
workers_with_events = [
wid for wid, evts in events_by_worker.items() if len(evts) > 0
key for key, evts in events_by_worker_dp.items() if len(evts) > 0
]
assert len(workers_with_events) == 1, (
f"Expected exactly 1 worker to have events (due to prefix reuse), "
f"but found {len(workers_with_events)} workers with events: {workers_with_events}"
f"Expected exactly 1 (worker_id, dp_rank) to have events (due to prefix reuse), "
f"but found {len(workers_with_events)} with events: {workers_with_events}"
)
# Verify: The worker with events should have exactly 4 events
active_worker = workers_with_events[0]
num_events = len(events_by_worker[active_worker])
# Verify: The (worker_id, dp_rank) with events should have exactly 4 events
active_worker_dp = workers_with_events[0]
num_events = len(events_by_worker_dp[active_worker_dp])
assert num_events == 4, (
f"Expected worker {active_worker} to have exactly 4 events, "
f"Expected (worker_id, dp_rank) {active_worker_dp} to have exactly 4 events, "
f"but found {num_events} events"
)
# Verify: Both worker_id and dp_rank should match the forced values
active_worker_id = active_worker_dp[0]
active_dp_rank = active_worker_dp[1]
assert active_worker_id == expected_worker_id, (
f"Expected all events to have worker_id={expected_worker_id} (forced in first request), "
f"but found worker_id={active_worker_id}"
)
assert active_dp_rank == expected_dp_rank, (
f"Expected all events to have dp_rank={expected_dp_rank} (forced in first request), "
f"but found dp_rank={active_dp_rank}"
)
logger.info(
f"Successfully verified: Worker {active_worker} handled all 4 requests with prefix reuse. "
f"Successfully verified: Worker {active_worker_id} dp_rank {active_dp_rank} handled all 4 requests with prefix reuse. "
f"All events correctly routed to worker_id={expected_worker_id}, dp_rank={expected_dp_rank} as expected. "
f"KV events synchronized correctly."
)
......
......@@ -69,7 +69,7 @@ sglang_configs = {
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
]
)
],
......
......@@ -60,7 +60,7 @@ trtllm_configs = {
chat_payload_default(
expected_log=[
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
]
)
],
......
......@@ -53,7 +53,7 @@ vllm_configs = {
expected_log=[
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
]
)
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment