Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
1e5b20b2
Unverified
Commit
1e5b20b2
authored
Dec 09, 2025
by
Yan Ru Pei
Committed by
GitHub
Dec 09, 2025
Browse files
chore: cleanups of passing around prefill and decode worker ids (#4829)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
14321c8f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
152 additions
and
87 deletions
+152
-87
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+21
-11
lib/llm/src/kv_router/prefill_router.rs
lib/llm/src/kv_router/prefill_router.rs
+8
-25
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+9
-19
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+12
-19
tests/router/common.py
tests/router/common.py
+98
-9
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+4
-4
No files found.
lib/llm/src/kv_router.rs
View file @
1e5b20b2
...
...
@@ -22,6 +22,8 @@ use futures::stream::{self, StreamExt};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
json
;
use
crate
::
protocols
::
openai
::
nvext
::
WorkerIdInfo
;
pub
mod
approx
;
pub
mod
indexer
;
pub
mod
prefill_router
;
...
...
@@ -646,13 +648,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
backend_input
.estimated_prefix_hit_num_blocks
=
Some
(
overlap_amount
);
backend_input
.dp_rank
=
Some
(
dp_rank
);
// Get prefill worker ID
if available (stored by PrefillRouter)
// In aggregated mode, prefill_
worker_id
is None, so we use decode_worker_id for both
// Get prefill worker ID
from prefill_result if available
// In aggregated mode, prefill_
result
is None, so we use decode_worker_id for both
let
decode_worker_id
=
instance_id
;
let
prefill_worker_id
=
context
.get
::
<
u64
>
(
"prefill_worker_id"
)
.ok
()
.map
(|
arc
|
*
arc
)
let
prefill_worker_id
=
backend_input
.prefill_result
.as_ref
()
.and_then
(|
prefill_result
|
{
prefill_result
.disaggregated_params
.get
(
"worker_id"
)
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
())
.and_then
(|
info
|
info
.prefill_worker_id
)
})
.or
(
Some
(
decode_worker_id
));
// Use decode_worker_id if no separate prefill worker
let
updated_request
=
context
.map
(|
_
|
backend_input
);
...
...
@@ -699,12 +707,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
continue
;
};
// prefill_worker_id comes from
context (set by PrefillRouter)
or falls back to instance_id
// prefill_worker_id comes from
prefill_result.disaggregated_params
or falls back to instance_id
// decode_worker_id is always the current instance_id
let
worker_id_json
=
json!
({
"prefill_worker_id"
:
prefill_worker_id
,
"decode_worker_id"
:
decode_worker_id
,
});
let
worker_id_info
=
WorkerIdInfo
{
prefill_worker_id
,
decode_worker_id
:
Some
(
decode_worker_id
),
};
let
worker_id_json
=
serde_json
::
to_value
(
&
worker_id_info
)
.expect
(
"WorkerIdInfo serialization should not fail"
);
if
let
Some
(
obj
)
=
data
.disaggregated_params
.as_mut
()
.and_then
(|
p
|
p
.as_object_mut
())
{
obj
.insert
(
"worker_id"
.to_string
(),
worker_id_json
);
...
...
lib/llm/src/kv_router/prefill_router.rs
View file @
1e5b20b2
...
...
@@ -176,11 +176,11 @@ impl PrefillRouter {
Ok
(())
}
/// Call the prefill router and extract structured prefill result
and worker ID
/// Call the prefill router and extract structured prefill result
async
fn
call_prefill
(
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
(
PrefillResult
,
Option
<
u64
>
),
PrefillError
>
{
)
->
Result
<
PrefillResult
,
PrefillError
>
{
// Get the prefill router, error if not activated
let
Some
(
prefill_router
)
=
self
.prefill_router
.get
()
else
{
return
Err
(
PrefillError
::
NotActivated
);
...
...
@@ -239,21 +239,10 @@ impl PrefillRouter {
));
};
// Extract prefill worker ID from disaggregated_params
let
prefill_worker_id
=
disaggregated_params
.get
(
"worker_id"
)
.and_then
(|
worker_id_json
|
{
worker_id_json
.get
(
"prefill_worker_id"
)
.and_then
(|
v
|
v
.as_u64
())
});
Ok
((
PrefillResult
{
disaggregated_params
,
prompt_tokens_details
,
},
prefill_worker_id
,
))
Ok
(
PrefillResult
{
disaggregated_params
,
prompt_tokens_details
,
})
}
}
...
...
@@ -310,7 +299,7 @@ impl
// Handle prefill result
match
prefill_result
{
Ok
(
(
prefill_result
,
prefill_worker_id
)
)
=>
{
Ok
(
prefill_result
)
=>
{
tracing
::
debug!
(
"Prefill succeeded, using disaggregated params for decode"
);
let
mut
decode_req
=
req
;
...
...
@@ -326,14 +315,8 @@ impl
..
existing_override
.unwrap_or_default
()
});
// Store prefill worker ID in context if available
let
mut
decode_context
=
context
;
if
let
Some
(
worker_id
)
=
prefill_worker_id
{
decode_context
.insert
(
"prefill_worker_id"
,
worker_id
);
}
// Map the modified request through with preserved context
let
decode_request
=
decode_
context
.map
(|
_
|
decode_req
);
let
decode_request
=
context
.map
(|
_
|
decode_req
);
next
.generate
(
decode_request
)
.await
}
Err
(
PrefillError
::
NotActivated
)
=>
{
...
...
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
1e5b20b2
...
...
@@ -4,7 +4,10 @@
use
super
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
};
use
crate
::{
local_model
::
runtime_config
::
ModelRuntimeConfig
,
protocols
::
common
::{
self
},
protocols
::{
common
,
openai
::
nvext
::{
NvExtResponse
,
WorkerIdInfo
},
},
types
::
TokenIdType
,
};
...
...
@@ -363,35 +366,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let
mut
stream_response
=
self
.create_choice
(
index
,
delta
.text
,
finish_reason
,
logprobs
);
// Extract worker_id from disaggregated_params and inject into nvext if present
if
let
Some
(
worker_id_
json
)
=
delta
if
let
Some
(
worker_id_
info
)
=
delta
.disaggregated_params
.as_ref
()
.and_then
(|
params
|
params
.get
(
"worker_id"
))
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
())
{
use
crate
::
protocols
::
openai
::
nvext
::{
NvExtResponse
,
WorkerIdInfo
};
let
prefill_worker_id
=
worker_id_json
.get
(
"prefill_worker_id"
)
.and_then
(|
v
|
v
.as_u64
());
let
decode_worker_id
=
worker_id_json
.get
(
"decode_worker_id"
)
.and_then
(|
v
|
v
.as_u64
());
let
worker_id_info
=
WorkerIdInfo
{
prefill_worker_id
,
decode_worker_id
,
};
let
nvext_response
=
NvExtResponse
{
worker_id
:
Some
(
worker_id_info
),
worker_id
:
Some
(
worker_id_info
.clone
()
),
};
if
let
Ok
(
nvext_json
)
=
serde_json
::
to_value
(
&
nvext_response
)
{
stream_response
.nvext
=
Some
(
nvext_json
);
tracing
::
debug!
(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}"
,
prefill_worker_id
,
decode_worker_id
worker_id_info
.
prefill_worker_id
,
worker_id_info
.
decode_worker_id
);
}
}
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
1e5b20b2
...
...
@@ -2,7 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
use
super
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
};
use
crate
::{
protocols
::
common
,
types
::
TokenIdType
};
use
crate
::{
protocols
::{
common
,
openai
::
nvext
::{
NvExtResponse
,
WorkerIdInfo
},
},
types
::
TokenIdType
,
};
impl
NvCreateCompletionRequest
{
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
...
...
@@ -266,35 +272,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let
mut
response
=
self
.create_choice
(
index
,
delta
.text
.clone
(),
finish_reason
,
logprobs
);
// Extract worker_id from disaggregated_params and inject into nvext if present
if
let
Some
(
worker_id_
json
)
=
delta
if
let
Some
(
worker_id_
info
)
=
delta
.disaggregated_params
.as_ref
()
.and_then
(|
params
|
params
.get
(
"worker_id"
))
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
())
{
use
crate
::
protocols
::
openai
::
nvext
::{
NvExtResponse
,
WorkerIdInfo
};
let
prefill_worker_id
=
worker_id_json
.get
(
"prefill_worker_id"
)
.and_then
(|
v
|
v
.as_u64
());
let
decode_worker_id
=
worker_id_json
.get
(
"decode_worker_id"
)
.and_then
(|
v
|
v
.as_u64
());
let
worker_id_info
=
WorkerIdInfo
{
prefill_worker_id
,
decode_worker_id
,
};
let
nvext_response
=
NvExtResponse
{
worker_id
:
Some
(
worker_id_info
),
worker_id
:
Some
(
worker_id_info
.clone
()
),
};
if
let
Ok
(
nvext_json
)
=
serde_json
::
to_value
(
&
nvext_response
)
{
response
.inner.nvext
=
Some
(
nvext_json
);
tracing
::
debug!
(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}"
,
prefill_worker_id
,
decode_worker_id
worker_id_info
.
prefill_worker_id
,
worker_id_info
.
decode_worker_id
);
}
}
...
...
tests/router/common.py
View file @
1e5b20b2
...
...
@@ -87,6 +87,47 @@ def generate_random_suffix() -> str:
return
""
.
join
(
random
.
choices
(
string
.
ascii_lowercase
,
k
=
10
))
# noqa: S311
def
verify_response_worker_ids
(
response_worker_ids
:
list
[
dict
[
str
,
Optional
[
int
]]],
key
:
str
,
expected_worker_id
:
int
,
)
->
None
:
"""Verify that all responses have the same worker ID for a given key.
Args:
response_worker_ids: List of dicts with worker ID info from responses.
key: The key to check (e.g., "decode_worker_id" or "prefill_worker_id").
expected_worker_id: The expected worker ID value.
Raises:
AssertionError: If any response is missing the key, values differ, or don't match expected.
"""
worker_ids
=
[
r
.
get
(
key
)
for
r
in
response_worker_ids
]
logger
.
info
(
f
"Response
{
key
}
s:
{
worker_ids
}
"
)
# All responses should have the key
assert
all
(
wid
is
not
None
for
wid
in
worker_ids
),
f
"Expected all
{
len
(
response_worker_ids
)
}
responses to have
{
key
}
, got:
{
worker_ids
}
"
# All values should be the same (due to prefix reuse routing)
unique_ids
=
set
(
worker_ids
)
assert
len
(
unique_ids
)
==
1
,
(
f
"Expected all responses to have the same
{
key
}
(due to prefix reuse), "
f
"but found
{
len
(
unique_ids
)
}
unique values:
{
unique_ids
}
"
)
# The value should match the expected worker ID
actual_worker_id
=
worker_ids
[
0
]
assert
actual_worker_id
==
expected_worker_id
,
(
f
"Expected
{
key
}
=
{
expected_worker_id
}
(forced in first request), "
f
"but got
{
key
}
=
{
actual_worker_id
}
"
)
logger
.
info
(
f
"✓ Verified all
{
len
(
response_worker_ids
)
}
responses have
{
key
}
=
{
actual_worker_id
}
"
)
########################################################
# Utility functions
########################################################
...
...
@@ -420,9 +461,17 @@ async def send_request_via_python_kv_router(
int
]
=
None
,
# If None, Router will select the best available worker
dp_rank
:
Optional
[
int
]
=
None
,
# Data parallel rank (defaults to 0)
)
->
bool
:
return_worker_ids
:
bool
=
False
,
# If True, return worker IDs from response
)
->
bool
|
dict
[
str
,
Optional
[
int
]]:
"""Send a request to the specified worker instance.
Returns True if workers respond, otherwise raises or returns False.
Args:
return_worker_ids: If True, returns a dict with prefill_worker_id and decode_worker_id.
If False, returns True on success or False on failure.
Returns:
If return_worker_ids=False: True if workers respond, otherwise raises or returns False.
If return_worker_ids=True: Dict with 'prefill_worker_id' and 'decode_worker_id' keys.
"""
wait_time
=
initial_wait
...
...
@@ -463,8 +512,11 @@ async def send_request_via_python_kv_router(
f
"Failed to connect to workers after
{
max_retries
+
1
}
attempts"
)
from
e
# Collect tokens from the SSE stream
# Collect tokens
and worker IDs
from the SSE stream
generated_tokens
=
[]
prefill_worker_id
:
Optional
[
int
]
=
None
decode_worker_id
:
Optional
[
int
]
=
None
async
for
response
in
stream
:
if
isinstance
(
response
,
dict
):
# Check if response has token_ids
...
...
@@ -480,6 +532,17 @@ async def send_request_via_python_kv_router(
f
"Stream finished with reason:
{
response
[
'finish_reason'
]
}
"
)
# Extract worker IDs from disaggregated_params if present
if
return_worker_ids
and
"disaggregated_params"
in
response
:
disagg_params
=
response
[
"disaggregated_params"
]
if
isinstance
(
disagg_params
,
dict
)
and
"worker_id"
in
disagg_params
:
worker_id_info
=
disagg_params
[
"worker_id"
]
if
isinstance
(
worker_id_info
,
dict
):
if
"prefill_worker_id"
in
worker_id_info
:
prefill_worker_id
=
worker_id_info
[
"prefill_worker_id"
]
if
"decode_worker_id"
in
worker_id_info
:
decode_worker_id
=
worker_id_info
[
"decode_worker_id"
]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger
.
debug
(
f
"Total generated tokens:
{
len
(
generated_tokens
)
}
"
)
if
(
...
...
@@ -497,9 +560,14 @@ async def send_request_via_python_kv_router(
logger
.
debug
(
f
"Successfully verified
{
max_tokens
}
tokens generated as expected via KvPushRouter with ignore_eos=True"
)
return
True
return
False
if
return_worker_ids
:
return
{
"prefill_worker_id"
:
prefill_worker_id
,
"decode_worker_id"
:
decode_worker_id
,
}
return
True
########################################################
...
...
@@ -1498,7 +1566,7 @@ def _test_router_indexers_sync(
logger
.
info
(
"Indexers sync test completed successfully"
)
def
_test_router_
disagg_
decisions
(
def
_test_router_decisions
_disagg
(
prefill_workers
,
decode_workers
,
block_size
:
int
,
...
...
@@ -1743,6 +1811,7 @@ def _test_router_decisions(
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens
=
[]
response_worker_ids
:
list
[
dict
[
str
,
Optional
[
int
]]]
=
[]
for
i
in
range
(
4
):
# Add BLOCK_SIZE new random tokens
...
...
@@ -1764,7 +1833,7 @@ def _test_router_decisions(
log_msg
+=
f
" - FORCING worker_id=
{
worker_id_override
}
"
logger
.
info
(
log_msg
)
await
send_request_via_python_kv_router
(
result
=
await
send_request_via_python_kv_router
(
kv_python_router
=
kv_push_router
,
model_name
=
model_name
,
token_ids
=
cumulative_tokens
.
copy
(),
...
...
@@ -1776,6 +1845,13 @@ def _test_router_decisions(
},
worker_id
=
worker_id_override
,
dp_rank
=
dp_rank_override
,
return_worker_ids
=
True
,
)
assert
isinstance
(
result
,
dict
),
f
"Expected dict result, got
{
type
(
result
)
}
"
response_worker_ids
.
append
(
result
)
logger
.
info
(
f
"Request
{
i
+
1
}
response: prefill_worker_id=
{
result
.
get
(
'prefill_worker_id'
)
}
, "
f
"decode_worker_id=
{
result
.
get
(
'decode_worker_id'
)
}
"
)
# Wait a bit between requests
...
...
@@ -1787,10 +1863,23 @@ def _test_router_decisions(
# Dump events from the router
events_json
=
await
kv_push_router
.
dump_events
()
return
events_json
,
forced_worker_id
,
forced_dp_rank
return
events_json
,
forced_worker_id
,
forced_dp_rank
,
response_worker_ids
# Run the async test
events_json
,
expected_worker_id
,
expected_dp_rank
=
asyncio
.
run
(
test_sync
())
(
events_json
,
expected_worker_id
,
expected_dp_rank
,
response_worker_ids
,
)
=
asyncio
.
run
(
test_sync
())
# Verify worker IDs from responses
verify_response_worker_ids
(
response_worker_ids
,
"decode_worker_id"
,
expected_worker_id
)
verify_response_worker_ids
(
response_worker_ids
,
"prefill_worker_id"
,
expected_worker_id
)
# Parse events and count by worker routing key (worker_id or (worker_id, dp_rank))
events
=
json
.
loads
(
events_json
)
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
1e5b20b2
...
...
@@ -11,7 +11,7 @@ from tests.router.common import ( # utilities
_test_python_router_bindings
,
_test_router_basic
,
_test_router_decisions
,
_test_router_
disagg_
decisions
,
_test_router_decisions
_disagg
,
_test_router_indexers_sync
,
_test_router_overload_503
,
_test_router_query_instance_id
,
...
...
@@ -66,7 +66,7 @@ def get_unique_ports(
"test_mocker_two_kv_router"
:
100
,
"test_mocker_kv_router_overload_503"
:
200
,
"test_query_instance_id_returns_worker_and_tokens"
:
300
,
"test_router_
disagg_
decisions"
:
400
,
"test_router_decisions
_disagg
"
:
400
,
"test_busy_threshold_endpoint"
:
500
,
}
...
...
@@ -583,7 +583,7 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
@
pytest
.
mark
.
parallel
def
test_router_
disagg_
decisions
(
def
test_router_decisions
_disagg
(
request
,
runtime_services_session
,
predownload_tokenizers
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
...
...
@@ -632,7 +632,7 @@ def test_router_disagg_decisions(
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
)[
0
]
# Run disagg routing test
_test_router_
disagg_
decisions
(
_test_router_decisions
_disagg
(
prefill_workers
=
prefill_workers
,
decode_workers
=
decode_workers
,
block_size
=
BLOCK_SIZE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment