Unverified Commit 6fe2152b authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

test: refactor router e2e tests to use context managers for process lifecycle (#6088)

parent 1cd3b724
......@@ -528,6 +528,7 @@ async def send_request_via_python_kv_router(
)
# Retry loop sending request to worker with exponential backoff
stream = None
for attempt in range(max_retries + 1):
try:
logger.debug(f"Sending request to {log_message} (attempt {attempt + 1})")
......@@ -557,6 +558,11 @@ async def send_request_via_python_kv_router(
f"Failed to connect to workers after {max_retries + 1} attempts"
) from e
if stream is None:
raise RuntimeError(
f"Failed to get a valid stream from workers after {max_retries + 1} attempts"
)
# Collect tokens and worker IDs from the SSE stream
generated_tokens = []
prefill_worker_id: Optional[int] = None
......@@ -653,18 +659,16 @@ def _test_router_basic(
AssertionError: If requests fail or frontend doesn't become ready
TimeoutError: If frontend doesn't become ready within timeout
"""
try:
with KVRouterProcess(
request,
block_size,
frontend_port,
engine_workers.namespace,
store_backend,
request_plane=request_plane,
):
# Start KV router frontend
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request,
block_size,
frontend_port,
engine_workers.namespace,
store_backend,
request_plane=request_plane,
)
kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}"
......@@ -690,10 +694,6 @@ def _test_router_basic(
logger.info(f"Successfully completed {num_requests} requests")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_two_routers(
engine_workers,
......@@ -1036,13 +1036,11 @@ def _test_router_query_instance_id(
AssertionError: If annotation response structure is incorrect or contains generation content
"""
try:
with KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend
):
# Start KV router (frontend)
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend
)
kv_router.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions"
......@@ -1164,10 +1162,6 @@ def _test_router_query_instance_id(
logger.info(f"Decode Worker ID: {result['decode_worker_id']}")
logger.info(f"Token count: {result['token_count']}")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_overload_503(
engine_workers,
......@@ -1194,42 +1188,17 @@ def _test_router_overload_503(
AssertionError: If 503 response is not received when expected
"""
try:
logger.info(
f"Starting KV router frontend on port {frontend_port} with limited resources"
)
# Custom command for router with limited block size
command = [
"python",
"-m",
"dynamo.frontend",
"--active-decode-blocks-threshold",
str(blocks_threshold),
"--kv-cache-block-size",
str(block_size),
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
]
kv_router = ManagedProcess(
command=command,
timeout=60,
display_output=True,
health_check_ports=[frontend_port],
health_check_urls=[
(
f"http://localhost:{frontend_port}/v1/models",
lambda r: r.status_code == 200,
)
],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
)
kv_router.__enter__()
logger.info(
f"Starting KV router frontend on port {frontend_port} with limited resources"
)
with KVRouterProcess(
request=request,
block_size=block_size,
frontend_port=frontend_port,
namespace=engine_workers.namespace,
blocks_threshold=blocks_threshold,
):
url = f"http://localhost:{frontend_port}/v1/chat/completions"
# Custom payload for 503 test with more tokens to consume resources
......@@ -1325,10 +1294,6 @@ def _test_router_overload_503(
logger.info("Successfully verified 503 response when all workers are busy")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_indexers_sync(
engine_workers,
......@@ -1727,23 +1692,21 @@ def _test_router_decisions_disagg(
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
"""
try:
with KVRouterProcess(
request,
block_size,
frontend_port,
decode_workers.namespace,
store_backend,
enforce_disagg=True,
request_plane=request_plane,
durable_kv_events=durable_kv_events,
):
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger.info(
f"Starting KV router frontend on port {frontend_port} for disagg test"
)
kv_router = KVRouterProcess(
request,
block_size,
frontend_port,
decode_workers.namespace,
store_backend,
enforce_disagg=True,
request_plane=request_plane,
durable_kv_events=durable_kv_events,
)
kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}"
chat_url = f"{frontend_url}/v1/chat/completions"
......@@ -1908,10 +1871,6 @@ def _test_router_decisions_disagg(
f" - Prefill worker is NOT in decode worker set {unique_decode_ids} (true disagg)"
)
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_decisions(
engine_workers,
......@@ -2190,20 +2149,18 @@ def _test_busy_threshold_endpoint(
initial_active_decode_blocks_threshold = 0.9
initial_active_prefill_tokens_threshold = 1000 # Literal token count threshold
try:
with KVRouterProcess(
request,
block_size,
frontend_port,
engine_workers.namespace,
store_backend,
blocks_threshold=initial_active_decode_blocks_threshold,
tokens_threshold=initial_active_prefill_tokens_threshold,
request_plane=request_plane,
):
# Start KV router frontend with initial thresholds to create monitor
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request,
block_size,
frontend_port,
engine_workers.namespace,
store_backend,
blocks_threshold=initial_active_decode_blocks_threshold,
tokens_threshold=initial_active_prefill_tokens_threshold,
request_plane=request_plane,
)
kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}"
busy_threshold_url = f"{frontend_url}/busy_threshold"
......@@ -2464,7 +2421,3 @@ def _test_busy_threshold_endpoint(
logger.info("All busy_threshold endpoint tests passed!")
asyncio.run(test_busy_threshold_api())
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
......@@ -351,17 +351,15 @@ def test_mocker_kv_router(
"durable_kv_events": durable_kv_events,
}
try:
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
) as mockers:
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique port for this test
frontend_port = get_unique_ports(
......@@ -379,10 +377,6 @@ def test_mocker_kv_router(
request_plane=request_plane,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parametrize("store_backend", ["etcd", "file"])
@pytest.mark.parametrize(
......@@ -415,17 +409,15 @@ def test_mocker_two_kv_router(
"durable_kv_events": durable_kv_events,
}
try:
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
store_backend=store_backend,
) as mockers:
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
store_backend=store_backend,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique ports for this test (2 ports for two routers)
router_ports = get_unique_ports(
......@@ -444,10 +436,6 @@ def test_mocker_two_kv_router(
skip_consumer_verification=not durable_kv_events, # Skip JetStream checks in NATS Core mode
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.skip(reason="Flaky, temporarily disabled")
@pytest.mark.parametrize(
......@@ -467,12 +455,10 @@ def test_mocker_kv_router_overload_503(
"durable_kv_events": durable_kv_events,
}
try:
with MockerProcess(request, mocker_args=mocker_args, num_mockers=1) as mockers:
# Start single mocker instance with limited resources
logger.info("Starting single mocker instance with limited resources")
mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=1)
logger.info(f"Mocker using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
......@@ -487,10 +473,6 @@ def test_mocker_kv_router_overload_503(
blocks_threshold=0.2,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(90) # bumped for xdist contention (was 22s; ~7.10s serial avg)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
......@@ -513,17 +495,15 @@ def test_kv_push_router_bindings(
"durable_kv_events": durable_kv_events,
}
try:
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
) as mockers:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
......@@ -540,10 +520,6 @@ def test_kv_push_router_bindings(
num_workers=NUM_MOCKERS,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parametrize(
"store_backend,durable_kv_events,request_plane",
......@@ -596,18 +572,16 @@ def test_indexers_sync(
"dp_size": 2,
}
try:
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
store_backend=store_backend,
request_plane=request_plane,
) as mockers:
# Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger.info(f"Starting {NUM_MOCKERS} mocker instances with dp_size=2")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
store_backend=store_backend,
request_plane=request_plane,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
......@@ -626,10 +600,6 @@ def test_indexers_sync(
logger.info("Indexers sync test completed successfully")
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(120) # bumped for xdist contention (was 42s; ~13.80s serial avg)
@pytest.mark.parametrize(
......@@ -648,14 +618,12 @@ def test_query_instance_id_returns_worker_and_tokens(
}
os.makedirs(request.node.name, exist_ok=True)
try:
with MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
) as mockers:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
......@@ -669,10 +637,6 @@ def test_query_instance_id_returns_worker_and_tokens(
test_payload=TEST_PAYLOAD,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(90) # bumped for xdist contention (was 29s; ~9.55s serial avg)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
......@@ -716,18 +680,15 @@ def test_router_decisions(
"durable_kv_events": durable_kv_events and use_kv_events,
}
try:
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=2,
request_plane=request_plane,
)
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=2,
request_plane=request_plane,
) as mockers:
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers
mockers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the mockers
......@@ -745,10 +706,6 @@ def test_router_decisions(
durable_kv_events=durable_kv_events,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.parametrize(
......@@ -788,54 +745,63 @@ def test_router_decisions_disagg(
# durable_kv_events defaults to False (NATS Core mode)
}
prefill_workers = None
decode_workers = None
try:
if registration_order == "prefill_first":
# Start prefill workers first
logger.info("Starting 4 prefill mocker instances (first)")
prefill_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
)
prefill_workers.__enter__()
if registration_order == "prefill_first":
# Start prefill workers first
logger.info("Starting 4 prefill mocker instances (first)")
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Then start decode workers
logger.info("Starting 4 decode mocker instances (second)")
decode_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
else:
# Start decode workers first
logger.info("Starting 4 decode mocker instances (first)")
decode_workers = DisaggMockerProcess(
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
)
decode_workers.__enter__()
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Get unique port for this test
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane="nats",
)
else:
# Start decode workers first
logger.info("Starting 4 decode mocker instances (first)")
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Then start prefill workers
logger.info("Starting 4 prefill mocker instances (second)")
prefill_workers = DisaggMockerProcess(
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
......@@ -843,31 +809,26 @@ def test_router_decisions_disagg(
num_mockers=4,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Get unique port for this test
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane="nats",
)
finally:
if decode_workers is not None:
decode_workers.__exit__(None, None, None)
if prefill_workers is not None:
prefill_workers.__exit__(None, None, None)
) as prefill_workers:
logger.info(
f"Prefill workers using endpoint: {prefill_workers.endpoint}"
)
# Get unique port for this test
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane="nats",
)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
......@@ -903,16 +864,14 @@ def test_busy_threshold_endpoint(
"durable_kv_events": durable_kv_events,
}
try:
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
) as mockers:
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
frontend_port = get_unique_ports(
request, num_ports=1, request_plane=request_plane
......@@ -926,7 +885,3 @@ def test_busy_threshold_endpoint(
test_payload=TEST_PAYLOAD,
request_plane=request_plane,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
......@@ -342,18 +342,16 @@ def test_sglang_kv_router_basic(
f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers using request_plane={request_plane}"
)
try:
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
) as sglang_workers:
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
sglang_workers = SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
)
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
......@@ -369,10 +367,6 @@ def test_sglang_kv_router_basic(
request_plane=request_plane,
)
finally:
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
......@@ -390,21 +384,18 @@ def test_router_decisions_sglang_multiple_workers(
logger.info("Starting SGLang router prefix reuse test with two workers")
N_WORKERS = 2
try:
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
) as sglang_workers:
# Start 2 worker processes on the same GPU
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)")
sglang_workers = SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
)
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Initialize SGLang workers
sglang_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(sglang_workers.namespace)
......@@ -415,11 +406,6 @@ def test_router_decisions_sglang_multiple_workers(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=False
)
finally:
# Clean up SGLang workers
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
@pytest.mark.gpu_2
@pytest.mark.post_merge
......@@ -442,18 +428,16 @@ def test_router_decisions_sglang_dp(
N_WORKERS = 1
DP_SIZE = 2
try:
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane,
) as sglang_workers:
logger.info("Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)")
sglang_workers = SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane,
)
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
......@@ -466,11 +450,6 @@ def test_router_decisions_sglang_dp(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
)
finally:
# Clean up SGLang workers
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
......@@ -511,20 +490,18 @@ def test_sglang_indexers_sync(
N_SGLANG_WORKERS = 2
try:
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
) as sglang_workers:
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
sglang_workers = SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
)
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
......@@ -542,7 +519,3 @@ def test_sglang_indexers_sync(
)
logger.info("SGLang indexers sync test completed successfully")
finally:
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
......@@ -332,18 +332,16 @@ def test_trtllm_kv_router_basic(
f"Starting TRT-LLM KV router test with {N_TRTLLM_WORKERS} workers using request_plane={request_plane}"
)
try:
with TRTLLMProcess(
request,
trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
) as trtllm_workers:
# Start TRT-LLM workers
logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers")
trtllm_workers = TRTLLMProcess(
request,
trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
)
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
......@@ -359,10 +357,6 @@ def test_trtllm_kv_router_basic(
request_plane=request_plane,
)
finally:
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.gpu_2
@pytest.mark.nightly
......@@ -392,19 +386,17 @@ def test_router_decisions_trtllm_attention_dp(
"tensor_parallel_size": N_ATTENTION_DP_RANKS,
}
try:
with TRTLLMProcess(
request,
trtllm_args=TRTLLM_ADP_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=False,
request_plane=request_plane,
) as trtllm_workers:
logger.info(
f"Starting 1 TRT-LLM worker with attention DP enabled (attention_dp_size={N_ATTENTION_DP_RANKS})"
)
trtllm_workers = TRTLLMProcess(
request,
trtllm_args=TRTLLM_ADP_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=False,
request_plane=request_plane,
)
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
......@@ -422,11 +414,6 @@ def test_router_decisions_trtllm_attention_dp(
block_size=TRTLLM_BLOCK_SIZE,
)
finally:
# Clean up TRTLLM workers
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
......@@ -443,23 +430,20 @@ def test_router_decisions_trtllm_multiple_workers(
logger.info("Starting TRT-LLM router prefix reuse test with two workers")
N_WORKERS = 2
try:
with TRTLLMProcess(
request,
trtllm_args=TRTLLM_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
) as trtllm_workers:
# Start 2 worker processes on the same GPU
logger.info(
"Starting 2 TRT-LLM worker processes on single GPU (gpu_mem_frac=0.4)"
)
trtllm_workers = TRTLLMProcess(
request,
trtllm_args=TRTLLM_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
)
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
# Initialize TRT-LLM workers
trtllm_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(trtllm_workers.namespace)
......@@ -475,11 +459,6 @@ def test_router_decisions_trtllm_multiple_workers(
block_size=TRTLLM_BLOCK_SIZE,
)
finally:
# Clean up TRT-LLM workers
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
......@@ -520,20 +499,18 @@ def test_trtllm_indexers_sync(
N_TRTLLM_WORKERS = 2
try:
with TRTLLMProcess(
request,
trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
) as trtllm_workers:
# Start TRT-LLM workers
logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers")
trtllm_workers = TRTLLMProcess(
request,
trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
)
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
......@@ -551,7 +528,3 @@ def test_trtllm_indexers_sync(
)
logger.info("TRT-LLM indexers sync test completed successfully")
finally:
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
......@@ -354,18 +354,16 @@ def test_vllm_kv_router_basic(
f"Starting vLLM KV router test with {N_VLLM_WORKERS} workers using request_plane={request_plane}"
)
try:
with VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
) as vllm_workers:
# Start vLLM workers
logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers")
vllm_workers = VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
)
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
......@@ -381,10 +379,6 @@ def test_vllm_kv_router_basic(
request_plane=request_plane,
)
finally:
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
......@@ -401,21 +395,17 @@ def test_router_decisions_vllm_multiple_workers(
logger.info("Starting vLLM router prefix reuse test with two workers")
N_WORKERS = 2
try:
with VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
) as vllm_workers:
# Start 2 worker processes on the same GPU
logger.info("Starting 2 vLLM worker processes on single GPU (gpu_mem=0.4)")
vllm_workers = VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
)
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
# Initialize vLLM workers
vllm_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(vllm_workers.namespace)
......@@ -426,11 +416,6 @@ def test_router_decisions_vllm_multiple_workers(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=False
)
finally:
# Clean up vLLM workers
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
@pytest.mark.gpu_2
@pytest.mark.nightly
......@@ -453,18 +438,16 @@ def test_router_decisions_vllm_dp(
N_WORKERS = 1
DP_SIZE = 2
try:
with VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane,
) as vllm_workers:
logger.info("Starting 2 vLLM DP ranks (dp_size=2) (gpu_mem=0.4)")
vllm_workers = VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane,
)
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
......@@ -477,11 +460,6 @@ def test_router_decisions_vllm_dp(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
)
finally:
# Clean up vLLM workers
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
......@@ -522,20 +500,18 @@ def test_vllm_indexers_sync(
N_VLLM_WORKERS = 2
try:
with VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
) as vllm_workers:
# Start vLLM workers
logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers")
vllm_workers = VLLMProcess(
request,
vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
)
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
......@@ -553,7 +529,3 @@ def test_vllm_indexers_sync(
)
logger.info("vLLM indexers sync test completed successfully")
finally:
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment