Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
6fe2152b
Unverified
Commit
6fe2152b
authored
Feb 10, 2026
by
Karen Chung
Committed by
GitHub
Feb 11, 2026
Browse files
test: refactor router e2e tests to use context managers for process lifecycle (#6088)
parent
1cd3b724
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
243 additions
and
417 deletions
+243
-417
tests/router/common.py
tests/router/common.py
+47
-94
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+104
-149
tests/router/test_router_e2e_with_sglang.py
tests/router/test_router_e2e_with_sglang.py
+31
-58
tests/router/test_router_e2e_with_trtllm.py
tests/router/test_router_e2e_with_trtllm.py
+30
-57
tests/router/test_router_e2e_with_vllm.py
tests/router/test_router_e2e_with_vllm.py
+31
-59
No files found.
tests/router/common.py
View file @
6fe2152b
...
@@ -528,6 +528,7 @@ async def send_request_via_python_kv_router(
...
@@ -528,6 +528,7 @@ async def send_request_via_python_kv_router(
)
)
# Retry loop sending request to worker with exponential backoff
# Retry loop sending request to worker with exponential backoff
stream
=
None
for
attempt
in
range
(
max_retries
+
1
):
for
attempt
in
range
(
max_retries
+
1
):
try
:
try
:
logger
.
debug
(
f
"Sending request to
{
log_message
}
(attempt
{
attempt
+
1
}
)"
)
logger
.
debug
(
f
"Sending request to
{
log_message
}
(attempt
{
attempt
+
1
}
)"
)
...
@@ -557,6 +558,11 @@ async def send_request_via_python_kv_router(
...
@@ -557,6 +558,11 @@ async def send_request_via_python_kv_router(
f
"Failed to connect to workers after
{
max_retries
+
1
}
attempts"
f
"Failed to connect to workers after
{
max_retries
+
1
}
attempts"
)
from
e
)
from
e
if
stream
is
None
:
raise
RuntimeError
(
f
"Failed to get a valid stream from workers after
{
max_retries
+
1
}
attempts"
)
# Collect tokens and worker IDs from the SSE stream
# Collect tokens and worker IDs from the SSE stream
generated_tokens
=
[]
generated_tokens
=
[]
prefill_worker_id
:
Optional
[
int
]
=
None
prefill_worker_id
:
Optional
[
int
]
=
None
...
@@ -653,18 +659,16 @@ def _test_router_basic(
...
@@ -653,18 +659,16 @@ def _test_router_basic(
AssertionError: If requests fail or frontend doesn't become ready
AssertionError: If requests fail or frontend doesn't become ready
TimeoutError: If frontend doesn't become ready within timeout
TimeoutError: If frontend doesn't become ready within timeout
"""
"""
try
:
with
KVRouterProcess
(
# Start KV router frontend
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
"
)
kv_router
=
KVRouterProcess
(
request
,
request
,
block_size
,
block_size
,
frontend_port
,
frontend_port
,
engine_workers
.
namespace
,
engine_workers
.
namespace
,
store_backend
,
store_backend
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
):
kv_router
.
__enter__
()
# Start KV router frontend
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
"
)
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
...
@@ -690,10 +694,6 @@ def _test_router_basic(
...
@@ -690,10 +694,6 @@ def _test_router_basic(
logger
.
info
(
f
"Successfully completed
{
num_requests
}
requests"
)
logger
.
info
(
f
"Successfully completed
{
num_requests
}
requests"
)
finally
:
if
"kv_router"
in
locals
():
kv_router
.
__exit__
(
None
,
None
,
None
)
def
_test_router_two_routers
(
def
_test_router_two_routers
(
engine_workers
,
engine_workers
,
...
@@ -1036,13 +1036,11 @@ def _test_router_query_instance_id(
...
@@ -1036,13 +1036,11 @@ def _test_router_query_instance_id(
AssertionError: If annotation response structure is incorrect or contains generation content
AssertionError: If annotation response structure is incorrect or contains generation content
"""
"""
try
:
with
KVRouterProcess
(
request
,
block_size
,
frontend_port
,
engine_workers
.
namespace
,
store_backend
):
# Start KV router (frontend)
# Start KV router (frontend)
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
"
)
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
"
)
kv_router
=
KVRouterProcess
(
request
,
block_size
,
frontend_port
,
engine_workers
.
namespace
,
store_backend
)
kv_router
.
__enter__
()
url
=
f
"http://localhost:
{
frontend_port
}
/v1/chat/completions"
url
=
f
"http://localhost:
{
frontend_port
}
/v1/chat/completions"
...
@@ -1164,10 +1162,6 @@ def _test_router_query_instance_id(
...
@@ -1164,10 +1162,6 @@ def _test_router_query_instance_id(
logger
.
info
(
f
"Decode Worker ID:
{
result
[
'decode_worker_id'
]
}
"
)
logger
.
info
(
f
"Decode Worker ID:
{
result
[
'decode_worker_id'
]
}
"
)
logger
.
info
(
f
"Token count:
{
result
[
'token_count'
]
}
"
)
logger
.
info
(
f
"Token count:
{
result
[
'token_count'
]
}
"
)
finally
:
if
"kv_router"
in
locals
():
kv_router
.
__exit__
(
None
,
None
,
None
)
def
_test_router_overload_503
(
def
_test_router_overload_503
(
engine_workers
,
engine_workers
,
...
@@ -1194,42 +1188,17 @@ def _test_router_overload_503(
...
@@ -1194,42 +1188,17 @@ def _test_router_overload_503(
AssertionError: If 503 response is not received when expected
AssertionError: If 503 response is not received when expected
"""
"""
try
:
logger
.
info
(
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
with limited resources"
f
"Starting KV router frontend on port
{
frontend_port
}
with limited resources"
)
)
# Custom command for router with limited block size
with
KVRouterProcess
(
command
=
[
request
=
request
,
"python"
,
block_size
=
block_size
,
"-m"
,
frontend_port
=
frontend_port
,
"dynamo.frontend"
,
namespace
=
engine_workers
.
namespace
,
"--active-decode-blocks-threshold"
,
blocks_threshold
=
blocks_threshold
,
str
(
blocks_threshold
),
):
"--kv-cache-block-size"
,
str
(
block_size
),
"--router-mode"
,
"kv"
,
"--http-port"
,
str
(
frontend_port
),
]
kv_router
=
ManagedProcess
(
command
=
command
,
timeout
=
60
,
display_output
=
True
,
health_check_ports
=
[
frontend_port
],
health_check_urls
=
[
(
f
"http://localhost:
{
frontend_port
}
/v1/models"
,
lambda
r
:
r
.
status_code
==
200
,
)
],
log_dir
=
request
.
node
.
name
,
terminate_all_matching_process_names
=
False
,
)
kv_router
.
__enter__
()
url
=
f
"http://localhost:
{
frontend_port
}
/v1/chat/completions"
url
=
f
"http://localhost:
{
frontend_port
}
/v1/chat/completions"
# Custom payload for 503 test with more tokens to consume resources
# Custom payload for 503 test with more tokens to consume resources
...
@@ -1325,10 +1294,6 @@ def _test_router_overload_503(
...
@@ -1325,10 +1294,6 @@ def _test_router_overload_503(
logger
.
info
(
"Successfully verified 503 response when all workers are busy"
)
logger
.
info
(
"Successfully verified 503 response when all workers are busy"
)
finally
:
if
"kv_router"
in
locals
():
kv_router
.
__exit__
(
None
,
None
,
None
)
def
_test_router_indexers_sync
(
def
_test_router_indexers_sync
(
engine_workers
,
engine_workers
,
...
@@ -1727,13 +1692,7 @@ def _test_router_decisions_disagg(
...
@@ -1727,13 +1692,7 @@ def _test_router_decisions_disagg(
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
"""
"""
try
:
with
KVRouterProcess
(
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
for disagg test"
)
kv_router
=
KVRouterProcess
(
request
,
request
,
block_size
,
block_size
,
frontend_port
,
frontend_port
,
...
@@ -1742,8 +1701,12 @@ def _test_router_decisions_disagg(
...
@@ -1742,8 +1701,12 @@ def _test_router_decisions_disagg(
enforce_disagg
=
True
,
enforce_disagg
=
True
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
durable_kv_events
=
durable_kv_events
,
durable_kv_events
=
durable_kv_events
,
):
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
for disagg test"
)
)
kv_router
.
__enter__
()
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
chat_url
=
f
"
{
frontend_url
}
/v1/chat/completions"
chat_url
=
f
"
{
frontend_url
}
/v1/chat/completions"
...
@@ -1908,10 +1871,6 @@ def _test_router_decisions_disagg(
...
@@ -1908,10 +1871,6 @@ def _test_router_decisions_disagg(
f
" - Prefill worker is NOT in decode worker set
{
unique_decode_ids
}
(true disagg)"
f
" - Prefill worker is NOT in decode worker set
{
unique_decode_ids
}
(true disagg)"
)
)
finally
:
if
"kv_router"
in
locals
():
kv_router
.
__exit__
(
None
,
None
,
None
)
def
_test_router_decisions
(
def
_test_router_decisions
(
engine_workers
,
engine_workers
,
...
@@ -2190,10 +2149,7 @@ def _test_busy_threshold_endpoint(
...
@@ -2190,10 +2149,7 @@ def _test_busy_threshold_endpoint(
initial_active_decode_blocks_threshold
=
0.9
initial_active_decode_blocks_threshold
=
0.9
initial_active_prefill_tokens_threshold
=
1000
# Literal token count threshold
initial_active_prefill_tokens_threshold
=
1000
# Literal token count threshold
try
:
with
KVRouterProcess
(
# Start KV router frontend with initial thresholds to create monitor
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
"
)
kv_router
=
KVRouterProcess
(
request
,
request
,
block_size
,
block_size
,
frontend_port
,
frontend_port
,
...
@@ -2202,8 +2158,9 @@ def _test_busy_threshold_endpoint(
...
@@ -2202,8 +2158,9 @@ def _test_busy_threshold_endpoint(
blocks_threshold
=
initial_active_decode_blocks_threshold
,
blocks_threshold
=
initial_active_decode_blocks_threshold
,
tokens_threshold
=
initial_active_prefill_tokens_threshold
,
tokens_threshold
=
initial_active_prefill_tokens_threshold
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
):
kv_router
.
__enter__
()
# Start KV router frontend with initial thresholds to create monitor
logger
.
info
(
f
"Starting KV router frontend on port
{
frontend_port
}
"
)
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
frontend_url
=
f
"http://localhost:
{
frontend_port
}
"
busy_threshold_url
=
f
"
{
frontend_url
}
/busy_threshold"
busy_threshold_url
=
f
"
{
frontend_url
}
/busy_threshold"
...
@@ -2464,7 +2421,3 @@ def _test_busy_threshold_endpoint(
...
@@ -2464,7 +2421,3 @@ def _test_busy_threshold_endpoint(
logger
.
info
(
"All busy_threshold endpoint tests passed!"
)
logger
.
info
(
"All busy_threshold endpoint tests passed!"
)
asyncio
.
run
(
test_busy_threshold_api
())
asyncio
.
run
(
test_busy_threshold_api
())
finally
:
if
"kv_router"
in
locals
():
kv_router
.
__exit__
(
None
,
None
,
None
)
tests/router/test_router_e2e_with_mockers.py
View file @
6fe2152b
...
@@ -351,17 +351,15 @@ def test_mocker_kv_router(
...
@@ -351,17 +351,15 @@ def test_mocker_kv_router(
"durable_kv_events"
:
durable_kv_events
,
"durable_kv_events"
:
durable_kv_events
,
}
}
try
:
with
MockerProcess
(
# Start mocker instances with the new CLI interface
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
mockers
=
MockerProcess
(
request
,
request
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
,
num_mockers
=
NUM_MOCKERS
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
mockers
:
# Start mocker instances with the new CLI interface
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
# Get unique port for this test
# Get unique port for this test
frontend_port
=
get_unique_ports
(
frontend_port
=
get_unique_ports
(
...
@@ -379,10 +377,6 @@ def test_mocker_kv_router(
...
@@ -379,10 +377,6 @@ def test_mocker_kv_router(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
parametrize
(
"store_backend"
,
[
"etcd"
,
"file"
])
@
pytest
.
mark
.
parametrize
(
"store_backend"
,
[
"etcd"
,
"file"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -415,17 +409,15 @@ def test_mocker_two_kv_router(
...
@@ -415,17 +409,15 @@ def test_mocker_two_kv_router(
"durable_kv_events"
:
durable_kv_events
,
"durable_kv_events"
:
durable_kv_events
,
}
}
try
:
with
MockerProcess
(
# Start mocker instances with the new CLI interface
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
mockers
=
MockerProcess
(
request
,
request
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
,
num_mockers
=
NUM_MOCKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
)
)
as
mockers
:
# Start mocker instances with the new CLI interface
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
# Get unique ports for this test (2 ports for two routers)
# Get unique ports for this test (2 ports for two routers)
router_ports
=
get_unique_ports
(
router_ports
=
get_unique_ports
(
...
@@ -444,10 +436,6 @@ def test_mocker_two_kv_router(
...
@@ -444,10 +436,6 @@ def test_mocker_two_kv_router(
skip_consumer_verification
=
not
durable_kv_events
,
# Skip JetStream checks in NATS Core mode
skip_consumer_verification
=
not
durable_kv_events
,
# Skip JetStream checks in NATS Core mode
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
skip
(
reason
=
"Flaky, temporarily disabled"
)
@
pytest
.
mark
.
skip
(
reason
=
"Flaky, temporarily disabled"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -467,12 +455,10 @@ def test_mocker_kv_router_overload_503(
...
@@ -467,12 +455,10 @@ def test_mocker_kv_router_overload_503(
"durable_kv_events"
:
durable_kv_events
,
"durable_kv_events"
:
durable_kv_events
,
}
}
try
:
with
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
1
)
as
mockers
:
# Start single mocker instance with limited resources
# Start single mocker instance with limited resources
logger
.
info
(
"Starting single mocker instance with limited resources"
)
logger
.
info
(
"Starting single mocker instance with limited resources"
)
mockers
=
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
1
)
logger
.
info
(
f
"Mocker using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"Mocker using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
# Get unique port for this test
# Get unique port for this test
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
)[
0
]
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
)[
0
]
...
@@ -487,10 +473,6 @@ def test_mocker_kv_router_overload_503(
...
@@ -487,10 +473,6 @@ def test_mocker_kv_router_overload_503(
blocks_threshold
=
0.2
,
blocks_threshold
=
0.2
,
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
timeout
(
90
)
# bumped for xdist contention (was 22s; ~7.10s serial avg)
@
pytest
.
mark
.
timeout
(
90
)
# bumped for xdist contention (was 22s; ~7.10s serial avg)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
...
@@ -513,17 +495,15 @@ def test_kv_push_router_bindings(
...
@@ -513,17 +495,15 @@ def test_kv_push_router_bindings(
"durable_kv_events"
:
durable_kv_events
,
"durable_kv_events"
:
durable_kv_events
,
}
}
try
:
with
MockerProcess
(
# Start mocker instances
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
mockers
=
MockerProcess
(
request
,
request
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
,
num_mockers
=
NUM_MOCKERS
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
mockers
:
# Start mocker instances
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
...
@@ -540,10 +520,6 @@ def test_kv_push_router_bindings(
...
@@ -540,10 +520,6 @@ def test_kv_push_router_bindings(
num_workers
=
NUM_MOCKERS
,
num_workers
=
NUM_MOCKERS
,
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"store_backend,durable_kv_events,request_plane"
,
"store_backend,durable_kv_events,request_plane"
,
...
@@ -596,18 +572,16 @@ def test_indexers_sync(
...
@@ -596,18 +572,16 @@ def test_indexers_sync(
"dp_size"
:
2
,
"dp_size"
:
2
,
}
}
try
:
with
MockerProcess
(
# Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances with dp_size=2"
)
mockers
=
MockerProcess
(
request
,
request
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
,
num_mockers
=
NUM_MOCKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
mockers
:
# Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances with dp_size=2"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
# Use the common test implementation (creates its own runtimes for each router)
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...
@@ -626,10 +600,6 @@ def test_indexers_sync(
...
@@ -626,10 +600,6 @@ def test_indexers_sync(
logger
.
info
(
"Indexers sync test completed successfully"
)
logger
.
info
(
"Indexers sync test completed successfully"
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
timeout
(
120
)
# bumped for xdist contention (was 42s; ~13.80s serial avg)
@
pytest
.
mark
.
timeout
(
120
)
# bumped for xdist contention (was 42s; ~13.80s serial avg)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -648,14 +618,12 @@ def test_query_instance_id_returns_worker_and_tokens(
...
@@ -648,14 +618,12 @@ def test_query_instance_id_returns_worker_and_tokens(
}
}
os
.
makedirs
(
request
.
node
.
name
,
exist_ok
=
True
)
os
.
makedirs
(
request
.
node
.
name
,
exist_ok
=
True
)
try
:
with
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
)
as
mockers
:
# Start mocker instances
# Start mocker instances
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
mockers
=
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
# Get unique port for this test
# Get unique port for this test
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
)[
0
]
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
)[
0
]
...
@@ -669,10 +637,6 @@ def test_query_instance_id_returns_worker_and_tokens(
...
@@ -669,10 +637,6 @@ def test_query_instance_id_returns_worker_and_tokens(
test_payload
=
TEST_PAYLOAD
,
test_payload
=
TEST_PAYLOAD
,
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
timeout
(
90
)
# bumped for xdist contention (was 29s; ~9.55s serial avg)
@
pytest
.
mark
.
timeout
(
90
)
# bumped for xdist contention (was 29s; ~9.55s serial avg)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
...
@@ -716,18 +680,15 @@ def test_router_decisions(
...
@@ -716,18 +680,15 @@ def test_router_decisions(
"durable_kv_events"
:
durable_kv_events
and
use_kv_events
,
"durable_kv_events"
:
durable_kv_events
and
use_kv_events
,
}
}
try
:
with
MockerProcess
(
mockers
=
MockerProcess
(
request
,
request
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
2
,
num_mockers
=
2
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
mockers
:
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
# Initialize mockers
# Initialize mockers
mockers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
# Use the namespace from the mockers
# Use the namespace from the mockers
...
@@ -745,10 +706,6 @@ def test_router_decisions(
...
@@ -745,10 +706,6 @@ def test_router_decisions(
durable_kv_events
=
durable_kv_events
,
durable_kv_events
=
durable_kv_events
,
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
parametrize
(
"registration_order"
,
[
"prefill_first"
,
"decode_first"
])
@
pytest
.
mark
.
parametrize
(
"registration_order"
,
[
"prefill_first"
,
"decode_first"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -788,14 +745,10 @@ def test_router_decisions_disagg(
...
@@ -788,14 +745,10 @@ def test_router_decisions_disagg(
# durable_kv_events defaults to False (NATS Core mode)
# durable_kv_events defaults to False (NATS Core mode)
}
}
prefill_workers
=
None
decode_workers
=
None
try
:
if
registration_order
==
"prefill_first"
:
if
registration_order
==
"prefill_first"
:
# Start prefill workers first
# Start prefill workers first
logger
.
info
(
"Starting 4 prefill mocker instances (first)"
)
logger
.
info
(
"Starting 4 prefill mocker instances (first)"
)
prefill_workers
=
DisaggMockerProcess
(
with
DisaggMockerProcess
(
request
,
request
,
namespace
=
shared_namespace
,
namespace
=
shared_namespace
,
worker_type
=
"prefill"
,
worker_type
=
"prefill"
,
...
@@ -803,39 +756,52 @@ def test_router_decisions_disagg(
...
@@ -803,39 +756,52 @@ def test_router_decisions_disagg(
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
"nats"
,
request_plane
=
"nats"
,
enable_bootstrap
=
enable_disagg_bootstrap
,
enable_bootstrap
=
enable_disagg_bootstrap
,
)
)
as
prefill_workers
:
prefill_workers
.
__enter__
()
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
# Then start decode workers
# Then start decode workers
logger
.
info
(
"Starting 4 decode mocker instances (second)"
)
logger
.
info
(
"Starting 4 decode mocker instances (second)"
)
decode_workers
=
DisaggMockerProcess
(
with
DisaggMockerProcess
(
request
,
request
,
namespace
=
shared_namespace
,
namespace
=
shared_namespace
,
worker_type
=
"decode"
,
worker_type
=
"decode"
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
"nats"
,
request_plane
=
"nats"
,
)
)
as
decode_workers
:
decode_workers
.
__enter__
()
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
# Get unique port for this test
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
,
registration_order
=
registration_order
)[
0
]
# Run disagg routing test
_test_router_decisions_disagg
(
prefill_workers
=
prefill_workers
,
decode_workers
=
decode_workers
,
block_size
=
BLOCK_SIZE
,
request
=
request
,
frontend_port
=
frontend_port
,
test_payload
=
TEST_PAYLOAD
,
request_plane
=
"nats"
,
)
else
:
else
:
# Start decode workers first
# Start decode workers first
logger
.
info
(
"Starting 4 decode mocker instances (first)"
)
logger
.
info
(
"Starting 4 decode mocker instances (first)"
)
decode_workers
=
DisaggMockerProcess
(
with
DisaggMockerProcess
(
request
,
request
,
namespace
=
shared_namespace
,
namespace
=
shared_namespace
,
worker_type
=
"decode"
,
worker_type
=
"decode"
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
"nats"
,
request_plane
=
"nats"
,
)
)
as
decode_workers
:
decode_workers
.
__enter__
()
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
# Then start prefill workers
# Then start prefill workers
logger
.
info
(
"Starting 4 prefill mocker instances (second)"
)
logger
.
info
(
"Starting 4 prefill mocker instances (second)"
)
prefill_workers
=
DisaggMockerProcess
(
with
DisaggMockerProcess
(
request
,
request
,
namespace
=
shared_namespace
,
namespace
=
shared_namespace
,
worker_type
=
"prefill"
,
worker_type
=
"prefill"
,
...
@@ -843,9 +809,10 @@ def test_router_decisions_disagg(
...
@@ -843,9 +809,10 @@ def test_router_decisions_disagg(
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
"nats"
,
request_plane
=
"nats"
,
enable_bootstrap
=
enable_disagg_bootstrap
,
enable_bootstrap
=
enable_disagg_bootstrap
,
)
as
prefill_workers
:
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
)
prefill_workers
.
__enter__
()
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
# Get unique port for this test
# Get unique port for this test
frontend_port
=
get_unique_ports
(
frontend_port
=
get_unique_ports
(
...
@@ -863,12 +830,6 @@ def test_router_decisions_disagg(
...
@@ -863,12 +830,6 @@ def test_router_decisions_disagg(
request_plane
=
"nats"
,
request_plane
=
"nats"
,
)
)
finally
:
if
decode_workers
is
not
None
:
decode_workers
.
__exit__
(
None
,
None
,
None
)
if
prefill_workers
is
not
None
:
prefill_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -903,16 +864,14 @@ def test_busy_threshold_endpoint(
...
@@ -903,16 +864,14 @@ def test_busy_threshold_endpoint(
"durable_kv_events"
:
durable_kv_events
,
"durable_kv_events"
:
durable_kv_events
,
}
}
try
:
with
MockerProcess
(
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
mockers
=
MockerProcess
(
request
,
request
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
,
num_mockers
=
NUM_MOCKERS
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
mockers
:
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
frontend_port
=
get_unique_ports
(
frontend_port
=
get_unique_ports
(
request
,
num_ports
=
1
,
request_plane
=
request_plane
request
,
num_ports
=
1
,
request_plane
=
request_plane
...
@@ -926,7 +885,3 @@ def test_busy_threshold_endpoint(
...
@@ -926,7 +885,3 @@ def test_busy_threshold_endpoint(
test_payload
=
TEST_PAYLOAD
,
test_payload
=
TEST_PAYLOAD
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
finally
:
if
"mockers"
in
locals
():
mockers
.
__exit__
(
None
,
None
,
None
)
tests/router/test_router_e2e_with_sglang.py
View file @
6fe2152b
...
@@ -342,18 +342,16 @@ def test_sglang_kv_router_basic(
...
@@ -342,18 +342,16 @@ def test_sglang_kv_router_basic(
f
"Starting SGLang KV router test with
{
N_SGLANG_WORKERS
}
workers using request_plane=
{
request_plane
}
"
f
"Starting SGLang KV router test with
{
N_SGLANG_WORKERS
}
workers using request_plane=
{
request_plane
}
"
)
)
try
:
with
SGLangProcess
(
# Start SGLang workers
logger
.
info
(
f
"Starting
{
N_SGLANG_WORKERS
}
SGLang workers"
)
sglang_workers
=
SGLangProcess
(
request
,
request
,
sglang_args
=
SGLANG_ARGS
,
sglang_args
=
SGLANG_ARGS
,
num_workers
=
N_SGLANG_WORKERS
,
num_workers
=
N_SGLANG_WORKERS
,
single_gpu
=
True
,
# fit workers into one GPU
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
sglang_workers
:
# Start SGLang workers
logger
.
info
(
f
"Starting
{
N_SGLANG_WORKERS
}
SGLang workers"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
sglang_workers
.
__enter__
()
# Run basic router test (starts router internally and waits for workers to be ready)
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port
=
allocate_frontend_ports
(
request
,
1
)[
0
]
frontend_port
=
allocate_frontend_ports
(
request
,
1
)[
0
]
...
@@ -369,10 +367,6 @@ def test_sglang_kv_router_basic(
...
@@ -369,10 +367,6 @@ def test_sglang_kv_router_basic(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
finally
:
if
"sglang_workers"
in
locals
():
sglang_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
...
@@ -390,21 +384,18 @@ def test_router_decisions_sglang_multiple_workers(
...
@@ -390,21 +384,18 @@ def test_router_decisions_sglang_multiple_workers(
logger
.
info
(
"Starting SGLang router prefix reuse test with two workers"
)
logger
.
info
(
"Starting SGLang router prefix reuse test with two workers"
)
N_WORKERS
=
2
N_WORKERS
=
2
try
:
with
SGLangProcess
(
# Start 2 worker processes on the same GPU
logger
.
info
(
"Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)"
)
sglang_workers
=
SGLangProcess
(
request
,
request
,
sglang_args
=
SGLANG_ARGS
,
sglang_args
=
SGLANG_ARGS
,
num_workers
=
N_WORKERS
,
num_workers
=
N_WORKERS
,
single_gpu
=
True
,
# Worker uses GPU 0
single_gpu
=
True
,
# Worker uses GPU 0
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
sglang_workers
:
# Start 2 worker processes on the same GPU
logger
.
info
(
"Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
# Initialize SGLang workers
# Initialize SGLang workers
sglang_workers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
namespace
=
runtime
.
namespace
(
sglang_workers
.
namespace
)
namespace
=
runtime
.
namespace
(
sglang_workers
.
namespace
)
...
@@ -415,11 +406,6 @@ def test_router_decisions_sglang_multiple_workers(
...
@@ -415,11 +406,6 @@ def test_router_decisions_sglang_multiple_workers(
sglang_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
False
sglang_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
False
)
)
finally
:
# Clean up SGLang workers
if
"sglang_workers"
in
locals
():
sglang_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
post_merge
@
pytest
.
mark
.
post_merge
...
@@ -442,18 +428,16 @@ def test_router_decisions_sglang_dp(
...
@@ -442,18 +428,16 @@ def test_router_decisions_sglang_dp(
N_WORKERS
=
1
N_WORKERS
=
1
DP_SIZE
=
2
DP_SIZE
=
2
try
:
with
SGLangProcess
(
logger
.
info
(
"Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)"
)
sglang_workers
=
SGLangProcess
(
request
,
request
,
sglang_args
=
SGLANG_ARGS
,
sglang_args
=
SGLANG_ARGS
,
num_workers
=
N_WORKERS
,
# Ignored when data_parallel_size is set
num_workers
=
N_WORKERS
,
# Ignored when data_parallel_size is set
single_gpu
=
False
,
single_gpu
=
False
,
data_parallel_size
=
DP_SIZE
,
# Creates DP_SIZE processes (one per rank)
data_parallel_size
=
DP_SIZE
,
# Creates DP_SIZE processes (one per rank)
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
sglang_workers
:
logger
.
info
(
"Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
sglang_workers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
...
@@ -466,11 +450,6 @@ def test_router_decisions_sglang_dp(
...
@@ -466,11 +450,6 @@ def test_router_decisions_sglang_dp(
sglang_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
True
sglang_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
True
)
)
finally
:
# Clean up SGLang workers
if
"sglang_workers"
in
locals
():
sglang_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
...
@@ -511,10 +490,7 @@ def test_sglang_indexers_sync(
...
@@ -511,10 +490,7 @@ def test_sglang_indexers_sync(
N_SGLANG_WORKERS
=
2
N_SGLANG_WORKERS
=
2
try
:
with
SGLangProcess
(
# Start SGLang workers
logger
.
info
(
f
"Starting
{
N_SGLANG_WORKERS
}
SGLang workers"
)
sglang_workers
=
SGLangProcess
(
request
,
request
,
sglang_args
=
SGLANG_ARGS
,
sglang_args
=
SGLANG_ARGS
,
num_workers
=
N_SGLANG_WORKERS
,
num_workers
=
N_SGLANG_WORKERS
,
...
@@ -522,9 +498,10 @@ def test_sglang_indexers_sync(
...
@@ -522,9 +498,10 @@ def test_sglang_indexers_sync(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
durable_kv_events
=
durable_kv_events
,
durable_kv_events
=
durable_kv_events
,
)
)
as
sglang_workers
:
# Start SGLang workers
logger
.
info
(
f
"Starting
{
N_SGLANG_WORKERS
}
SGLang workers"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
sglang_workers
.
__enter__
()
# Use the common test implementation (creates its own runtimes for each router)
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...
@@ -542,7 +519,3 @@ def test_sglang_indexers_sync(
...
@@ -542,7 +519,3 @@ def test_sglang_indexers_sync(
)
)
logger
.
info
(
"SGLang indexers sync test completed successfully"
)
logger
.
info
(
"SGLang indexers sync test completed successfully"
)
finally
:
if
"sglang_workers"
in
locals
():
sglang_workers
.
__exit__
(
None
,
None
,
None
)
tests/router/test_router_e2e_with_trtllm.py
View file @
6fe2152b
...
@@ -332,18 +332,16 @@ def test_trtllm_kv_router_basic(
...
@@ -332,18 +332,16 @@ def test_trtllm_kv_router_basic(
f
"Starting TRT-LLM KV router test with
{
N_TRTLLM_WORKERS
}
workers using request_plane=
{
request_plane
}
"
f
"Starting TRT-LLM KV router test with
{
N_TRTLLM_WORKERS
}
workers using request_plane=
{
request_plane
}
"
)
)
try
:
with
TRTLLMProcess
(
# Start TRT-LLM workers
logger
.
info
(
f
"Starting
{
N_TRTLLM_WORKERS
}
TRT-LLM workers"
)
trtllm_workers
=
TRTLLMProcess
(
request
,
request
,
trtllm_args
=
TRTLLM_ARGS
,
trtllm_args
=
TRTLLM_ARGS
,
num_workers
=
N_TRTLLM_WORKERS
,
num_workers
=
N_TRTLLM_WORKERS
,
single_gpu
=
True
,
# fit workers into one GPU
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
trtllm_workers
:
# Start TRT-LLM workers
logger
.
info
(
f
"Starting
{
N_TRTLLM_WORKERS
}
TRT-LLM workers"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
trtllm_workers
.
__enter__
()
# Run basic router test (starts router internally and waits for workers to be ready)
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port
=
allocate_frontend_ports
(
request
,
1
)[
0
]
frontend_port
=
allocate_frontend_ports
(
request
,
1
)[
0
]
...
@@ -359,10 +357,6 @@ def test_trtllm_kv_router_basic(
...
@@ -359,10 +357,6 @@ def test_trtllm_kv_router_basic(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
finally
:
if
"trtllm_workers"
in
locals
():
trtllm_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
nightly
...
@@ -392,19 +386,17 @@ def test_router_decisions_trtllm_attention_dp(
...
@@ -392,19 +386,17 @@ def test_router_decisions_trtllm_attention_dp(
"tensor_parallel_size"
:
N_ATTENTION_DP_RANKS
,
"tensor_parallel_size"
:
N_ATTENTION_DP_RANKS
,
}
}
try
:
with
TRTLLMProcess
(
logger
.
info
(
f
"Starting 1 TRT-LLM worker with attention DP enabled (attention_dp_size=
{
N_ATTENTION_DP_RANKS
}
)"
)
trtllm_workers
=
TRTLLMProcess
(
request
,
request
,
trtllm_args
=
TRTLLM_ADP_ARGS
,
trtllm_args
=
TRTLLM_ADP_ARGS
,
num_workers
=
N_TRTLLM_WORKERS
,
num_workers
=
N_TRTLLM_WORKERS
,
single_gpu
=
False
,
single_gpu
=
False
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
as
trtllm_workers
:
logger
.
info
(
f
"Starting 1 TRT-LLM worker with attention DP enabled (attention_dp_size=
{
N_ATTENTION_DP_RANKS
}
)"
)
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
trtllm_workers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
...
@@ -422,11 +414,6 @@ def test_router_decisions_trtllm_attention_dp(
...
@@ -422,11 +414,6 @@ def test_router_decisions_trtllm_attention_dp(
block_size
=
TRTLLM_BLOCK_SIZE
,
block_size
=
TRTLLM_BLOCK_SIZE
,
)
)
finally
:
# Clean up TRTLLM workers
if
"trtllm_workers"
in
locals
():
trtllm_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
...
@@ -443,23 +430,20 @@ def test_router_decisions_trtllm_multiple_workers(
...
@@ -443,23 +430,20 @@ def test_router_decisions_trtllm_multiple_workers(
logger
.
info
(
"Starting TRT-LLM router prefix reuse test with two workers"
)
logger
.
info
(
"Starting TRT-LLM router prefix reuse test with two workers"
)
N_WORKERS
=
2
N_WORKERS
=
2
try
:
with
TRTLLMProcess
(
# Start 2 worker processes on the same GPU
logger
.
info
(
"Starting 2 TRT-LLM worker processes on single GPU (gpu_mem_frac=0.4)"
)
trtllm_workers
=
TRTLLMProcess
(
request
,
request
,
trtllm_args
=
TRTLLM_ARGS
,
trtllm_args
=
TRTLLM_ARGS
,
num_workers
=
N_WORKERS
,
num_workers
=
N_WORKERS
,
single_gpu
=
True
,
# Worker uses GPU 0
single_gpu
=
True
,
# Worker uses GPU 0
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
as
trtllm_workers
:
# Start 2 worker processes on the same GPU
logger
.
info
(
"Starting 2 TRT-LLM worker processes on single GPU (gpu_mem_frac=0.4)"
)
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
# Initialize TRT-LLM workers
# Initialize TRT-LLM workers
trtllm_workers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
namespace
=
runtime
.
namespace
(
trtllm_workers
.
namespace
)
namespace
=
runtime
.
namespace
(
trtllm_workers
.
namespace
)
...
@@ -475,11 +459,6 @@ def test_router_decisions_trtllm_multiple_workers(
...
@@ -475,11 +459,6 @@ def test_router_decisions_trtllm_multiple_workers(
block_size
=
TRTLLM_BLOCK_SIZE
,
block_size
=
TRTLLM_BLOCK_SIZE
,
)
)
finally
:
# Clean up TRT-LLM workers
if
"trtllm_workers"
in
locals
():
trtllm_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
...
@@ -520,10 +499,7 @@ def test_trtllm_indexers_sync(
...
@@ -520,10 +499,7 @@ def test_trtllm_indexers_sync(
N_TRTLLM_WORKERS
=
2
N_TRTLLM_WORKERS
=
2
try
:
with
TRTLLMProcess
(
# Start TRT-LLM workers
logger
.
info
(
f
"Starting
{
N_TRTLLM_WORKERS
}
TRT-LLM workers"
)
trtllm_workers
=
TRTLLMProcess
(
request
,
request
,
trtllm_args
=
TRTLLM_ARGS
,
trtllm_args
=
TRTLLM_ARGS
,
num_workers
=
N_TRTLLM_WORKERS
,
num_workers
=
N_TRTLLM_WORKERS
,
...
@@ -531,9 +507,10 @@ def test_trtllm_indexers_sync(
...
@@ -531,9 +507,10 @@ def test_trtllm_indexers_sync(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
durable_kv_events
=
durable_kv_events
,
durable_kv_events
=
durable_kv_events
,
)
)
as
trtllm_workers
:
# Start TRT-LLM workers
logger
.
info
(
f
"Starting
{
N_TRTLLM_WORKERS
}
TRT-LLM workers"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
trtllm_workers
.
__enter__
()
# Use the common test implementation (creates its own runtimes for each router)
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...
@@ -551,7 +528,3 @@ def test_trtllm_indexers_sync(
...
@@ -551,7 +528,3 @@ def test_trtllm_indexers_sync(
)
)
logger
.
info
(
"TRT-LLM indexers sync test completed successfully"
)
logger
.
info
(
"TRT-LLM indexers sync test completed successfully"
)
finally
:
if
"trtllm_workers"
in
locals
():
trtllm_workers
.
__exit__
(
None
,
None
,
None
)
tests/router/test_router_e2e_with_vllm.py
View file @
6fe2152b
...
@@ -354,18 +354,16 @@ def test_vllm_kv_router_basic(
...
@@ -354,18 +354,16 @@ def test_vllm_kv_router_basic(
f
"Starting vLLM KV router test with
{
N_VLLM_WORKERS
}
workers using request_plane=
{
request_plane
}
"
f
"Starting vLLM KV router test with
{
N_VLLM_WORKERS
}
workers using request_plane=
{
request_plane
}
"
)
)
try
:
with
VLLMProcess
(
# Start vLLM workers
logger
.
info
(
f
"Starting
{
N_VLLM_WORKERS
}
vLLM workers"
)
vllm_workers
=
VLLMProcess
(
request
,
request
,
vllm_args
=
VLLM_ARGS
,
vllm_args
=
VLLM_ARGS
,
num_workers
=
N_VLLM_WORKERS
,
num_workers
=
N_VLLM_WORKERS
,
single_gpu
=
True
,
# fit workers into one GPU
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
vllm_workers
:
# Start vLLM workers
logger
.
info
(
f
"Starting
{
N_VLLM_WORKERS
}
vLLM workers"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
vllm_workers
.
__enter__
()
# Run basic router test (starts router internally and waits for workers to be ready)
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port
=
allocate_frontend_ports
(
request
,
1
)[
0
]
frontend_port
=
allocate_frontend_ports
(
request
,
1
)[
0
]
...
@@ -381,10 +379,6 @@ def test_vllm_kv_router_basic(
...
@@ -381,10 +379,6 @@ def test_vllm_kv_router_basic(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
finally
:
if
"vllm_workers"
in
locals
():
vllm_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
...
@@ -401,21 +395,17 @@ def test_router_decisions_vllm_multiple_workers(
...
@@ -401,21 +395,17 @@ def test_router_decisions_vllm_multiple_workers(
logger
.
info
(
"Starting vLLM router prefix reuse test with two workers"
)
logger
.
info
(
"Starting vLLM router prefix reuse test with two workers"
)
N_WORKERS
=
2
N_WORKERS
=
2
try
:
with
VLLMProcess
(
# Start 2 worker processes on the same GPU
logger
.
info
(
"Starting 2 vLLM worker processes on single GPU (gpu_mem=0.4)"
)
vllm_workers
=
VLLMProcess
(
request
,
request
,
vllm_args
=
VLLM_ARGS
,
vllm_args
=
VLLM_ARGS
,
num_workers
=
N_WORKERS
,
num_workers
=
N_WORKERS
,
single_gpu
=
True
,
# Worker uses GPU 0
single_gpu
=
True
,
# Worker uses GPU 0
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
vllm_workers
:
# Start 2 worker processes on the same GPU
logger
.
info
(
"Starting 2 vLLM worker processes on single GPU (gpu_mem=0.4)"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
# Initialize vLLM workers
vllm_workers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
namespace
=
runtime
.
namespace
(
vllm_workers
.
namespace
)
namespace
=
runtime
.
namespace
(
vllm_workers
.
namespace
)
...
@@ -426,11 +416,6 @@ def test_router_decisions_vllm_multiple_workers(
...
@@ -426,11 +416,6 @@ def test_router_decisions_vllm_multiple_workers(
vllm_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
False
vllm_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
False
)
)
finally
:
# Clean up vLLM workers
if
"vllm_workers"
in
locals
():
vllm_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
nightly
...
@@ -453,18 +438,16 @@ def test_router_decisions_vllm_dp(
...
@@ -453,18 +438,16 @@ def test_router_decisions_vllm_dp(
N_WORKERS
=
1
N_WORKERS
=
1
DP_SIZE
=
2
DP_SIZE
=
2
try
:
with
VLLMProcess
(
logger
.
info
(
"Starting 2 vLLM DP ranks (dp_size=2) (gpu_mem=0.4)"
)
vllm_workers
=
VLLMProcess
(
request
,
request
,
vllm_args
=
VLLM_ARGS
,
vllm_args
=
VLLM_ARGS
,
num_workers
=
N_WORKERS
,
# Ignored when data_parallel_size is set
num_workers
=
N_WORKERS
,
# Ignored when data_parallel_size is set
single_gpu
=
False
,
single_gpu
=
False
,
data_parallel_size
=
DP_SIZE
,
# Creates DP_SIZE processes (one per rank)
data_parallel_size
=
DP_SIZE
,
# Creates DP_SIZE processes (one per rank)
request_plane
=
request_plane
,
request_plane
=
request_plane
,
)
)
as
vllm_workers
:
logger
.
info
(
"Starting 2 vLLM DP ranks (dp_size=2) (gpu_mem=0.4)"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
vllm_workers
.
__enter__
()
# Get runtime and create endpoint
# Get runtime and create endpoint
runtime
=
get_runtime
(
request_plane
=
request_plane
)
runtime
=
get_runtime
(
request_plane
=
request_plane
)
...
@@ -477,11 +460,6 @@ def test_router_decisions_vllm_dp(
...
@@ -477,11 +460,6 @@ def test_router_decisions_vllm_dp(
vllm_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
True
vllm_workers
,
endpoint
,
MODEL_NAME
,
request
,
test_dp_rank
=
True
)
)
finally
:
# Clean up vLLM workers
if
"vllm_workers"
in
locals
():
vllm_workers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
...
@@ -522,10 +500,7 @@ def test_vllm_indexers_sync(
...
@@ -522,10 +500,7 @@ def test_vllm_indexers_sync(
N_VLLM_WORKERS
=
2
N_VLLM_WORKERS
=
2
try
:
with
VLLMProcess
(
# Start vLLM workers
logger
.
info
(
f
"Starting
{
N_VLLM_WORKERS
}
vLLM workers"
)
vllm_workers
=
VLLMProcess
(
request
,
request
,
vllm_args
=
VLLM_ARGS
,
vllm_args
=
VLLM_ARGS
,
num_workers
=
N_VLLM_WORKERS
,
num_workers
=
N_VLLM_WORKERS
,
...
@@ -533,9 +508,10 @@ def test_vllm_indexers_sync(
...
@@ -533,9 +508,10 @@ def test_vllm_indexers_sync(
request_plane
=
request_plane
,
request_plane
=
request_plane
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
durable_kv_events
=
durable_kv_events
,
durable_kv_events
=
durable_kv_events
,
)
)
as
vllm_workers
:
# Start vLLM workers
logger
.
info
(
f
"Starting
{
N_VLLM_WORKERS
}
vLLM workers"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
vllm_workers
.
__enter__
()
# Use the common test implementation (creates its own runtimes for each router)
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...
@@ -553,7 +529,3 @@ def test_vllm_indexers_sync(
...
@@ -553,7 +529,3 @@ def test_vllm_indexers_sync(
)
)
logger
.
info
(
"vLLM indexers sync test completed successfully"
)
logger
.
info
(
"vLLM indexers sync test completed successfully"
)
finally
:
if
"vllm_workers"
in
locals
():
vllm_workers
.
__exit__
(
None
,
None
,
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment