Unverified Commit 1e5b20b2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: cleanups of passing around prefill and decode worker ids (#4829)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 14321c8f
...@@ -22,6 +22,8 @@ use futures::stream::{self, StreamExt}; ...@@ -22,6 +22,8 @@ use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use crate::protocols::openai::nvext::WorkerIdInfo;
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
pub mod prefill_router; pub mod prefill_router;
...@@ -646,13 +648,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -646,13 +648,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank); backend_input.dp_rank = Some(dp_rank);
// Get prefill worker ID if available (stored by PrefillRouter) // Get prefill worker ID from prefill_result if available
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both // In aggregated mode, prefill_result is None, so we use decode_worker_id for both
let decode_worker_id = instance_id; let decode_worker_id = instance_id;
let prefill_worker_id = context let prefill_worker_id = backend_input
.get::<u64>("prefill_worker_id") .prefill_result
.ok() .as_ref()
.map(|arc| *arc) .and_then(|prefill_result| {
prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id)
})
.or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
...@@ -699,12 +707,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -699,12 +707,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
continue; continue;
}; };
// prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id // prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
// decode_worker_id is always the current instance_id // decode_worker_id is always the current instance_id
let worker_id_json = json!({ let worker_id_info = WorkerIdInfo {
"prefill_worker_id": prefill_worker_id, prefill_worker_id,
"decode_worker_id": decode_worker_id, decode_worker_id: Some(decode_worker_id),
}); };
let worker_id_json = serde_json::to_value(&worker_id_info)
.expect("WorkerIdInfo serialization should not fail");
if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) { if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
obj.insert("worker_id".to_string(), worker_id_json); obj.insert("worker_id".to_string(), worker_id_json);
......
...@@ -176,11 +176,11 @@ impl PrefillRouter { ...@@ -176,11 +176,11 @@ impl PrefillRouter {
Ok(()) Ok(())
} }
/// Call the prefill router and extract structured prefill result and worker ID /// Call the prefill router and extract structured prefill result
async fn call_prefill( async fn call_prefill(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> { ) -> Result<PrefillResult, PrefillError> {
// Get the prefill router, error if not activated // Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else { let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated); return Err(PrefillError::NotActivated);
...@@ -239,21 +239,10 @@ impl PrefillRouter { ...@@ -239,21 +239,10 @@ impl PrefillRouter {
)); ));
}; };
// Extract prefill worker ID from disaggregated_params Ok(PrefillResult {
let prefill_worker_id = disaggregated_params
.get("worker_id")
.and_then(|worker_id_json| {
worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64())
});
Ok((
PrefillResult {
disaggregated_params, disaggregated_params,
prompt_tokens_details, prompt_tokens_details,
}, })
prefill_worker_id,
))
} }
} }
...@@ -310,7 +299,7 @@ impl ...@@ -310,7 +299,7 @@ impl
// Handle prefill result // Handle prefill result
match prefill_result { match prefill_result {
Ok((prefill_result, prefill_worker_id)) => { Ok(prefill_result) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode"); tracing::debug!("Prefill succeeded, using disaggregated params for decode");
let mut decode_req = req; let mut decode_req = req;
...@@ -326,14 +315,8 @@ impl ...@@ -326,14 +315,8 @@ impl
..existing_override.unwrap_or_default() ..existing_override.unwrap_or_default()
}); });
// Store prefill worker ID in context if available
let mut decode_context = context;
if let Some(worker_id) = prefill_worker_id {
decode_context.insert("prefill_worker_id", worker_id);
}
// Map the modified request through with preserved context // Map the modified request through with preserved context
let decode_request = decode_context.map(|_| decode_req); let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await next.generate(decode_request).await
} }
Err(PrefillError::NotActivated) => { Err(PrefillError::NotActivated) => {
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{ use crate::{
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
protocols::common::{self}, protocols::{
common,
openai::nvext::{NvExtResponse, WorkerIdInfo},
},
types::TokenIdType, types::TokenIdType,
}; };
...@@ -363,35 +366,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -363,35 +366,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs); let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present // Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta if let Some(worker_id_info) = delta
.disaggregated_params .disaggregated_params
.as_ref() .as_ref()
.and_then(|params| params.get("worker_id")) .and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
{ {
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse { let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info), worker_id: Some(worker_id_info.clone()),
}; };
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
stream_response.nvext = Some(nvext_json); stream_response.nvext = Some(nvext_json);
tracing::debug!( tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}", "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
prefill_worker_id, worker_id_info.prefill_worker_id,
decode_worker_id worker_id_info.decode_worker_id
); );
} }
} }
......
...@@ -2,7 +2,13 @@ ...@@ -2,7 +2,13 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse}; use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{protocols::common, types::TokenIdType}; use crate::{
protocols::{
common,
openai::nvext::{NvExtResponse, WorkerIdInfo},
},
types::TokenIdType,
};
impl NvCreateCompletionRequest { impl NvCreateCompletionRequest {
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification. /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
...@@ -266,35 +272,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -266,35 +272,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs); let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present // Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta if let Some(worker_id_info) = delta
.disaggregated_params .disaggregated_params
.as_ref() .as_ref()
.and_then(|params| params.get("worker_id")) .and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
{ {
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse { let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info), worker_id: Some(worker_id_info.clone()),
}; };
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.inner.nvext = Some(nvext_json); response.inner.nvext = Some(nvext_json);
tracing::debug!( tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}", "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
prefill_worker_id, worker_id_info.prefill_worker_id,
decode_worker_id worker_id_info.decode_worker_id
); );
} }
} }
......
...@@ -87,6 +87,47 @@ def generate_random_suffix() -> str: ...@@ -87,6 +87,47 @@ def generate_random_suffix() -> str:
return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311 return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311
def verify_response_worker_ids(
response_worker_ids: list[dict[str, Optional[int]]],
key: str,
expected_worker_id: int,
) -> None:
"""Verify that all responses have the same worker ID for a given key.
Args:
response_worker_ids: List of dicts with worker ID info from responses.
key: The key to check (e.g., "decode_worker_id" or "prefill_worker_id").
expected_worker_id: The expected worker ID value.
Raises:
AssertionError: If any response is missing the key, values differ, or don't match expected.
"""
worker_ids = [r.get(key) for r in response_worker_ids]
logger.info(f"Response {key}s: {worker_ids}")
# All responses should have the key
assert all(
wid is not None for wid in worker_ids
), f"Expected all {len(response_worker_ids)} responses to have {key}, got: {worker_ids}"
# All values should be the same (due to prefix reuse routing)
unique_ids = set(worker_ids)
assert len(unique_ids) == 1, (
f"Expected all responses to have the same {key} (due to prefix reuse), "
f"but found {len(unique_ids)} unique values: {unique_ids}"
)
# The value should match the expected worker ID
actual_worker_id = worker_ids[0]
assert actual_worker_id == expected_worker_id, (
f"Expected {key}={expected_worker_id} (forced in first request), "
f"but got {key}={actual_worker_id}"
)
logger.info(
f"✓ Verified all {len(response_worker_ids)} responses have {key}={actual_worker_id}"
)
######################################################## ########################################################
# Utility functions # Utility functions
######################################################## ########################################################
...@@ -420,9 +461,17 @@ async def send_request_via_python_kv_router( ...@@ -420,9 +461,17 @@ async def send_request_via_python_kv_router(
int int
] = None, # If None, Router will select the best available worker ] = None, # If None, Router will select the best available worker
dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0) dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0)
) -> bool: return_worker_ids: bool = False, # If True, return worker IDs from response
) -> bool | dict[str, Optional[int]]:
"""Send a request to the specified worker instance. """Send a request to the specified worker instance.
Returns True if workers respond, otherwise raises or returns False.
Args:
return_worker_ids: If True, returns a dict with prefill_worker_id and decode_worker_id.
If False, returns True on success or False on failure.
Returns:
If return_worker_ids=False: True if workers respond, otherwise raises or returns False.
If return_worker_ids=True: Dict with 'prefill_worker_id' and 'decode_worker_id' keys.
""" """
wait_time = initial_wait wait_time = initial_wait
...@@ -463,8 +512,11 @@ async def send_request_via_python_kv_router( ...@@ -463,8 +512,11 @@ async def send_request_via_python_kv_router(
f"Failed to connect to workers after {max_retries + 1} attempts" f"Failed to connect to workers after {max_retries + 1} attempts"
) from e ) from e
# Collect tokens from the SSE stream # Collect tokens and worker IDs from the SSE stream
generated_tokens = [] generated_tokens = []
prefill_worker_id: Optional[int] = None
decode_worker_id: Optional[int] = None
async for response in stream: async for response in stream:
if isinstance(response, dict): if isinstance(response, dict):
# Check if response has token_ids # Check if response has token_ids
...@@ -480,6 +532,17 @@ async def send_request_via_python_kv_router( ...@@ -480,6 +532,17 @@ async def send_request_via_python_kv_router(
f"Stream finished with reason: {response['finish_reason']}" f"Stream finished with reason: {response['finish_reason']}"
) )
# Extract worker IDs from disaggregated_params if present
if return_worker_ids and "disaggregated_params" in response:
disagg_params = response["disaggregated_params"]
if isinstance(disagg_params, dict) and "worker_id" in disagg_params:
worker_id_info = disagg_params["worker_id"]
if isinstance(worker_id_info, dict):
if "prefill_worker_id" in worker_id_info:
prefill_worker_id = worker_id_info["prefill_worker_id"]
if "decode_worker_id" in worker_id_info:
decode_worker_id = worker_id_info["decode_worker_id"]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True # Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.debug(f"Total generated tokens: {len(generated_tokens)}") logger.debug(f"Total generated tokens: {len(generated_tokens)}")
if ( if (
...@@ -497,9 +560,14 @@ async def send_request_via_python_kv_router( ...@@ -497,9 +560,14 @@ async def send_request_via_python_kv_router(
logger.debug( logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True" f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True"
) )
return True
return False if return_worker_ids:
return {
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
}
return True
######################################################## ########################################################
...@@ -1498,7 +1566,7 @@ def _test_router_indexers_sync( ...@@ -1498,7 +1566,7 @@ def _test_router_indexers_sync(
logger.info("Indexers sync test completed successfully") logger.info("Indexers sync test completed successfully")
def _test_router_disagg_decisions( def _test_router_decisions_disagg(
prefill_workers, prefill_workers,
decode_workers, decode_workers,
block_size: int, block_size: int,
...@@ -1743,6 +1811,7 @@ def _test_router_decisions( ...@@ -1743,6 +1811,7 @@ def _test_router_decisions(
# Send 4 progressive requests with overlapping prefixes # Send 4 progressive requests with overlapping prefixes
cumulative_tokens = [] cumulative_tokens = []
response_worker_ids: list[dict[str, Optional[int]]] = []
for i in range(4): for i in range(4):
# Add BLOCK_SIZE new random tokens # Add BLOCK_SIZE new random tokens
...@@ -1764,7 +1833,7 @@ def _test_router_decisions( ...@@ -1764,7 +1833,7 @@ def _test_router_decisions(
log_msg += f" - FORCING worker_id={worker_id_override}" log_msg += f" - FORCING worker_id={worker_id_override}"
logger.info(log_msg) logger.info(log_msg)
await send_request_via_python_kv_router( result = await send_request_via_python_kv_router(
kv_python_router=kv_push_router, kv_python_router=kv_push_router,
model_name=model_name, model_name=model_name,
token_ids=cumulative_tokens.copy(), token_ids=cumulative_tokens.copy(),
...@@ -1776,6 +1845,13 @@ def _test_router_decisions( ...@@ -1776,6 +1845,13 @@ def _test_router_decisions(
}, },
worker_id=worker_id_override, worker_id=worker_id_override,
dp_rank=dp_rank_override, dp_rank=dp_rank_override,
return_worker_ids=True,
)
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
response_worker_ids.append(result)
logger.info(
f"Request {i + 1} response: prefill_worker_id={result.get('prefill_worker_id')}, "
f"decode_worker_id={result.get('decode_worker_id')}"
) )
# Wait a bit between requests # Wait a bit between requests
...@@ -1787,10 +1863,23 @@ def _test_router_decisions( ...@@ -1787,10 +1863,23 @@ def _test_router_decisions(
# Dump events from the router # Dump events from the router
events_json = await kv_push_router.dump_events() events_json = await kv_push_router.dump_events()
return events_json, forced_worker_id, forced_dp_rank return events_json, forced_worker_id, forced_dp_rank, response_worker_ids
# Run the async test # Run the async test
events_json, expected_worker_id, expected_dp_rank = asyncio.run(test_sync()) (
events_json,
expected_worker_id,
expected_dp_rank,
response_worker_ids,
) = asyncio.run(test_sync())
# Verify worker IDs from responses
verify_response_worker_ids(
response_worker_ids, "decode_worker_id", expected_worker_id
)
verify_response_worker_ids(
response_worker_ids, "prefill_worker_id", expected_worker_id
)
# Parse events and count by worker routing key (worker_id or (worker_id, dp_rank)) # Parse events and count by worker routing key (worker_id or (worker_id, dp_rank))
events = json.loads(events_json) events = json.loads(events_json)
......
...@@ -11,7 +11,7 @@ from tests.router.common import ( # utilities ...@@ -11,7 +11,7 @@ from tests.router.common import ( # utilities
_test_python_router_bindings, _test_python_router_bindings,
_test_router_basic, _test_router_basic,
_test_router_decisions, _test_router_decisions,
_test_router_disagg_decisions, _test_router_decisions_disagg,
_test_router_indexers_sync, _test_router_indexers_sync,
_test_router_overload_503, _test_router_overload_503,
_test_router_query_instance_id, _test_router_query_instance_id,
...@@ -66,7 +66,7 @@ def get_unique_ports( ...@@ -66,7 +66,7 @@ def get_unique_ports(
"test_mocker_two_kv_router": 100, "test_mocker_two_kv_router": 100,
"test_mocker_kv_router_overload_503": 200, "test_mocker_kv_router_overload_503": 200,
"test_query_instance_id_returns_worker_and_tokens": 300, "test_query_instance_id_returns_worker_and_tokens": 300,
"test_router_disagg_decisions": 400, "test_router_decisions_disagg": 400,
"test_busy_threshold_endpoint": 500, "test_busy_threshold_endpoint": 500,
} }
...@@ -583,7 +583,7 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz ...@@ -583,7 +583,7 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
@pytest.mark.parallel @pytest.mark.parallel
def test_router_disagg_decisions( def test_router_decisions_disagg(
request, runtime_services_session, predownload_tokenizers request, runtime_services_session, predownload_tokenizers
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup. """Validate KV cache prefix reuse in disaggregated prefill-decode setup.
...@@ -632,7 +632,7 @@ def test_router_disagg_decisions( ...@@ -632,7 +632,7 @@ def test_router_disagg_decisions(
frontend_port = get_unique_ports(request, num_ports=1)[0] frontend_port = get_unique_ports(request, num_ports=1)[0]
# Run disagg routing test # Run disagg routing test
_test_router_disagg_decisions( _test_router_decisions_disagg(
prefill_workers=prefill_workers, prefill_workers=prefill_workers,
decode_workers=decode_workers, decode_workers=decode_workers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment