Unverified Commit 1b8869f5 authored by Keiven C's avatar Keiven C Committed by GitHub
Browse files

refactor: router tests to be pytest-xdist parallel safe (#5005)


Signed-off-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
Co-authored-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent 9be3df8f
...@@ -556,6 +556,11 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane): ...@@ -556,6 +556,11 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
It also sets the NATS_SERVER and ETCD_ENDPOINTS environment variables so that It also sets the NATS_SERVER and ETCD_ENDPOINTS environment variables so that
Dynamo processes can find the services on the dynamic ports. Dynamo processes can find the services on the dynamic ports.
xdist/parallel safety:
- Function-scoped: each test gets its own NATS/etcd instances and ports.
- Each pytest-xdist worker runs tests in a separate process, so env vars do not
leak across workers.
- If store_kv != "etcd", etcd is not started (returns None) - If store_kv != "etcd", etcd is not started (returns None)
- If request_plane != "nats", NATS is not started (returns None) - If request_plane != "nats", NATS is not started (returns None)
...@@ -598,7 +603,14 @@ def runtime_services_session(request, tmp_path_factory): ...@@ -598,7 +603,14 @@ def runtime_services_session(request, tmp_path_factory):
Uses file-based reference counting to coordinate between pytest-xdist worker processes. Uses file-based reference counting to coordinate between pytest-xdist worker processes.
Only the first worker starts services, and only the last worker tears them down. Only the first worker starts services, and only the last worker tears them down.
Test isolation is achieved through unique namespaces (test-namespace-{random-suffix}). WARNING: may not be parallel/xdist safe.
- This fixture shares one NATS + one etcd across many tests (and across xdist workers).
- It is only safe if tests fully isolate state (e.g. unique namespaces) and do not
assume exclusive access to global streams/keys/ports.
- Prefer `runtime_services_dynamic_ports` for true per-test isolation in parallel runs.
TODO: once nothing uses `runtime_services_session`, make the per-test dynamic ports
behavior the default for router/frontend integration tests.
""" """
with SharedNatsServer(request, tmp_path_factory) as nats: with SharedNatsServer(request, tmp_path_factory) as nats:
with SharedEtcdServer(request, tmp_path_factory) as etcd: with SharedEtcdServer(request, tmp_path_factory) as etcd:
......
...@@ -25,6 +25,11 @@ NUM_REQUESTS = 100 ...@@ -25,6 +25,11 @@ NUM_REQUESTS = 100
BLOCK_SIZE = 16 BLOCK_SIZE = 16
def _nats_server() -> str:
# Prefer dynamically-started NATS from per-test fixtures when present.
return os.environ.get("NATS_SERVER", "nats://localhost:4222")
######################################################## ########################################################
# Helper Classes # Helper Classes
######################################################## ########################################################
...@@ -394,7 +399,7 @@ async def check_nats_consumers(namespace: str, expected_count: Optional[int] = N ...@@ -394,7 +399,7 @@ async def check_nats_consumers(namespace: str, expected_count: Optional[int] = N
stream_name = f"{slugified}-kv-events" stream_name = f"{slugified}-kv-events"
logger.info(f"Checking consumers for stream: {stream_name}") logger.info(f"Checking consumers for stream: {stream_name}")
nc = await nats.connect("nats://localhost:4222") nc = await nats.connect(servers=_nats_server())
try: try:
js = nc.jetstream() js = nc.jetstream()
consumer_infos = await js.consumers_info(stream_name) consumer_infos = await js.consumers_info(stream_name)
...@@ -770,7 +775,7 @@ def _test_router_two_routers( ...@@ -770,7 +775,7 @@ def _test_router_two_routers(
logger.info(f"Checking consumers for stream: {stream_name}") logger.info(f"Checking consumers for stream: {stream_name}")
# Connect to NATS and list consumers # Connect to NATS and list consumers
nc = await nats.connect("nats://localhost:4222") nc = await nats.connect(servers=_nats_server())
try: try:
js = nc.jetstream() js = nc.jetstream()
...@@ -1529,29 +1534,30 @@ def _test_router_indexers_sync( ...@@ -1529,29 +1534,30 @@ def _test_router_indexers_sync(
logger.info(f"Verifying NATS object store bucket exists: {expected_bucket}") logger.info(f"Verifying NATS object store bucket exists: {expected_bucket}")
snapshot_verified = False snapshot_verified = False
# Connect to NATS and check object store. This honors per-test NATS instances
# started by fixtures (xdist-safe) instead of assuming localhost:4222.
nc = await nats.connect(servers=_nats_server())
try: try:
# Connect to NATS and check object store js = nc.jetstream()
nc = await nats.connect("nats://localhost:4222") obj_store = await js.object_store(expected_bucket)
try:
js = nc.jetstream()
obj_store = await js.object_store(expected_bucket)
# Try to get the expected file # Try to get the expected file
try: try:
result = await obj_store.get(expected_file) result = await obj_store.get(expected_file)
logger.info( logger.info(
f"✓ Snapshot file '{expected_file}' found in bucket '{expected_bucket}' " f"✓ Snapshot file '{expected_file}' found in bucket '{expected_bucket}' "
f"(size: {len(result.data) if result.data else 0} bytes)" f"(size: {len(result.data) if result.data else 0} bytes)"
) )
snapshot_verified = True snapshot_verified = True
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Snapshot file '{expected_file}' not found in bucket '{expected_bucket}': {e}" f"Snapshot file '{expected_file}' not found in bucket '{expected_bucket}': {e}"
) )
finally:
await nc.close()
except Exception as e: except Exception as e:
logger.error(f"Error checking NATS object store: {e}") logger.error(f"Error checking NATS object store: {e}")
finally:
await nc.close()
# Assert that snapshot was created (threshold=20, sent 25 requests) # Assert that snapshot was created (threshold=20, sent 25 requests)
if not snapshot_verified: if not snapshot_verified:
...@@ -1647,7 +1653,7 @@ def _test_router_indexers_sync( ...@@ -1647,7 +1653,7 @@ def _test_router_indexers_sync(
slugified = component_subject.lower().replace(".", "-").replace("_", "-") slugified = component_subject.lower().replace(".", "-").replace("_", "-")
stream_name = f"{slugified}-kv-events" stream_name = f"{slugified}-kv-events"
nc = await nats.connect("nats://localhost:4222") nc = await nats.connect(servers=_nats_server())
try: try:
js = nc.jetstream() js = nc.jetstream()
consumer_infos = await js.consumers_info(stream_name) consumer_infos = await js.consumers_info(stream_name)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Parallelization: Hermetic tests (xdist-safe via dynamic ports + per-test namespaces).
# Tested on: Linux container.
# Combined pre_merge wall time (this file):
# - Serialized: 304.01s.
# - Parallel (-n auto): 34.55s (269.46s saved, 8.80x).
import logging import logging
import os import os
from contextlib import nullcontext
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import pytest import pytest
from tests.conftest import EtcdServer, NatsServer
from tests.router.common import ( # utilities from tests.router.common import ( # utilities
_test_busy_threshold_endpoint, _test_busy_threshold_endpoint,
_test_python_router_bindings, _test_python_router_bindings,
...@@ -23,6 +27,7 @@ from tests.router.common import ( # utilities ...@@ -23,6 +27,7 @@ from tests.router.common import ( # utilities
) )
from tests.utils.constants import ROUTER_MODEL_NAME from tests.utils.constants import ROUTER_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,6 +37,7 @@ pytestmark = [ ...@@ -32,6 +37,7 @@ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.gpu_0, pytest.mark.gpu_0,
pytest.mark.integration, pytest.mark.integration,
pytest.mark.parallel,
pytest.mark.model(MODEL_NAME), pytest.mark.model(MODEL_NAME),
] ]
NUM_MOCKERS = 2 NUM_MOCKERS = 2
...@@ -48,48 +54,20 @@ def get_unique_ports( ...@@ -48,48 +54,20 @@ def get_unique_ports(
request_plane: str = "nats", request_plane: str = "nats",
registration_order: str = "prefill_first", registration_order: str = "prefill_first",
) -> list[int]: ) -> list[int]:
"""Generate unique ports for parallel test execution. """Allocate random free ports for xdist-safe router tests.
Ports are unique based on:
- Test function name (each test gets a base offset)
- Parametrization value (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
- Port index (for multi-port tests)
Args: This replaces the previous "test-name offset" scheme with the shared flock-backed
request: Pytest request fixture allocator from `tests.utils.port_utils`, which avoids collisions across pytest-xdist
num_ports: Number of ports needed (1 for single router, 2 for two routers) worker processes.
store_backend: Storage backend parameter ("etcd" or "file")
request_plane: Request plane parameter ("nats" or "tcp")
registration_order: Registration order parameter ("prefill_first" or "decode_first")
Returns: Notes:
List of unique port numbers - The extra parameters are kept for call-site compatibility (they no longer affect
the chosen ports).
- Ports are released at the end of the test via a pytest finalizer.
""" """
# Get test name without parametrization suffix _ = (store_backend, request_plane, registration_order)
test_name = request.node.name.split("[")[0] ports = allocate_ports(num_ports, BASE_PORT)
request.addfinalizer(lambda: deallocate_ports(ports))
# Base offsets per test function (ensures each test gets unique range)
test_offsets = {
"test_mocker_kv_router": 0,
"test_mocker_two_kv_router": 100,
"test_mocker_kv_router_overload_503": 200,
"test_query_instance_id_returns_worker_and_tokens": 300,
"test_router_decisions_disagg": 400,
"test_busy_threshold_endpoint": 500,
}
base_offset = test_offsets.get(test_name, 0)
# Parametrization offset (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
store_offset = 0 if store_backend == "etcd" else 50
plane_offset = 0 if request_plane == "nats" else 25
order_offset = 0 if registration_order == "prefill_first" else 10
# Generate ports
ports = [
BASE_PORT + base_offset + store_offset + plane_offset + order_offset + i
for i in range(num_ports)
]
return ports return ports
...@@ -306,8 +284,10 @@ class DisaggMockerProcess: ...@@ -306,8 +284,10 @@ class DisaggMockerProcess:
self._process.__exit__(exc_type, exc_val, exc_tb) self._process.__exit__(exc_type, exc_val, exc_tb)
@pytest.mark.parallel @pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
def test_mocker_kv_router(request, runtime_services_session, predownload_tokenizers): def test_mocker_kv_router(
request, runtime_services_dynamic_ports, predownload_tokenizers
):
""" """
Test KV router with multiple mocker engine instances. Test KV router with multiple mocker engine instances.
This test doesn't require GPUs and runs quickly for pre-merge validation. This test doesn't require GPUs and runs quickly for pre-merge validation.
...@@ -316,7 +296,7 @@ def test_mocker_kv_router(request, runtime_services_session, predownload_tokeniz ...@@ -316,7 +296,7 @@ def test_mocker_kv_router(request, runtime_services_session, predownload_tokeniz
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting mocker KV router test") logger.info("Starting mocker KV router test")
# Create mocker args dictiona: FixtureRequestry: tuple[NatsServer, EtcdServer]: NoneType # Create mocker args dictionary
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
try: try:
...@@ -346,11 +326,11 @@ def test_mocker_kv_router(request, runtime_services_session, predownload_tokeniz ...@@ -346,11 +326,11 @@ def test_mocker_kv_router(request, runtime_services_session, predownload_tokeniz
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parallel
@pytest.mark.parametrize("store_backend", ["etcd", "file"]) @pytest.mark.parametrize("store_backend", ["etcd", "file"])
@pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up
def test_mocker_two_kv_router( def test_mocker_two_kv_router(
request, request,
runtime_services_session, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
file_storage_backend, file_storage_backend,
store_backend, store_backend,
...@@ -402,10 +382,10 @@ def test_mocker_two_kv_router( ...@@ -402,10 +382,10 @@ def test_mocker_two_kv_router(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parallel
@pytest.mark.skip(reason="Flaky, temporarily disabled") @pytest.mark.skip(reason="Flaky, temporarily disabled")
@pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up (when enabled)
def test_mocker_kv_router_overload_503( def test_mocker_kv_router_overload_503(
request, runtime_services_session, predownload_tokenizers request, runtime_services_dynamic_ports, predownload_tokenizers
): ):
"""Test that KV router returns 503 when mocker workers are overloaded.""" """Test that KV router returns 503 when mocker workers are overloaded."""
logger.info("Starting mocker KV router overload test for 503 status") logger.info("Starting mocker KV router overload test for 503 status")
...@@ -441,9 +421,9 @@ def test_mocker_kv_router_overload_503( ...@@ -441,9 +421,9 @@ def test_mocker_kv_router_overload_503(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parallel @pytest.mark.timeout(22) # ~3x average (~7.10s), rounded up
def test_kv_push_router_bindings( def test_kv_push_router_bindings(
request, runtime_services_session, predownload_tokenizers request, runtime_services_dynamic_ports, predownload_tokenizers
): ):
"""Test KvPushRouter Python bindings with mocker engines.""" """Test KvPushRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test") logger.info("Starting KvPushRouter bindings test")
...@@ -488,8 +468,10 @@ def test_kv_push_router_bindings( ...@@ -488,8 +468,10 @@ def test_kv_push_router_bindings(
], ],
ids=["jetstream", "file"], # "nats_core" commented out to match commented test case ids=["jetstream", "file"], # "nats_core" commented out to match commented test case
) )
@pytest.mark.timeout(27) # ~3x average (~8.93s), rounded up
def test_indexers_sync( def test_indexers_sync(
request, request,
runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
file_storage_backend, file_storage_backend,
store_backend, store_backend,
...@@ -511,54 +493,52 @@ def test_indexers_sync( ...@@ -511,54 +493,52 @@ def test_indexers_sync(
f"use_nats_core={use_nats_core}, request_plane={request_plane}" f"use_nats_core={use_nats_core}, request_plane={request_plane}"
) )
# Start NATS manually (needed for all variants - KV event sync) # Use the dynamic-port fixture to avoid hardcoded localhost:4222/2379 in parallel runs.
with NatsServer(request) as nats_server: nats_process, _etcd_process = runtime_services_dynamic_ports
# Start etcd if needed
etcd_ctx = EtcdServer(request) if store_backend == "etcd" else nullcontext() # Create mocker args dictionary
with etcd_ctx: mocker_args = {
# Create mocker args dictionary "speedup_ratio": SPEEDUP_RATIO,
mocker_args = { "block_size": BLOCK_SIZE,
"speedup_ratio": SPEEDUP_RATIO, "enable_local_indexer": use_nats_core,
"block_size": BLOCK_SIZE, }
"enable_local_indexer": use_nats_core,
} try:
# Start mocker instances
try: logger.info(f"Starting {NUM_MOCKERS} mocker instances")
# Start mocker instances mockers = MockerProcess(
logger.info(f"Starting {NUM_MOCKERS} mocker instances") request,
mockers = MockerProcess( mocker_args=mocker_args,
request, num_mockers=NUM_MOCKERS,
mocker_args=mocker_args, store_backend=store_backend,
num_mockers=NUM_MOCKERS, request_plane=request_plane,
store_backend=store_backend, )
request_plane=request_plane, logger.info(f"All mockers using endpoint: {mockers.endpoint}")
) mockers.__enter__()
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# Use the common test implementation (creates its own runtimes for each router) _test_router_indexers_sync(
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive engine_workers=mockers,
_test_router_indexers_sync( block_size=BLOCK_SIZE,
engine_workers=mockers, model_name=MODEL_NAME,
block_size=BLOCK_SIZE, num_workers=NUM_MOCKERS,
model_name=MODEL_NAME, store_backend=store_backend,
num_workers=NUM_MOCKERS, request_plane=request_plane,
store_backend=store_backend, test_nats_interruption=use_nats_core,
request_plane=request_plane, nats_server=nats_process if use_nats_core else None,
test_nats_interruption=use_nats_core, )
nats_server=nats_server if use_nats_core else None,
) logger.info("Indexers sync test completed successfully")
logger.info("Indexers sync test completed successfully") finally:
if "mockers" in locals():
finally: mockers.__exit__(None, None, None)
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
@pytest.mark.parallel
def test_query_instance_id_returns_worker_and_tokens( def test_query_instance_id_returns_worker_and_tokens(
request, runtime_services_session, predownload_tokenizers request, runtime_services_dynamic_ports, predownload_tokenizers
): ):
"""Test query_instance_id annotation with mocker engines.""" """Test query_instance_id annotation with mocker engines."""
logger.info("Starting KV router query_instance_id annotation test") logger.info("Starting KV router query_instance_id annotation test")
...@@ -591,10 +571,10 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -591,10 +571,10 @@ def test_query_instance_id_returns_worker_and_tokens(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parallel
@pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"]) @pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"])
@pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up
def test_router_decisions( def test_router_decisions(
request, runtime_services_session, predownload_tokenizers, use_nats_core request, runtime_services_dynamic_ports, predownload_tokenizers, use_nats_core
): ):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
...@@ -641,10 +621,10 @@ def test_router_decisions( ...@@ -641,10 +621,10 @@ def test_router_decisions(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parallel
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"]) @pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.timeout(59) # ~3x average (~19.51s), rounded up
def test_router_decisions_disagg( def test_router_decisions_disagg(
request, runtime_services_session, predownload_tokenizers, registration_order request, runtime_services_dynamic_ports, predownload_tokenizers, registration_order
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup. """Validate KV cache prefix reuse in disaggregated prefill-decode setup.
...@@ -742,10 +722,10 @@ def test_router_decisions_disagg( ...@@ -742,10 +722,10 @@ def test_router_decisions_disagg(
prefill_workers.__exit__(None, None, None) prefill_workers.__exit__(None, None, None)
@pytest.mark.parallel
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.timeout(39) # ~3x average (~12.84s), rounded up
def test_busy_threshold_endpoint( def test_busy_threshold_endpoint(
request, runtime_services_session, predownload_tokenizers, request_plane request, runtime_services_dynamic_ports, predownload_tokenizers, request_plane
): ):
"""Test that the /busy_threshold endpoint can be hit and responds correctly. """Test that the /busy_threshold endpoint can be hit and responds correctly.
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Timing notes (measured in an SGLang-enabled container):
# - GPU-1 subset (`-m "gpu_1"`): 92.35s total for 2 tests (+ 1 skipped).
# These tests load a real model and can be slow/flaky when GPU resources are contended,
# so we set explicit pytest timeouts to fail fast on hangs (see per-test markers below).
import logging import logging
import os import os
import time import time
...@@ -14,7 +19,9 @@ from tests.router.common import ( # utilities ...@@ -14,7 +19,9 @@ from tests.router.common import ( # utilities
generate_random_suffix, generate_random_suffix,
get_runtime, get_runtime,
) )
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -26,13 +33,17 @@ pytestmark = [ ...@@ -26,13 +33,17 @@ pytestmark = [
pytest.mark.model(MODEL_NAME), pytest.mark.model(MODEL_NAME),
] ]
SPEEDUP_RATIO = 10.0 SPEEDUP_RATIO = 10.0
PORTS = [
8011,
8022,
] # Frontend ports: use PORTS[0] for single router, PORTS for multi-router
NUM_REQUESTS = 10 NUM_REQUESTS = 10
PAGE_SIZE = 16 # SGLang uses "page_size" instead of "block_size" PAGE_SIZE = 16 # SGLang uses "page_size" instead of "block_size"
def allocate_frontend_ports(request, count: int) -> list[int]:
"""Allocate random free frontend ports for xdist-safe execution."""
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
# Shared test payload for all tests # Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = { TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -291,8 +302,9 @@ class SGLangProcess: ...@@ -291,8 +302,9 @@ class SGLangProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up
def test_sglang_kv_router_basic( def test_sglang_kv_router_basic(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Quick e2e sanity test for KV router with SGLang engine instances. Quick e2e sanity test for KV router with SGLang engine instances.
...@@ -315,11 +327,12 @@ def test_sglang_kv_router_basic( ...@@ -315,11 +327,12 @@ def test_sglang_kv_router_basic(
sglang_workers.__enter__() sglang_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic( _test_router_basic(
engine_workers=sglang_workers, engine_workers=sglang_workers,
block_size=PAGE_SIZE, block_size=PAGE_SIZE,
request=request, request=request,
frontend_port=PORTS[0], frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
...@@ -336,7 +349,7 @@ def test_sglang_kv_router_basic( ...@@ -336,7 +349,7 @@ def test_sglang_kv_router_basic(
@pytest.mark.skip(reason="Broken by sglang changes") @pytest.mark.skip(reason="Broken by sglang changes")
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged # TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
def test_router_decisions_sglang_multiple_workers( def test_router_decisions_sglang_multiple_workers(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting SGLang router prefix reuse test with two workers") logger.info("Starting SGLang router prefix reuse test with two workers")
...@@ -373,8 +386,9 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -373,8 +386,9 @@ def test_router_decisions_sglang_multiple_workers(
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance)
def test_router_decisions_sglang_dp( def test_router_decisions_sglang_dp(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
"""Validate KV cache prefix reuse with SGLang by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse with SGLang by sending progressive requests with overlapping prefixes.
Same flow as test_router_decisions_sglang_multiple_workers; force first request to (worker_id, dp_rank=1). Same flow as test_router_decisions_sglang_multiple_workers; force first request to (worker_id, dp_rank=1).
...@@ -417,8 +431,9 @@ def test_router_decisions_sglang_dp( ...@@ -417,8 +431,9 @@ def test_router_decisions_sglang_dp(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up
def test_sglang_indexers_sync( def test_sglang_indexers_sync(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests Test that two KV routers have synchronized indexer states after processing requests
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Timing notes (measured in a TRT-LLM-enabled container):
# - GPU-1 subset (`-m "gpu_1"`): 136.36s total for 3 tests.
# These tests load a real model and can be slow/flaky when GPU resources are contended,
# so we set explicit pytest timeouts to fail fast on hangs (see per-test markers below).
import logging import logging
import os import os
import time import time
...@@ -14,7 +19,9 @@ from tests.router.common import ( # utilities ...@@ -14,7 +19,9 @@ from tests.router.common import ( # utilities
generate_random_suffix, generate_random_suffix,
get_runtime, get_runtime,
) )
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -26,12 +33,16 @@ pytestmark = [ ...@@ -26,12 +33,16 @@ pytestmark = [
pytest.mark.trtllm, pytest.mark.trtllm,
pytest.mark.model(MODEL_NAME), pytest.mark.model(MODEL_NAME),
] ]
PORTS = [
8011,
8022,
] # Frontend ports: use PORTS[0] for single router, PORTS for multi-router
NUM_REQUESTS = 10 NUM_REQUESTS = 10
def allocate_frontend_ports(request, count: int) -> list[int]:
"""Allocate random free frontend ports for xdist-safe execution."""
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
# Shared test payload for all tests # Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = { TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -265,8 +276,9 @@ class TRTLLMProcess: ...@@ -265,8 +276,9 @@ class TRTLLMProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_trtllm_kv_router_basic( def test_trtllm_kv_router_basic(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Quick e2e sanity test for KV router with TRT-LLM engine instances. Quick e2e sanity test for KV router with TRT-LLM engine instances.
...@@ -289,11 +301,12 @@ def test_trtllm_kv_router_basic( ...@@ -289,11 +301,12 @@ def test_trtllm_kv_router_basic(
trtllm_workers.__enter__() trtllm_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic( _test_router_basic(
engine_workers=trtllm_workers, engine_workers=trtllm_workers,
block_size=TRTLLM_BLOCK_SIZE, block_size=TRTLLM_BLOCK_SIZE,
request=request, request=request,
frontend_port=PORTS[0], frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
...@@ -307,8 +320,9 @@ def test_trtllm_kv_router_basic( ...@@ -307,8 +320,9 @@ def test_trtllm_kv_router_basic(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_router_decisions_trtllm_multiple_workers( def test_router_decisions_trtllm_multiple_workers(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting TRT-LLM router prefix reuse test with two workers") logger.info("Starting TRT-LLM router prefix reuse test with two workers")
...@@ -353,8 +367,9 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -353,8 +367,9 @@ def test_router_decisions_trtllm_multiple_workers(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_trtllm_indexers_sync( def test_trtllm_indexers_sync(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests Test that two KV routers have synchronized indexer states after processing requests
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Timing notes (measured locally):
# - GPU-1 subset (`-m "gpu_1 and not gpu_2"`): 130.43s total for 3 tests.
# These tests load a real model and can be slow/flaky when GPU resources are contended,
# so we set explicit pytest timeouts to fail fast on hangs (see per-test markers below).
import logging import logging
import os import os
import time import time
...@@ -14,7 +19,9 @@ from tests.router.common import ( # utilities ...@@ -14,7 +19,9 @@ from tests.router.common import ( # utilities
generate_random_suffix, generate_random_suffix,
get_runtime, get_runtime,
) )
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -26,13 +33,17 @@ pytestmark = [ ...@@ -26,13 +33,17 @@ pytestmark = [
pytest.mark.model(MODEL_NAME), pytest.mark.model(MODEL_NAME),
] ]
SPEEDUP_RATIO = 10.0 SPEEDUP_RATIO = 10.0
PORTS = [
8011,
8022,
] # Frontend ports: use PORTS[0] for single router, PORTS for multi-router
NUM_REQUESTS = 10 NUM_REQUESTS = 10
BLOCK_SIZE = 16 BLOCK_SIZE = 16
def allocate_frontend_ports(request, count: int) -> list[int]:
"""Allocate random free frontend ports for xdist-safe execution."""
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
# Shared test payload for all tests # Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = { TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -306,8 +317,9 @@ class VLLMProcess: ...@@ -306,8 +317,9 @@ class VLLMProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
def test_vllm_kv_router_basic( def test_vllm_kv_router_basic(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Quick e2e sanity test for KV router with vLLM engine instances. Quick e2e sanity test for KV router with vLLM engine instances.
...@@ -330,11 +342,12 @@ def test_vllm_kv_router_basic( ...@@ -330,11 +342,12 @@ def test_vllm_kv_router_basic(
vllm_workers.__enter__() vllm_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic( _test_router_basic(
engine_workers=vllm_workers, engine_workers=vllm_workers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
request=request, request=request,
frontend_port=PORTS[0], frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
...@@ -348,8 +361,9 @@ def test_vllm_kv_router_basic( ...@@ -348,8 +361,9 @@ def test_vllm_kv_router_basic(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
def test_router_decisions_vllm_multiple_workers( def test_router_decisions_vllm_multiple_workers(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting vLLM router prefix reuse test with two workers") logger.info("Starting vLLM router prefix reuse test with two workers")
...@@ -386,8 +400,9 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -386,8 +400,9 @@ def test_router_decisions_vllm_multiple_workers(
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance)
def test_router_decisions_vllm_dp( def test_router_decisions_vllm_dp(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
"""Validate KV cache prefix reuse with vLLM by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse with vLLM by sending progressive requests with overlapping prefixes.
Same flow as test_router_decisions_vllm_multiple_workers; force first request to (worker_id, dp_rank=1). Same flow as test_router_decisions_vllm_multiple_workers; force first request to (worker_id, dp_rank=1).
...@@ -430,8 +445,9 @@ def test_router_decisions_vllm_dp( ...@@ -430,8 +445,9 @@ def test_router_decisions_vllm_dp(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
def test_vllm_indexers_sync( def test_vllm_indexers_sync(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests Test that two KV routers have synchronized indexer states after processing requests
......
...@@ -602,10 +602,6 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -602,10 +602,6 @@ class DynamoFrontendProcess(ManagedProcess):
# - tests/frontend/test_completion_mocker_engine.py # - tests/frontend/test_completion_mocker_engine.py
# - tests/frontend/grpc/test_tensor_parameters.py # - tests/frontend/grpc/test_tensor_parameters.py
# - tests/frontend/grpc/test_tensor_mocker_engine.py # - tests/frontend/grpc/test_tensor_mocker_engine.py
# - tests/router/common.py
# - tests/router/test_router_e2e_with_vllm.py
# - tests/router/test_router_e2e_with_sglang.py
# - tests/router/test_router_e2e_with_trtllm.py
# - tests/fault_tolerance/cancellation/utils.py # - tests/fault_tolerance/cancellation/utils.py
# - tests/fault_tolerance/migration/utils.py # - tests/fault_tolerance/migration/utils.py
# - tests/fault_tolerance/etcd_ha/utils.py # - tests/fault_tolerance/etcd_ha/utils.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment