Unverified Commit 6fe2152b authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

test: refactor router e2e tests to use context managers for process lifecycle (#6088)

parent 1cd3b724
...@@ -528,6 +528,7 @@ async def send_request_via_python_kv_router( ...@@ -528,6 +528,7 @@ async def send_request_via_python_kv_router(
) )
# Retry loop sending request to worker with exponential backoff # Retry loop sending request to worker with exponential backoff
stream = None
for attempt in range(max_retries + 1): for attempt in range(max_retries + 1):
try: try:
logger.debug(f"Sending request to {log_message} (attempt {attempt + 1})") logger.debug(f"Sending request to {log_message} (attempt {attempt + 1})")
...@@ -557,6 +558,11 @@ async def send_request_via_python_kv_router( ...@@ -557,6 +558,11 @@ async def send_request_via_python_kv_router(
f"Failed to connect to workers after {max_retries + 1} attempts" f"Failed to connect to workers after {max_retries + 1} attempts"
) from e ) from e
if stream is None:
raise RuntimeError(
f"Failed to get a valid stream from workers after {max_retries + 1} attempts"
)
# Collect tokens and worker IDs from the SSE stream # Collect tokens and worker IDs from the SSE stream
generated_tokens = [] generated_tokens = []
prefill_worker_id: Optional[int] = None prefill_worker_id: Optional[int] = None
...@@ -653,18 +659,16 @@ def _test_router_basic( ...@@ -653,18 +659,16 @@ def _test_router_basic(
AssertionError: If requests fail or frontend doesn't become ready AssertionError: If requests fail or frontend doesn't become ready
TimeoutError: If frontend doesn't become ready within timeout TimeoutError: If frontend doesn't become ready within timeout
""" """
try: with KVRouterProcess(
# Start KV router frontend
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request, request,
block_size, block_size,
frontend_port, frontend_port,
engine_workers.namespace, engine_workers.namespace,
store_backend, store_backend,
request_plane=request_plane, request_plane=request_plane,
) ):
kv_router.__enter__() # Start KV router frontend
logger.info(f"Starting KV router frontend on port {frontend_port}")
frontend_url = f"http://localhost:{frontend_port}" frontend_url = f"http://localhost:{frontend_port}"
...@@ -690,10 +694,6 @@ def _test_router_basic( ...@@ -690,10 +694,6 @@ def _test_router_basic(
logger.info(f"Successfully completed {num_requests} requests") logger.info(f"Successfully completed {num_requests} requests")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_two_routers( def _test_router_two_routers(
engine_workers, engine_workers,
...@@ -1036,13 +1036,11 @@ def _test_router_query_instance_id( ...@@ -1036,13 +1036,11 @@ def _test_router_query_instance_id(
AssertionError: If annotation response structure is incorrect or contains generation content AssertionError: If annotation response structure is incorrect or contains generation content
""" """
try: with KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend
):
# Start KV router (frontend) # Start KV router (frontend)
logger.info(f"Starting KV router frontend on port {frontend_port}") logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend
)
kv_router.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions" url = f"http://localhost:{frontend_port}/v1/chat/completions"
...@@ -1164,10 +1162,6 @@ def _test_router_query_instance_id( ...@@ -1164,10 +1162,6 @@ def _test_router_query_instance_id(
logger.info(f"Decode Worker ID: {result['decode_worker_id']}") logger.info(f"Decode Worker ID: {result['decode_worker_id']}")
logger.info(f"Token count: {result['token_count']}") logger.info(f"Token count: {result['token_count']}")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_overload_503( def _test_router_overload_503(
engine_workers, engine_workers,
...@@ -1194,42 +1188,17 @@ def _test_router_overload_503( ...@@ -1194,42 +1188,17 @@ def _test_router_overload_503(
AssertionError: If 503 response is not received when expected AssertionError: If 503 response is not received when expected
""" """
try:
logger.info( logger.info(
f"Starting KV router frontend on port {frontend_port} with limited resources" f"Starting KV router frontend on port {frontend_port} with limited resources"
) )
# Custom command for router with limited block size with KVRouterProcess(
command = [ request=request,
"python", block_size=block_size,
"-m", frontend_port=frontend_port,
"dynamo.frontend", namespace=engine_workers.namespace,
"--active-decode-blocks-threshold", blocks_threshold=blocks_threshold,
str(blocks_threshold), ):
"--kv-cache-block-size",
str(block_size),
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
]
kv_router = ManagedProcess(
command=command,
timeout=60,
display_output=True,
health_check_ports=[frontend_port],
health_check_urls=[
(
f"http://localhost:{frontend_port}/v1/models",
lambda r: r.status_code == 200,
)
],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
)
kv_router.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions" url = f"http://localhost:{frontend_port}/v1/chat/completions"
# Custom payload for 503 test with more tokens to consume resources # Custom payload for 503 test with more tokens to consume resources
...@@ -1325,10 +1294,6 @@ def _test_router_overload_503( ...@@ -1325,10 +1294,6 @@ def _test_router_overload_503(
logger.info("Successfully verified 503 response when all workers are busy") logger.info("Successfully verified 503 response when all workers are busy")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_indexers_sync( def _test_router_indexers_sync(
engine_workers, engine_workers,
...@@ -1727,13 +1692,7 @@ def _test_router_decisions_disagg( ...@@ -1727,13 +1692,7 @@ def _test_router_decisions_disagg(
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure) AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg) AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
""" """
try: with KVRouterProcess(
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger.info(
f"Starting KV router frontend on port {frontend_port} for disagg test"
)
kv_router = KVRouterProcess(
request, request,
block_size, block_size,
frontend_port, frontend_port,
...@@ -1742,8 +1701,12 @@ def _test_router_decisions_disagg( ...@@ -1742,8 +1701,12 @@ def _test_router_decisions_disagg(
enforce_disagg=True, enforce_disagg=True,
request_plane=request_plane, request_plane=request_plane,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
):
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger.info(
f"Starting KV router frontend on port {frontend_port} for disagg test"
) )
kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}" frontend_url = f"http://localhost:{frontend_port}"
chat_url = f"{frontend_url}/v1/chat/completions" chat_url = f"{frontend_url}/v1/chat/completions"
...@@ -1908,10 +1871,6 @@ def _test_router_decisions_disagg( ...@@ -1908,10 +1871,6 @@ def _test_router_decisions_disagg(
f" - Prefill worker is NOT in decode worker set {unique_decode_ids} (true disagg)" f" - Prefill worker is NOT in decode worker set {unique_decode_ids} (true disagg)"
) )
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_decisions( def _test_router_decisions(
engine_workers, engine_workers,
...@@ -2190,10 +2149,7 @@ def _test_busy_threshold_endpoint( ...@@ -2190,10 +2149,7 @@ def _test_busy_threshold_endpoint(
initial_active_decode_blocks_threshold = 0.9 initial_active_decode_blocks_threshold = 0.9
initial_active_prefill_tokens_threshold = 1000 # Literal token count threshold initial_active_prefill_tokens_threshold = 1000 # Literal token count threshold
try: with KVRouterProcess(
# Start KV router frontend with initial thresholds to create monitor
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request, request,
block_size, block_size,
frontend_port, frontend_port,
...@@ -2202,8 +2158,9 @@ def _test_busy_threshold_endpoint( ...@@ -2202,8 +2158,9 @@ def _test_busy_threshold_endpoint(
blocks_threshold=initial_active_decode_blocks_threshold, blocks_threshold=initial_active_decode_blocks_threshold,
tokens_threshold=initial_active_prefill_tokens_threshold, tokens_threshold=initial_active_prefill_tokens_threshold,
request_plane=request_plane, request_plane=request_plane,
) ):
kv_router.__enter__() # Start KV router frontend with initial thresholds to create monitor
logger.info(f"Starting KV router frontend on port {frontend_port}")
frontend_url = f"http://localhost:{frontend_port}" frontend_url = f"http://localhost:{frontend_port}"
busy_threshold_url = f"{frontend_url}/busy_threshold" busy_threshold_url = f"{frontend_url}/busy_threshold"
...@@ -2464,7 +2421,3 @@ def _test_busy_threshold_endpoint( ...@@ -2464,7 +2421,3 @@ def _test_busy_threshold_endpoint(
logger.info("All busy_threshold endpoint tests passed!") logger.info("All busy_threshold endpoint tests passed!")
asyncio.run(test_busy_threshold_api()) asyncio.run(test_busy_threshold_api())
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
...@@ -351,17 +351,15 @@ def test_mocker_kv_router( ...@@ -351,17 +351,15 @@ def test_mocker_kv_router(
"durable_kv_events": durable_kv_events, "durable_kv_events": durable_kv_events,
} }
try: with MockerProcess(
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=NUM_MOCKERS, num_mockers=NUM_MOCKERS,
request_plane=request_plane, request_plane=request_plane,
) ) as mockers:
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique port for this test # Get unique port for this test
frontend_port = get_unique_ports( frontend_port = get_unique_ports(
...@@ -379,10 +377,6 @@ def test_mocker_kv_router( ...@@ -379,10 +377,6 @@ def test_mocker_kv_router(
request_plane=request_plane, request_plane=request_plane,
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parametrize("store_backend", ["etcd", "file"]) @pytest.mark.parametrize("store_backend", ["etcd", "file"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -415,17 +409,15 @@ def test_mocker_two_kv_router( ...@@ -415,17 +409,15 @@ def test_mocker_two_kv_router(
"durable_kv_events": durable_kv_events, "durable_kv_events": durable_kv_events,
} }
try: with MockerProcess(
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=NUM_MOCKERS, num_mockers=NUM_MOCKERS,
store_backend=store_backend, store_backend=store_backend,
) ) as mockers:
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique ports for this test (2 ports for two routers) # Get unique ports for this test (2 ports for two routers)
router_ports = get_unique_ports( router_ports = get_unique_ports(
...@@ -444,10 +436,6 @@ def test_mocker_two_kv_router( ...@@ -444,10 +436,6 @@ def test_mocker_two_kv_router(
skip_consumer_verification=not durable_kv_events, # Skip JetStream checks in NATS Core mode skip_consumer_verification=not durable_kv_events, # Skip JetStream checks in NATS Core mode
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.skip(reason="Flaky, temporarily disabled") @pytest.mark.skip(reason="Flaky, temporarily disabled")
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -467,12 +455,10 @@ def test_mocker_kv_router_overload_503( ...@@ -467,12 +455,10 @@ def test_mocker_kv_router_overload_503(
"durable_kv_events": durable_kv_events, "durable_kv_events": durable_kv_events,
} }
try: with MockerProcess(request, mocker_args=mocker_args, num_mockers=1) as mockers:
# Start single mocker instance with limited resources # Start single mocker instance with limited resources
logger.info("Starting single mocker instance with limited resources") logger.info("Starting single mocker instance with limited resources")
mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=1)
logger.info(f"Mocker using endpoint: {mockers.endpoint}") logger.info(f"Mocker using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique port for this test # Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0] frontend_port = get_unique_ports(request, num_ports=1)[0]
...@@ -487,10 +473,6 @@ def test_mocker_kv_router_overload_503( ...@@ -487,10 +473,6 @@ def test_mocker_kv_router_overload_503(
blocks_threshold=0.2, blocks_threshold=0.2,
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(90) # bumped for xdist contention (was 22s; ~7.10s serial avg) @pytest.mark.timeout(90) # bumped for xdist contention (was 22s; ~7.10s serial avg)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
...@@ -513,17 +495,15 @@ def test_kv_push_router_bindings( ...@@ -513,17 +495,15 @@ def test_kv_push_router_bindings(
"durable_kv_events": durable_kv_events, "durable_kv_events": durable_kv_events,
} }
try: with MockerProcess(
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=NUM_MOCKERS, num_mockers=NUM_MOCKERS,
request_plane=request_plane, request_plane=request_plane,
) ) as mockers:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
...@@ -540,10 +520,6 @@ def test_kv_push_router_bindings( ...@@ -540,10 +520,6 @@ def test_kv_push_router_bindings(
num_workers=NUM_MOCKERS, num_workers=NUM_MOCKERS,
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"store_backend,durable_kv_events,request_plane", "store_backend,durable_kv_events,request_plane",
...@@ -596,18 +572,16 @@ def test_indexers_sync( ...@@ -596,18 +572,16 @@ def test_indexers_sync(
"dp_size": 2, "dp_size": 2,
} }
try: with MockerProcess(
# Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger.info(f"Starting {NUM_MOCKERS} mocker instances with dp_size=2")
mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=NUM_MOCKERS, num_mockers=NUM_MOCKERS,
store_backend=store_backend, store_backend=store_backend,
request_plane=request_plane, request_plane=request_plane,
) ) as mockers:
# Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger.info(f"Starting {NUM_MOCKERS} mocker instances with dp_size=2")
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Use the common test implementation (creates its own runtimes for each router) # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...@@ -626,10 +600,6 @@ def test_indexers_sync( ...@@ -626,10 +600,6 @@ def test_indexers_sync(
logger.info("Indexers sync test completed successfully") logger.info("Indexers sync test completed successfully")
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(120) # bumped for xdist contention (was 42s; ~13.80s serial avg) @pytest.mark.timeout(120) # bumped for xdist contention (was 42s; ~13.80s serial avg)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -648,14 +618,12 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -648,14 +618,12 @@ def test_query_instance_id_returns_worker_and_tokens(
} }
os.makedirs(request.node.name, exist_ok=True) os.makedirs(request.node.name, exist_ok=True)
try: with MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
) as mockers:
# Start mocker instances # Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances") logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Get unique port for this test # Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0] frontend_port = get_unique_ports(request, num_ports=1)[0]
...@@ -669,10 +637,6 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -669,10 +637,6 @@ def test_query_instance_id_returns_worker_and_tokens(
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.timeout(90) # bumped for xdist contention (was 29s; ~9.55s serial avg) @pytest.mark.timeout(90) # bumped for xdist contention (was 29s; ~9.55s serial avg)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
...@@ -716,18 +680,15 @@ def test_router_decisions( ...@@ -716,18 +680,15 @@ def test_router_decisions(
"durable_kv_events": durable_kv_events and use_kv_events, "durable_kv_events": durable_kv_events and use_kv_events,
} }
try: with MockerProcess(
mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=2, num_mockers=2,
request_plane=request_plane, request_plane=request_plane,
) ) as mockers:
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers # Initialize mockers
mockers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the mockers # Use the namespace from the mockers
...@@ -745,10 +706,6 @@ def test_router_decisions( ...@@ -745,10 +706,6 @@ def test_router_decisions(
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"]) @pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -788,14 +745,10 @@ def test_router_decisions_disagg( ...@@ -788,14 +745,10 @@ def test_router_decisions_disagg(
# durable_kv_events defaults to False (NATS Core mode) # durable_kv_events defaults to False (NATS Core mode)
} }
prefill_workers = None
decode_workers = None
try:
if registration_order == "prefill_first": if registration_order == "prefill_first":
# Start prefill workers first # Start prefill workers first
logger.info("Starting 4 prefill mocker instances (first)") logger.info("Starting 4 prefill mocker instances (first)")
prefill_workers = DisaggMockerProcess( with DisaggMockerProcess(
request, request,
namespace=shared_namespace, namespace=shared_namespace,
worker_type="prefill", worker_type="prefill",
...@@ -803,39 +756,52 @@ def test_router_decisions_disagg( ...@@ -803,39 +756,52 @@ def test_router_decisions_disagg(
num_mockers=4, num_mockers=4,
request_plane="nats", request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap, enable_bootstrap=enable_disagg_bootstrap,
) ) as prefill_workers:
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}") logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Then start decode workers # Then start decode workers
logger.info("Starting 4 decode mocker instances (second)") logger.info("Starting 4 decode mocker instances (second)")
decode_workers = DisaggMockerProcess( with DisaggMockerProcess(
request, request,
namespace=shared_namespace, namespace=shared_namespace,
worker_type="decode", worker_type="decode",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane="nats", request_plane="nats",
) ) as decode_workers:
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Get unique port for this test
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane="nats",
)
else: else:
# Start decode workers first # Start decode workers first
logger.info("Starting 4 decode mocker instances (first)") logger.info("Starting 4 decode mocker instances (first)")
decode_workers = DisaggMockerProcess( with DisaggMockerProcess(
request, request,
namespace=shared_namespace, namespace=shared_namespace,
worker_type="decode", worker_type="decode",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane="nats", request_plane="nats",
) ) as decode_workers:
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Then start prefill workers # Then start prefill workers
logger.info("Starting 4 prefill mocker instances (second)") logger.info("Starting 4 prefill mocker instances (second)")
prefill_workers = DisaggMockerProcess( with DisaggMockerProcess(
request, request,
namespace=shared_namespace, namespace=shared_namespace,
worker_type="prefill", worker_type="prefill",
...@@ -843,9 +809,10 @@ def test_router_decisions_disagg( ...@@ -843,9 +809,10 @@ def test_router_decisions_disagg(
num_mockers=4, num_mockers=4,
request_plane="nats", request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap, enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(
f"Prefill workers using endpoint: {prefill_workers.endpoint}"
) )
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Get unique port for this test # Get unique port for this test
frontend_port = get_unique_ports( frontend_port = get_unique_ports(
...@@ -863,12 +830,6 @@ def test_router_decisions_disagg( ...@@ -863,12 +830,6 @@ def test_router_decisions_disagg(
request_plane="nats", request_plane="nats",
) )
finally:
if decode_workers is not None:
decode_workers.__exit__(None, None, None)
if prefill_workers is not None:
prefill_workers.__exit__(None, None, None)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -903,16 +864,14 @@ def test_busy_threshold_endpoint( ...@@ -903,16 +864,14 @@ def test_busy_threshold_endpoint(
"durable_kv_events": durable_kv_events, "durable_kv_events": durable_kv_events,
} }
try: with MockerProcess(
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=NUM_MOCKERS, num_mockers=NUM_MOCKERS,
request_plane=request_plane, request_plane=request_plane,
) ) as mockers:
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
frontend_port = get_unique_ports( frontend_port = get_unique_ports(
request, num_ports=1, request_plane=request_plane request, num_ports=1, request_plane=request_plane
...@@ -926,7 +885,3 @@ def test_busy_threshold_endpoint( ...@@ -926,7 +885,3 @@ def test_busy_threshold_endpoint(
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
request_plane=request_plane, request_plane=request_plane,
) )
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
...@@ -342,18 +342,16 @@ def test_sglang_kv_router_basic( ...@@ -342,18 +342,16 @@ def test_sglang_kv_router_basic(
f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers using request_plane={request_plane}" f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers using request_plane={request_plane}"
) )
try: with SGLangProcess(
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
sglang_workers = SGLangProcess(
request, request,
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS, num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane, request_plane=request_plane,
) ) as sglang_workers:
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0] frontend_port = allocate_frontend_ports(request, 1)[0]
...@@ -369,10 +367,6 @@ def test_sglang_kv_router_basic( ...@@ -369,10 +367,6 @@ def test_sglang_kv_router_basic(
request_plane=request_plane, request_plane=request_plane,
) )
finally:
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -390,21 +384,18 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -390,21 +384,18 @@ def test_router_decisions_sglang_multiple_workers(
logger.info("Starting SGLang router prefix reuse test with two workers") logger.info("Starting SGLang router prefix reuse test with two workers")
N_WORKERS = 2 N_WORKERS = 2
try: with SGLangProcess(
# Start 2 worker processes on the same GPU
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)")
sglang_workers = SGLangProcess(
request, request,
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0 single_gpu=True, # Worker uses GPU 0
request_plane=request_plane, request_plane=request_plane,
) ) as sglang_workers:
# Start 2 worker processes on the same GPU
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Initialize SGLang workers # Initialize SGLang workers
sglang_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(sglang_workers.namespace) namespace = runtime.namespace(sglang_workers.namespace)
...@@ -415,11 +406,6 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -415,11 +406,6 @@ def test_router_decisions_sglang_multiple_workers(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=False sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=False
) )
finally:
# Clean up SGLang workers
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.post_merge @pytest.mark.post_merge
...@@ -442,18 +428,16 @@ def test_router_decisions_sglang_dp( ...@@ -442,18 +428,16 @@ def test_router_decisions_sglang_dp(
N_WORKERS = 1 N_WORKERS = 1
DP_SIZE = 2 DP_SIZE = 2
try: with SGLangProcess(
logger.info("Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)")
sglang_workers = SGLangProcess(
request, request,
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False, single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank) data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane, request_plane=request_plane,
) ) as sglang_workers:
logger.info("Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
...@@ -466,11 +450,6 @@ def test_router_decisions_sglang_dp( ...@@ -466,11 +450,6 @@ def test_router_decisions_sglang_dp(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
) )
finally:
# Clean up SGLang workers
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -511,10 +490,7 @@ def test_sglang_indexers_sync( ...@@ -511,10 +490,7 @@ def test_sglang_indexers_sync(
N_SGLANG_WORKERS = 2 N_SGLANG_WORKERS = 2
try: with SGLangProcess(
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
sglang_workers = SGLangProcess(
request, request,
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS, num_workers=N_SGLANG_WORKERS,
...@@ -522,9 +498,10 @@ def test_sglang_indexers_sync( ...@@ -522,9 +498,10 @@ def test_sglang_indexers_sync(
request_plane=request_plane, request_plane=request_plane,
store_backend=store_backend, store_backend=store_backend,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
) ) as sglang_workers:
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router) # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...@@ -542,7 +519,3 @@ def test_sglang_indexers_sync( ...@@ -542,7 +519,3 @@ def test_sglang_indexers_sync(
) )
logger.info("SGLang indexers sync test completed successfully") logger.info("SGLang indexers sync test completed successfully")
finally:
if "sglang_workers" in locals():
sglang_workers.__exit__(None, None, None)
...@@ -332,18 +332,16 @@ def test_trtllm_kv_router_basic( ...@@ -332,18 +332,16 @@ def test_trtllm_kv_router_basic(
f"Starting TRT-LLM KV router test with {N_TRTLLM_WORKERS} workers using request_plane={request_plane}" f"Starting TRT-LLM KV router test with {N_TRTLLM_WORKERS} workers using request_plane={request_plane}"
) )
try: with TRTLLMProcess(
# Start TRT-LLM workers
logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers")
trtllm_workers = TRTLLMProcess(
request, request,
trtllm_args=TRTLLM_ARGS, trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS, num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane, request_plane=request_plane,
) ) as trtllm_workers:
# Start TRT-LLM workers
logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers")
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0] frontend_port = allocate_frontend_ports(request, 1)[0]
...@@ -359,10 +357,6 @@ def test_trtllm_kv_router_basic( ...@@ -359,10 +357,6 @@ def test_trtllm_kv_router_basic(
request_plane=request_plane, request_plane=request_plane,
) )
finally:
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.nightly @pytest.mark.nightly
...@@ -392,19 +386,17 @@ def test_router_decisions_trtllm_attention_dp( ...@@ -392,19 +386,17 @@ def test_router_decisions_trtllm_attention_dp(
"tensor_parallel_size": N_ATTENTION_DP_RANKS, "tensor_parallel_size": N_ATTENTION_DP_RANKS,
} }
try: with TRTLLMProcess(
logger.info(
f"Starting 1 TRT-LLM worker with attention DP enabled (attention_dp_size={N_ATTENTION_DP_RANKS})"
)
trtllm_workers = TRTLLMProcess(
request, request,
trtllm_args=TRTLLM_ADP_ARGS, trtllm_args=TRTLLM_ADP_ARGS,
num_workers=N_TRTLLM_WORKERS, num_workers=N_TRTLLM_WORKERS,
single_gpu=False, single_gpu=False,
request_plane=request_plane, request_plane=request_plane,
) as trtllm_workers:
logger.info(
f"Starting 1 TRT-LLM worker with attention DP enabled (attention_dp_size={N_ATTENTION_DP_RANKS})"
) )
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
...@@ -422,11 +414,6 @@ def test_router_decisions_trtllm_attention_dp( ...@@ -422,11 +414,6 @@ def test_router_decisions_trtllm_attention_dp(
block_size=TRTLLM_BLOCK_SIZE, block_size=TRTLLM_BLOCK_SIZE,
) )
finally:
# Clean up TRTLLM workers
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -443,23 +430,20 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -443,23 +430,20 @@ def test_router_decisions_trtllm_multiple_workers(
logger.info("Starting TRT-LLM router prefix reuse test with two workers") logger.info("Starting TRT-LLM router prefix reuse test with two workers")
N_WORKERS = 2 N_WORKERS = 2
try: with TRTLLMProcess(
# Start 2 worker processes on the same GPU
logger.info(
"Starting 2 TRT-LLM worker processes on single GPU (gpu_mem_frac=0.4)"
)
trtllm_workers = TRTLLMProcess(
request, request,
trtllm_args=TRTLLM_ARGS, trtllm_args=TRTLLM_ARGS,
num_workers=N_WORKERS, num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0 single_gpu=True, # Worker uses GPU 0
request_plane=request_plane, request_plane=request_plane,
) as trtllm_workers:
# Start 2 worker processes on the same GPU
logger.info(
"Starting 2 TRT-LLM worker processes on single GPU (gpu_mem_frac=0.4)"
) )
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
# Initialize TRT-LLM workers # Initialize TRT-LLM workers
trtllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(trtllm_workers.namespace) namespace = runtime.namespace(trtllm_workers.namespace)
...@@ -475,11 +459,6 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -475,11 +459,6 @@ def test_router_decisions_trtllm_multiple_workers(
block_size=TRTLLM_BLOCK_SIZE, block_size=TRTLLM_BLOCK_SIZE,
) )
finally:
# Clean up TRT-LLM workers
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -520,10 +499,7 @@ def test_trtllm_indexers_sync( ...@@ -520,10 +499,7 @@ def test_trtllm_indexers_sync(
N_TRTLLM_WORKERS = 2 N_TRTLLM_WORKERS = 2
try: with TRTLLMProcess(
# Start TRT-LLM workers
logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers")
trtllm_workers = TRTLLMProcess(
request, request,
trtllm_args=TRTLLM_ARGS, trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS, num_workers=N_TRTLLM_WORKERS,
...@@ -531,9 +507,10 @@ def test_trtllm_indexers_sync( ...@@ -531,9 +507,10 @@ def test_trtllm_indexers_sync(
request_plane=request_plane, request_plane=request_plane,
store_backend=store_backend, store_backend=store_backend,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
) ) as trtllm_workers:
# Start TRT-LLM workers
logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers")
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router) # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...@@ -551,7 +528,3 @@ def test_trtllm_indexers_sync( ...@@ -551,7 +528,3 @@ def test_trtllm_indexers_sync(
) )
logger.info("TRT-LLM indexers sync test completed successfully") logger.info("TRT-LLM indexers sync test completed successfully")
finally:
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
...@@ -354,18 +354,16 @@ def test_vllm_kv_router_basic( ...@@ -354,18 +354,16 @@ def test_vllm_kv_router_basic(
f"Starting vLLM KV router test with {N_VLLM_WORKERS} workers using request_plane={request_plane}" f"Starting vLLM KV router test with {N_VLLM_WORKERS} workers using request_plane={request_plane}"
) )
try: with VLLMProcess(
# Start vLLM workers
logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers")
vllm_workers = VLLMProcess(
request, request,
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS, num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane, request_plane=request_plane,
) ) as vllm_workers:
# Start vLLM workers
logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers")
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__()
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0] frontend_port = allocate_frontend_ports(request, 1)[0]
...@@ -381,10 +379,6 @@ def test_vllm_kv_router_basic( ...@@ -381,10 +379,6 @@ def test_vllm_kv_router_basic(
request_plane=request_plane, request_plane=request_plane,
) )
finally:
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -401,21 +395,17 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -401,21 +395,17 @@ def test_router_decisions_vllm_multiple_workers(
logger.info("Starting vLLM router prefix reuse test with two workers") logger.info("Starting vLLM router prefix reuse test with two workers")
N_WORKERS = 2 N_WORKERS = 2
try: with VLLMProcess(
# Start 2 worker processes on the same GPU
logger.info("Starting 2 vLLM worker processes on single GPU (gpu_mem=0.4)")
vllm_workers = VLLMProcess(
request, request,
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_WORKERS, num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0 single_gpu=True, # Worker uses GPU 0
request_plane=request_plane, request_plane=request_plane,
) ) as vllm_workers:
# Start 2 worker processes on the same GPU
logger.info("Starting 2 vLLM worker processes on single GPU (gpu_mem=0.4)")
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
# Initialize vLLM workers
vllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(vllm_workers.namespace) namespace = runtime.namespace(vllm_workers.namespace)
...@@ -426,11 +416,6 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -426,11 +416,6 @@ def test_router_decisions_vllm_multiple_workers(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=False vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=False
) )
finally:
# Clean up vLLM workers
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.nightly @pytest.mark.nightly
...@@ -453,18 +438,16 @@ def test_router_decisions_vllm_dp( ...@@ -453,18 +438,16 @@ def test_router_decisions_vllm_dp(
N_WORKERS = 1 N_WORKERS = 1
DP_SIZE = 2 DP_SIZE = 2
try: with VLLMProcess(
logger.info("Starting 2 vLLM DP ranks (dp_size=2) (gpu_mem=0.4)")
vllm_workers = VLLMProcess(
request, request,
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False, single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank) data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane, request_plane=request_plane,
) ) as vllm_workers:
logger.info("Starting 2 vLLM DP ranks (dp_size=2) (gpu_mem=0.4)")
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane) runtime = get_runtime(request_plane=request_plane)
...@@ -477,11 +460,6 @@ def test_router_decisions_vllm_dp( ...@@ -477,11 +460,6 @@ def test_router_decisions_vllm_dp(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=True vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
) )
finally:
# Clean up vLLM workers
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -522,10 +500,7 @@ def test_vllm_indexers_sync( ...@@ -522,10 +500,7 @@ def test_vllm_indexers_sync(
N_VLLM_WORKERS = 2 N_VLLM_WORKERS = 2
try: with VLLMProcess(
# Start vLLM workers
logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers")
vllm_workers = VLLMProcess(
request, request,
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS, num_workers=N_VLLM_WORKERS,
...@@ -533,9 +508,10 @@ def test_vllm_indexers_sync( ...@@ -533,9 +508,10 @@ def test_vllm_indexers_sync(
request_plane=request_plane, request_plane=request_plane,
store_backend=store_backend, store_backend=store_backend,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
) ) as vllm_workers:
# Start vLLM workers
logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers")
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router) # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
...@@ -553,7 +529,3 @@ def test_vllm_indexers_sync( ...@@ -553,7 +529,3 @@ def test_vllm_indexers_sync(
) )
logger.info("vLLM indexers sync test completed successfully") logger.info("vLLM indexers sync test completed successfully")
finally:
if "vllm_workers" in locals():
vllm_workers.__exit__(None, None, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment