Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
57c701fb
Unverified
Commit
57c701fb
authored
Nov 24, 2025
by
Yan Ru Pei
Committed by
GitHub
Nov 25, 2025
Browse files
fix: expose prefill worker id in disagg (#4563)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
550bf98c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
458 additions
and
100 deletions
+458
-100
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+23
-25
tests/router/common.py
tests/router/common.py
+194
-0
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+241
-75
No files found.
lib/llm/src/kv_router.rs
View file @
57c701fb
...
...
@@ -623,14 +623,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
backend_input
.estimated_prefix_hit_num_blocks
=
Some
(
overlap_amount
);
backend_input
.dp_rank
=
Some
(
dp_rank
);
// Check if worker_id is requested in extra_fields
let
should_populate_worker_id
=
backend_input
.extra_fields
.as_deref
()
.unwrap_or
(
&
[])
.iter
()
.any
(|
s
|
s
==
"worker_id"
);
// Get prefill worker ID if available (stored by PrefillRouter)
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
let
decode_worker_id
=
instance_id
;
...
...
@@ -672,25 +664,31 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
prefill_marked
=
true
;
}
// Inject worker_id in first item's disaggregated_params if requested
if
first_item
&&
should_populate_worker_id
{
if
let
Some
(
ref
mut
data
)
=
item
.data
{
// Add worker_id to disaggregated_params
// Always inject worker_id in first item's disaggregated_params
// This is needed for:
// 1. PrefillRouter to know which prefill worker was chosen
// 2. Client response when extra_fields contains "worker_id"
if
first_item
{
first_item
=
false
;
let
Some
(
ref
mut
data
)
=
item
.data
else
{
yield
item
;
continue
;
};
// prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id
// decode_worker_id is always the current instance_id
let
worker_id_json
=
json!
({
"prefill_worker_id"
:
prefill_worker_id
,
"decode_worker_id"
:
decode_worker_id
,
});
if
let
Some
(
ref
mut
params
)
=
data
.disaggregated_params
{
if
let
Some
(
obj
)
=
params
.as_object_mut
()
{
if
let
Some
(
obj
)
=
data
.disaggregated_params
.as_mut
()
.and_then
(|
p
|
p
.as_object_mut
())
{
obj
.insert
(
"worker_id"
.to_string
(),
worker_id_json
);
}
}
else
{
data
.disaggregated_params
=
Some
(
json!
({
"worker_id"
:
worker_id_json
}));
}
}
first_item
=
false
;
}
yield
item
;
}
...
...
tests/router/common.py
View file @
57c701fb
...
...
@@ -36,6 +36,7 @@ class KVRouterProcess(ManagedProcess):
frontend_port
:
int
,
namespace
:
str
,
store_backend
:
str
=
"etcd"
,
enforce_disagg
:
bool
=
False
,
):
command
=
[
"python3"
,
...
...
@@ -53,6 +54,9 @@ class KVRouterProcess(ManagedProcess):
namespace
,
]
if
enforce_disagg
:
command
.
append
(
"--enforce-disagg"
)
super
().
__init__
(
command
=
command
,
timeout
=
60
,
...
...
@@ -1490,6 +1494,196 @@ def _test_router_indexers_sync(
logger
.
info
(
"Indexers sync test completed successfully"
)
def
_test_router_disagg_decisions
(
prefill_workers
,
decode_workers
,
block_size
:
int
,
request
,
frontend_port
:
int
,
test_payload
:
dict
,
store_backend
:
str
=
"etcd"
,
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
Assumes prefill_workers and decode_workers are already initialized. This function manages
router lifecycle and sends progressive requests with overlapping prefixes.
This test:
1. Starts the KV router frontend with disagg support
2. Sends 4 progressive requests where each extends the previous tokens by block_size
3. Extracts prefill_worker_id and decode_worker_id from response nvext
4. Verifies all prefill_worker_ids are the same (due to prefix reuse routing)
5. Verifies prefill_worker_id is NOT in the set of decode_worker_ids (true disagg)
Args:
prefill_workers: Prefill workers already initialized with __enter__()
decode_workers: Decode workers already initialized with __enter__()
block_size: Block size for KV cache
request: Pytest request fixture for managing resources
frontend_port: Port for the frontend HTTP server
test_payload: Base test payload to send to /v1/chat/completions
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
Raises:
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
"""
try
:
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
for disagg test"
)
kv_router
=
KVRouterProcess
(
request
,
block_size
,
frontend_port
,
decode_workers
.
namespace
,
store_backend
,
enforce_disagg
=
True
,
)
kv_router
.
__enter__
()
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
chat_url
=
f
"
{
frontend_url
}
/v1/chat/completions"
# Wait for workers to register with frontend
logger
.
info
(
"Waiting for prefill and decode workers to register with frontend..."
)
asyncio
.
run
(
wait_for_frontend_ready
(
frontend_url
=
frontend_url
,
expected_num_workers
=
decode_workers
.
num_workers
,
timeout
=
120
,
)
)
async
def
send_progressive_requests
():
"""Send 4 progressive requests with overlapping prefixes and collect worker IDs."""
prefill_worker_ids
=
[]
decode_worker_ids
=
[]
# Generate base tokens for progressive prefix extension
base_content
=
test_payload
[
"messages"
][
0
][
"content"
]
async
with
aiohttp
.
ClientSession
()
as
session
:
for
i
in
range
(
4
):
# Build progressive content by repeating base content
# Each iteration adds more content to extend the prefix
progressive_content
=
" "
.
join
([
base_content
]
*
(
i
+
1
))
# Create payload with worker_id in extra_fields to get prefill/decode worker IDs
payload
=
{
**
test_payload
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
progressive_content
,
}
],
"nvext"
:
{
"extra_fields"
:
[
"worker_id"
]},
"stream"
:
True
,
}
logger
.
info
(
f
"Sending request
{
i
+
1
}
/4 with progressive prefix "
f
"(~
{
len
(
progressive_content
)
}
chars)"
)
async
with
session
.
post
(
chat_url
,
json
=
payload
)
as
response
:
assert
(
response
.
status
==
200
),
f
"Request
{
i
+
1
}
failed with status
{
response
.
status
}
"
# Collect all chunks and look for nvext with worker_id
prefill_wid
=
None
decode_wid
=
None
async
for
line
in
response
.
content
:
if
not
line
:
continue
line_str
=
line
.
decode
(
"utf-8"
,
errors
=
"replace"
).
strip
()
if
not
line_str
.
startswith
(
"data:"
):
continue
data_str
=
line_str
[
5
:].
strip
()
if
data_str
==
"[DONE]"
:
break
try
:
data
=
json
.
loads
(
data_str
)
# Check for nvext.worker_id in the response
nvext
=
data
.
get
(
"nvext"
,
{})
worker_id_info
=
nvext
.
get
(
"worker_id"
,
{})
if
worker_id_info
:
if
"prefill_worker_id"
in
worker_id_info
:
prefill_wid
=
worker_id_info
[
"prefill_worker_id"
]
if
"decode_worker_id"
in
worker_id_info
:
decode_wid
=
worker_id_info
[
"decode_worker_id"
]
except
json
.
JSONDecodeError
:
continue
logger
.
info
(
f
"Request
{
i
+
1
}
: prefill_worker_id=
{
prefill_wid
}
, "
f
"decode_worker_id=
{
decode_wid
}
"
)
if
prefill_wid
is
not
None
:
prefill_worker_ids
.
append
(
prefill_wid
)
if
decode_wid
is
not
None
:
decode_worker_ids
.
append
(
decode_wid
)
# Small delay between requests
await
asyncio
.
sleep
(
0.5
)
return
prefill_worker_ids
,
decode_worker_ids
# Run the progressive requests
prefill_ids
,
decode_ids
=
asyncio
.
run
(
send_progressive_requests
())
logger
.
info
(
f
"Collected prefill_worker_ids:
{
prefill_ids
}
"
)
logger
.
info
(
f
"Collected decode_worker_ids:
{
decode_ids
}
"
)
# Verify we got worker IDs from all requests
assert
len
(
prefill_ids
)
==
4
,
(
f
"Expected 4 prefill_worker_ids, got
{
len
(
prefill_ids
)
}
. "
f
"Make sure nvext.extra_fields=['worker_id'] is being processed."
)
# Verify all prefill_worker_ids are the same (prefix reuse)
unique_prefill_ids
=
set
(
prefill_ids
)
assert
len
(
unique_prefill_ids
)
==
1
,
(
f
"Expected all prefill requests to route to the same worker due to prefix reuse, "
f
"but found
{
len
(
unique_prefill_ids
)
}
unique prefill_worker_ids:
{
unique_prefill_ids
}
. "
f
"Full list:
{
prefill_ids
}
"
)
# Verify prefill_worker_id is NOT in decode_worker_ids (true disagg)
unique_decode_ids
=
set
(
decode_ids
)
prefill_id
=
prefill_ids
[
0
]
assert
prefill_id
not
in
unique_decode_ids
,
(
f
"Prefill worker
{
prefill_id
}
should NOT be in decode workers
{
unique_decode_ids
}
. "
f
"This suggests disaggregated mode is not working correctly - "
f
"prefill and decode should use separate worker pools."
)
logger
.
info
(
f
"Successfully verified disaggregated routing:
\n
"
f
" - All 4 requests routed to same prefill_worker_id=
{
prefill_id
}
(prefix reuse)
\n
"
f
" - Prefill worker is NOT in decode worker set
{
unique_decode_ids
}
(true disagg)"
)
finally
:
if
"kv_router"
in
locals
():
kv_router
.
__exit__
(
None
,
None
,
None
)
def
_test_router_decisions
(
engine_workers
,
endpoint
,
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
57c701fb
...
...
@@ -10,6 +10,7 @@ from tests.router.common import ( # utilities
_test_python_router_bindings
,
_test_router_basic
,
_test_router_decisions
,
_test_router_disagg_decisions
,
_test_router_indexers_sync
,
_test_router_overload_503
,
_test_router_query_instance_id
,
...
...
@@ -61,6 +62,7 @@ def get_unique_ports(
"test_mocker_two_kv_router"
:
100
,
"test_mocker_kv_router_overload_503"
:
200
,
"test_query_instance_id_returns_worker_and_tokens"
:
300
,
"test_router_disagg_decisions"
:
400
,
}
base_offset
=
test_offsets
.
get
(
test_name
,
0
)
...
...
@@ -87,31 +89,25 @@ TEST_PAYLOAD: Dict[str, Any] = {
}
class
MockerProcess
:
"""Manages multiple mocker engine instances with the same namespace"""
def
_build_mocker_command
(
endpoint
:
str
,
store_backend
:
str
,
num_workers
:
int
,
mocker_args
:
Dict
[
str
,
Any
],
worker_type
:
Optional
[
str
]
=
None
,
)
->
list
[
str
]:
"""Build the mocker CLI command with all arguments.
def
__init__
(
self
,
request
,
mocker_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_mockers
:
int
=
1
,
store_backend
:
str
=
"etcd"
,
):
# Generate a unique namespace suffix shared by all mockers
namespace_suffix
=
generate_random_suffix
()
self
.
namespace
=
f
"test-namespace-
{
namespace_suffix
}
"
self
.
component_name
=
"mocker"
self
.
endpoint
=
f
"dyn://
{
self
.
namespace
}
.
{
self
.
component_name
}
.generate"
self
.
num_mockers
=
num_mockers
self
.
num_workers
=
self
.
num_mockers
# for compatibility with common.py
self
.
mocker_processes
=
[]
# Default mocker args if not provided
if
mocker_args
is
None
:
mocker_args
=
{}
Args:
endpoint: The dynamo endpoint string
store_backend: Storage backend ("etcd" or "file")
num_workers: Number of workers to spawn (uses --num-workers flag)
mocker_args: Dictionary of mocker arguments
worker_type: Optional worker type ("prefill" or "decode") for disagg mode
# Create multiple mocker processes with the same namespace
for
i
in
range
(
num_mockers
):
Returns:
List of command arguments for subprocess
"""
command
=
[
"python"
,
"-m"
,
...
...
@@ -119,11 +115,19 @@ class MockerProcess:
"--model-path"
,
MODEL_NAME
,
"--endpoint"
,
self
.
endpoint
,
endpoint
,
"--store-kv"
,
store_backend
,
"--num-workers"
,
str
(
num_workers
),
]
# Add worker type flag for disaggregated mode
if
worker_type
==
"prefill"
:
command
.
append
(
"--is-prefill-worker"
)
elif
worker_type
==
"decode"
:
command
.
append
(
"--is-decode-worker"
)
# Add individual CLI arguments from mocker_args
if
"speedup_ratio"
in
mocker_args
:
command
.
extend
([
"--speedup-ratio"
,
str
(
mocker_args
[
"speedup_ratio"
])])
...
...
@@ -137,10 +141,7 @@ class MockerProcess:
command
.
extend
([
"--max-num-seqs"
,
str
(
mocker_args
[
"max_num_seqs"
])])
if
"max_num_batched_tokens"
in
mocker_args
:
command
.
extend
(
[
"--max-num-batched-tokens"
,
str
(
mocker_args
[
"max_num_batched_tokens"
]),
]
[
"--max-num-batched-tokens"
,
str
(
mocker_args
[
"max_num_batched_tokens"
])]
)
if
"enable_prefix_caching"
in
mocker_args
:
if
mocker_args
[
"enable_prefix_caching"
]:
...
...
@@ -157,7 +158,35 @@ class MockerProcess:
if
"dp_size"
in
mocker_args
:
command
.
extend
([
"--data-parallel-size"
,
str
(
mocker_args
[
"dp_size"
])])
process
=
ManagedProcess
(
return
command
class
MockerProcess
:
"""Manages mocker engine instances with shared tokio runtime via --num-workers."""
def
__init__
(
self
,
request
,
mocker_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_mockers
:
int
=
1
,
store_backend
:
str
=
"etcd"
,
):
namespace_suffix
=
generate_random_suffix
()
self
.
namespace
=
f
"test-namespace-
{
namespace_suffix
}
"
self
.
component_name
=
"mocker"
self
.
endpoint
=
f
"dyn://
{
self
.
namespace
}
.
{
self
.
component_name
}
.generate"
self
.
num_workers
=
num_mockers
mocker_args
=
mocker_args
or
{}
command
=
_build_mocker_command
(
endpoint
=
self
.
endpoint
,
store_backend
=
store_backend
,
num_workers
=
num_mockers
,
mocker_args
=
mocker_args
,
)
self
.
_process
=
ManagedProcess
(
command
=
command
,
timeout
=
60
,
display_output
=
True
,
...
...
@@ -166,21 +195,90 @@ class MockerProcess:
log_dir
=
request
.
node
.
name
,
terminate_existing
=
False
,
)
self
.
mocker_processes
.
append
(
process
)
logger
.
info
(
f
"Created mocker instance
{
i
}
with endpoint:
{
self
.
endpoint
}
"
)
logger
.
info
(
f
"Created mocker process with
{
num_mockers
}
worker(s), endpoint:
{
self
.
endpoint
}
"
)
def
__enter__
(
self
):
"""Start all mocker processes"""
for
i
,
process
in
enumerate
(
self
.
mocker_processes
):
logger
.
info
(
f
"Starting mocker instance
{
i
}
"
)
process
.
__enter__
()
logger
.
info
(
f
"Starting mocker process with
{
self
.
num_workers
}
worker(s)"
)
self
.
_process
.
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Stop all mocker processes"""
for
i
,
process
in
enumerate
(
self
.
mocker_processes
):
logger
.
info
(
f
"Stopping mocker instance
{
i
}
"
)
process
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
logger
.
info
(
"Stopping mocker process"
)
self
.
_process
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
class
DisaggMockerProcess
:
"""Manages prefill or decode mocker instances for disaggregated serving.
Uses --num-workers for shared tokio runtime. For disaggregated serving:
- Prefill workers: worker_type="prefill", endpoint is namespace.prefill.generate
- Decode workers: worker_type="decode", endpoint is namespace.backend.generate
Both prefill and decode workers should share the same namespace for proper discovery.
"""
def
__init__
(
self
,
request
,
namespace
:
str
,
worker_type
:
str
,
mocker_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
num_mockers
:
int
=
1
,
store_backend
:
str
=
"etcd"
,
):
if
worker_type
not
in
(
"prefill"
,
"decode"
):
raise
ValueError
(
f
"worker_type must be 'prefill' or 'decode', got
{
worker_type
}
"
)
self
.
namespace
=
namespace
self
.
worker_type
=
worker_type
self
.
num_workers
=
num_mockers
# Set component name and endpoint based on worker type
if
worker_type
==
"prefill"
:
self
.
component_name
=
"prefill"
self
.
endpoint
=
f
"dyn://
{
self
.
namespace
}
.prefill.generate"
else
:
self
.
component_name
=
"backend"
self
.
endpoint
=
f
"dyn://
{
self
.
namespace
}
.backend.generate"
mocker_args
=
mocker_args
or
{}
command
=
_build_mocker_command
(
endpoint
=
self
.
endpoint
,
store_backend
=
store_backend
,
num_workers
=
num_mockers
,
mocker_args
=
mocker_args
,
worker_type
=
worker_type
,
)
self
.
_process
=
ManagedProcess
(
command
=
command
,
timeout
=
60
,
display_output
=
True
,
health_check_ports
=
[],
health_check_urls
=
[],
log_dir
=
request
.
node
.
name
,
terminate_existing
=
False
,
)
logger
.
info
(
f
"Created
{
worker_type
}
mocker process with
{
num_mockers
}
worker(s), "
f
"endpoint:
{
self
.
endpoint
}
"
)
def
__enter__
(
self
):
logger
.
info
(
f
"Starting
{
self
.
worker_type
}
mocker process with
{
self
.
num_workers
}
worker(s)"
)
self
.
_process
.
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
logger
.
info
(
f
"Stopping
{
self
.
worker_type
}
mocker process"
)
self
.
_process
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
@
pytest
.
mark
.
pre_merge
...
...
@@ -492,3 +590,71 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
parallel
@
pytest
.
mark
.
model
(
MODEL_NAME
)
def
test_router_disagg_decisions
(
request
,
runtime_services_session
,
predownload_tokenizers
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse.
"""
logger
.
info
(
"Starting disaggregated router prefix reuse test"
)
# Generate shared namespace for prefill and decode workers
namespace_suffix
=
generate_random_suffix
()
shared_namespace
=
f
"test-namespace-
{
namespace_suffix
}
"
# Create mocker args
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
prefill_workers
=
None
decode_workers
=
None
try
:
# Start prefill workers (4 instances)
logger
.
info
(
"Starting 4 prefill mocker instances"
)
prefill_workers
=
DisaggMockerProcess
(
request
,
namespace
=
shared_namespace
,
worker_type
=
"prefill"
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
)
prefill_workers
.
__enter__
()
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
# Start decode workers (4 instances)
logger
.
info
(
"Starting 4 decode mocker instances"
)
decode_workers
=
DisaggMockerProcess
(
request
,
namespace
=
shared_namespace
,
worker_type
=
"decode"
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
)
decode_workers
.
__enter__
()
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
# Get unique port for this test
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
)[
0
]
# Run disagg routing test
_test_router_disagg_decisions
(
prefill_workers
=
prefill_workers
,
decode_workers
=
decode_workers
,
block_size
=
BLOCK_SIZE
,
request
=
request
,
frontend_port
=
frontend_port
,
test_payload
=
TEST_PAYLOAD
,
)
finally
:
if
decode_workers
is
not
None
:
decode_workers
.
__exit__
(
None
,
None
,
None
)
if
prefill_workers
is
not
None
:
prefill_workers
.
__exit__
(
None
,
None
,
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment