Unverified Commit 1e5b20b2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: cleanups of passing around prefill and decode worker ids (#4829)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 14321c8f
......@@ -22,6 +22,8 @@ use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::protocols::openai::nvext::WorkerIdInfo;
pub mod approx;
pub mod indexer;
pub mod prefill_router;
......@@ -646,13 +648,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
// Get prefill worker ID if available (stored by PrefillRouter)
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
// Get prefill worker ID from prefill_result if available
// In aggregated mode, prefill_result is None, so we use decode_worker_id for both
let decode_worker_id = instance_id;
let prefill_worker_id = context
.get::<u64>("prefill_worker_id")
.ok()
.map(|arc| *arc)
let prefill_worker_id = backend_input
.prefill_result
.as_ref()
.and_then(|prefill_result| {
prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id)
})
.or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker
let updated_request = context.map(|_| backend_input);
......@@ -699,12 +707,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
continue;
};
// prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id
// prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
// decode_worker_id is always the current instance_id
let worker_id_json = json!({
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
});
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id: Some(decode_worker_id),
};
let worker_id_json = serde_json::to_value(&worker_id_info)
.expect("WorkerIdInfo serialization should not fail");
if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
obj.insert("worker_id".to_string(), worker_id_json);
......
......@@ -176,11 +176,11 @@ impl PrefillRouter {
Ok(())
}
/// Call the prefill router and extract structured prefill result and worker ID
/// Call the prefill router and extract structured prefill result
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
) -> Result<PrefillResult, PrefillError> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated);
......@@ -239,21 +239,10 @@ impl PrefillRouter {
));
};
// Extract prefill worker ID from disaggregated_params
let prefill_worker_id = disaggregated_params
.get("worker_id")
.and_then(|worker_id_json| {
worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64())
});
Ok((
PrefillResult {
Ok(PrefillResult {
disaggregated_params,
prompt_tokens_details,
},
prefill_worker_id,
))
})
}
}
......@@ -310,7 +299,7 @@ impl
// Handle prefill result
match prefill_result {
Ok((prefill_result, prefill_worker_id)) => {
Ok(prefill_result) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
let mut decode_req = req;
......@@ -326,14 +315,8 @@ impl
..existing_override.unwrap_or_default()
});
// Store prefill worker ID in context if available
let mut decode_context = context;
if let Some(worker_id) = prefill_worker_id {
decode_context.insert("prefill_worker_id", worker_id);
}
// Map the modified request through with preserved context
let decode_request = decode_context.map(|_| decode_req);
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(PrefillError::NotActivated) => {
......
......@@ -4,7 +4,10 @@
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{
local_model::runtime_config::ModelRuntimeConfig,
protocols::common::{self},
protocols::{
common,
openai::nvext::{NvExtResponse, WorkerIdInfo},
},
types::TokenIdType,
};
......@@ -363,35 +366,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta
if let Some(worker_id_info) = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
{
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info),
worker_id: Some(worker_id_info.clone()),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
stream_response.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
prefill_worker_id,
decode_worker_id
worker_id_info.prefill_worker_id,
worker_id_info.decode_worker_id
);
}
}
......
......@@ -2,7 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{protocols::common, types::TokenIdType};
use crate::{
protocols::{
common,
openai::nvext::{NvExtResponse, WorkerIdInfo},
},
types::TokenIdType,
};
impl NvCreateCompletionRequest {
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
......@@ -266,35 +272,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta
if let Some(worker_id_info) = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
{
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info),
worker_id: Some(worker_id_info.clone()),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.inner.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
prefill_worker_id,
decode_worker_id
worker_id_info.prefill_worker_id,
worker_id_info.decode_worker_id
);
}
}
......
......@@ -87,6 +87,47 @@ def generate_random_suffix() -> str:
return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311
def verify_response_worker_ids(
response_worker_ids: list[dict[str, Optional[int]]],
key: str,
expected_worker_id: int,
) -> None:
"""Verify that all responses have the same worker ID for a given key.
Args:
response_worker_ids: List of dicts with worker ID info from responses.
key: The key to check (e.g., "decode_worker_id" or "prefill_worker_id").
expected_worker_id: The expected worker ID value.
Raises:
AssertionError: If any response is missing the key, values differ, or don't match expected.
"""
worker_ids = [r.get(key) for r in response_worker_ids]
logger.info(f"Response {key}s: {worker_ids}")
# All responses should have the key
assert all(
wid is not None for wid in worker_ids
), f"Expected all {len(response_worker_ids)} responses to have {key}, got: {worker_ids}"
# All values should be the same (due to prefix reuse routing)
unique_ids = set(worker_ids)
assert len(unique_ids) == 1, (
f"Expected all responses to have the same {key} (due to prefix reuse), "
f"but found {len(unique_ids)} unique values: {unique_ids}"
)
# The value should match the expected worker ID
actual_worker_id = worker_ids[0]
assert actual_worker_id == expected_worker_id, (
f"Expected {key}={expected_worker_id} (forced in first request), "
f"but got {key}={actual_worker_id}"
)
logger.info(
f"✓ Verified all {len(response_worker_ids)} responses have {key}={actual_worker_id}"
)
########################################################
# Utility functions
########################################################
......@@ -420,9 +461,17 @@ async def send_request_via_python_kv_router(
int
] = None, # If None, Router will select the best available worker
dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0)
) -> bool:
return_worker_ids: bool = False, # If True, return worker IDs from response
) -> bool | dict[str, Optional[int]]:
"""Send a request to the specified worker instance.
Returns True if workers respond, otherwise raises or returns False.
Args:
return_worker_ids: If True, returns a dict with prefill_worker_id and decode_worker_id.
If False, returns True on success or False on failure.
Returns:
If return_worker_ids=False: True if workers respond, otherwise raises or returns False.
If return_worker_ids=True: Dict with 'prefill_worker_id' and 'decode_worker_id' keys.
"""
wait_time = initial_wait
......@@ -463,8 +512,11 @@ async def send_request_via_python_kv_router(
f"Failed to connect to workers after {max_retries + 1} attempts"
) from e
# Collect tokens from the SSE stream
# Collect tokens and worker IDs from the SSE stream
generated_tokens = []
prefill_worker_id: Optional[int] = None
decode_worker_id: Optional[int] = None
async for response in stream:
if isinstance(response, dict):
# Check if response has token_ids
......@@ -480,6 +532,17 @@ async def send_request_via_python_kv_router(
f"Stream finished with reason: {response['finish_reason']}"
)
# Extract worker IDs from disaggregated_params if present
if return_worker_ids and "disaggregated_params" in response:
disagg_params = response["disaggregated_params"]
if isinstance(disagg_params, dict) and "worker_id" in disagg_params:
worker_id_info = disagg_params["worker_id"]
if isinstance(worker_id_info, dict):
if "prefill_worker_id" in worker_id_info:
prefill_worker_id = worker_id_info["prefill_worker_id"]
if "decode_worker_id" in worker_id_info:
decode_worker_id = worker_id_info["decode_worker_id"]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.debug(f"Total generated tokens: {len(generated_tokens)}")
if (
......@@ -497,9 +560,14 @@ async def send_request_via_python_kv_router(
logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True"
)
return True
return False
if return_worker_ids:
return {
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
}
return True
########################################################
......@@ -1498,7 +1566,7 @@ def _test_router_indexers_sync(
logger.info("Indexers sync test completed successfully")
def _test_router_disagg_decisions(
def _test_router_decisions_disagg(
prefill_workers,
decode_workers,
block_size: int,
......@@ -1743,6 +1811,7 @@ def _test_router_decisions(
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens = []
response_worker_ids: list[dict[str, Optional[int]]] = []
for i in range(4):
# Add BLOCK_SIZE new random tokens
......@@ -1764,7 +1833,7 @@ def _test_router_decisions(
log_msg += f" - FORCING worker_id={worker_id_override}"
logger.info(log_msg)
await send_request_via_python_kv_router(
result = await send_request_via_python_kv_router(
kv_python_router=kv_push_router,
model_name=model_name,
token_ids=cumulative_tokens.copy(),
......@@ -1776,6 +1845,13 @@ def _test_router_decisions(
},
worker_id=worker_id_override,
dp_rank=dp_rank_override,
return_worker_ids=True,
)
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
response_worker_ids.append(result)
logger.info(
f"Request {i + 1} response: prefill_worker_id={result.get('prefill_worker_id')}, "
f"decode_worker_id={result.get('decode_worker_id')}"
)
# Wait a bit between requests
......@@ -1787,10 +1863,23 @@ def _test_router_decisions(
# Dump events from the router
events_json = await kv_push_router.dump_events()
return events_json, forced_worker_id, forced_dp_rank
return events_json, forced_worker_id, forced_dp_rank, response_worker_ids
# Run the async test
events_json, expected_worker_id, expected_dp_rank = asyncio.run(test_sync())
(
events_json,
expected_worker_id,
expected_dp_rank,
response_worker_ids,
) = asyncio.run(test_sync())
# Verify worker IDs from responses
verify_response_worker_ids(
response_worker_ids, "decode_worker_id", expected_worker_id
)
verify_response_worker_ids(
response_worker_ids, "prefill_worker_id", expected_worker_id
)
# Parse events and count by worker routing key (worker_id or (worker_id, dp_rank))
events = json.loads(events_json)
......
......@@ -11,7 +11,7 @@ from tests.router.common import ( # utilities
_test_python_router_bindings,
_test_router_basic,
_test_router_decisions,
_test_router_disagg_decisions,
_test_router_decisions_disagg,
_test_router_indexers_sync,
_test_router_overload_503,
_test_router_query_instance_id,
......@@ -66,7 +66,7 @@ def get_unique_ports(
"test_mocker_two_kv_router": 100,
"test_mocker_kv_router_overload_503": 200,
"test_query_instance_id_returns_worker_and_tokens": 300,
"test_router_disagg_decisions": 400,
"test_router_decisions_disagg": 400,
"test_busy_threshold_endpoint": 500,
}
......@@ -583,7 +583,7 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
@pytest.mark.parallel
def test_router_disagg_decisions(
def test_router_decisions_disagg(
request, runtime_services_session, predownload_tokenizers
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
......@@ -632,7 +632,7 @@ def test_router_disagg_decisions(
frontend_port = get_unique_ports(request, num_ports=1)[0]
# Run disagg routing test
_test_router_disagg_decisions(
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment