Unverified Commit 00f8615e authored by Jacky's avatar Jacky Committed by GitHub
Browse files

test: Ensure Request Cancellation, Migration, Rejection Work with TCP transport (#4875)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent c8845b41
...@@ -165,6 +165,7 @@ async def client(runtime, namespace): ...@@ -165,6 +165,7 @@ async def client(runtime, namespace):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
async def test_client_context_cancel(temp_file_store, server, client): async def test_client_context_cancel(temp_file_store, server, client):
_, handler = server _, handler = server
context = Context() context = Context()
...@@ -198,6 +199,7 @@ async def test_client_context_cancel(temp_file_store, server, client): ...@@ -198,6 +199,7 @@ async def test_client_context_cancel(temp_file_store, server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
async def test_client_loop_break(temp_file_store, server, client): async def test_client_loop_break(temp_file_store, server, client):
_, handler = server _, handler = server
stream = await client.generate("_generate_until_context_cancelled") stream = await client.generate("_generate_until_context_cancelled")
...@@ -230,6 +232,7 @@ async def test_client_loop_break(temp_file_store, server, client): ...@@ -230,6 +232,7 @@ async def test_client_loop_break(temp_file_store, server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
async def test_server_context_cancel(temp_file_store, server, client): async def test_server_context_cancel(temp_file_store, server, client):
_, handler = server _, handler = server
stream = await client.generate("_generate_and_cancel_context") stream = await client.generate("_generate_and_cancel_context")
...@@ -254,6 +257,7 @@ async def test_server_context_cancel(temp_file_store, server, client): ...@@ -254,6 +257,7 @@ async def test_server_context_cancel(temp_file_store, server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
async def test_server_raise_cancelled(temp_file_store, server, client): async def test_server_raise_cancelled(temp_file_store, server, client):
_, handler = server _, handler = server
stream = await client.generate("_generate_and_raise_cancelled") stream = await client.generate("_generate_and_raise_cancelled")
...@@ -282,6 +286,7 @@ async def test_server_raise_cancelled(temp_file_store, server, client): ...@@ -282,6 +286,7 @@ async def test_server_raise_cancelled(temp_file_store, server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
async def test_client_context_already_cancelled(temp_file_store, server, client): async def test_client_context_already_cancelled(temp_file_store, server, client):
_, handler = server _, handler = server
context = Context() context = Context()
...@@ -304,6 +309,7 @@ async def test_client_context_already_cancelled(temp_file_store, server, client) ...@@ -304,6 +309,7 @@ async def test_client_context_already_cancelled(temp_file_store, server, client)
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
async def test_client_context_cancel_before_await_request( async def test_client_context_cancel_before_await_request(
temp_file_store, server, client temp_file_store, server, client
): ):
......
...@@ -402,8 +402,34 @@ def temp_file_store(): ...@@ -402,8 +402,34 @@ def temp_file_store():
yield tmpdir yield tmpdir
@pytest.fixture
def store_kv(request):
"""
KV store for runtime. Defaults to "file".
To iterate over multiple stores in a test:
@pytest.mark.parametrize("store_kv", ["file", "etcd"], indirect=True)
async def test_example(runtime):
...
"""
return getattr(request, "param", "file")
@pytest.fixture
def request_plane(request):
"""
Request plane for runtime. Defaults to "nats".
To iterate over multiple transports in a test:
@pytest.mark.parametrize("request_plane", ["tcp", "nats"], indirect=True)
async def test_example(runtime):
...
"""
return getattr(request, "param", "nats")
@pytest.fixture(scope="function", autouse=False) @pytest.fixture(scope="function", autouse=False)
async def runtime(request): async def runtime(request, store_kv, request_plane):
""" """
Create a DistributedRuntime for testing. Create a DistributedRuntime for testing.
...@@ -413,6 +439,14 @@ async def runtime(request): ...@@ -413,6 +439,14 @@ async def runtime(request):
Without @pytest.mark.forked in isolated mode, you will get "Worker already initialized" Without @pytest.mark.forked in isolated mode, you will get "Worker already initialized"
errors when multiple tests try to create runtimes in the same process. errors when multiple tests try to create runtimes in the same process.
The store_kv and request_plane can be customized by overriding their fixtures
or using @pytest.mark.parametrize with indirect=True:
@pytest.mark.forked
@pytest.mark.parametrize("store_kv", ["etcd"], indirect=True)
async def test_with_etcd(runtime):
...
""" """
# Check if the test is marked with @pytest.mark.forked (only in isolated mode) # Check if the test is marked with @pytest.mark.forked (only in isolated mode)
if ENABLE_ISOLATED_ETCD_AND_NATS: if ENABLE_ISOLATED_ETCD_AND_NATS:
...@@ -435,6 +469,6 @@ This is required because DistributedRuntime is a process-level singleton. ...@@ -435,6 +469,6 @@ This is required because DistributedRuntime is a process-level singleton.
) )
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "file", "nats") runtime = DistributedRuntime(loop, store_kv, request_plane)
yield runtime yield runtime
runtime.shutdown() runtime.shutdown()
...@@ -412,11 +412,52 @@ class SharedNatsServer(SharedManagedProcess): ...@@ -412,11 +412,52 @@ class SharedNatsServer(SharedManagedProcess):
return server return server
@pytest.fixture
def store_kv(request):
"""
KV store for runtime. Defaults to "etcd".
To iterate over multiple stores in a test:
@pytest.mark.parametrize("store_kv", ["file", "etcd"], indirect=True)
def test_example(runtime_services):
...
"""
return getattr(request, "param", "etcd")
@pytest.fixture
def request_plane(request):
"""
Request plane for runtime. Defaults to "nats".
To iterate over multiple transports in a test:
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_example(runtime_services):
...
"""
return getattr(request, "param", "nats")
@pytest.fixture() @pytest.fixture()
def runtime_services(request): def runtime_services(request, store_kv, request_plane):
with NatsServer(request) as nats_process: """
Start runtime services (NATS and/or etcd) based on store_kv and request_plane.
- If store_kv != "etcd", etcd is not started (returns None)
- If request_plane != "nats", NATS is not started (returns None)
"""
if request_plane == "nats" and store_kv == "etcd":
with NatsServer(request) as nats_process:
with EtcdServer(request) as etcd_process:
yield nats_process, etcd_process
elif request_plane == "nats":
with NatsServer(request) as nats_process:
yield nats_process, None
elif store_kv == "etcd":
with EtcdServer(request) as etcd_process: with EtcdServer(request) as etcd_process:
yield nats_process, etcd_process yield None, etcd_process
else:
yield None, None
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
......
...@@ -89,8 +89,10 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -89,8 +89,10 @@ class DynamoWorkerProcess(ManagedProcess):
else: # agg (aggregated mode) else: # agg (aggregated mode)
port = "8081" port = "8081"
# Set debug logging environment # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -161,6 +163,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -161,6 +163,7 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(160) # 3x average @pytest.mark.timeout(160) # 3x average
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.xfail(strict=False) @pytest.mark.xfail(strict=False)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_request_cancellation_sglang_aggregated(request, runtime_services): def test_request_cancellation_sglang_aggregated(request, runtime_services):
""" """
End-to-end test for request cancellation functionality in aggregated mode. End-to-end test for request cancellation functionality in aggregated mode.
...@@ -245,6 +248,17 @@ def test_request_cancellation_sglang_aggregated(request, runtime_services): ...@@ -245,6 +248,17 @@ def test_request_cancellation_sglang_aggregated(request, runtime_services):
@pytest.mark.timeout(185) # 3x average @pytest.mark.timeout(185) # 3x average
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_cancellation_sglang_decode_cancel(request, runtime_services): def test_request_cancellation_sglang_decode_cancel(request, runtime_services):
""" """
End-to-end test for request cancellation during decode phase. End-to-end test for request cancellation during decode phase.
......
...@@ -85,8 +85,10 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -85,8 +85,10 @@ class DynamoWorkerProcess(ManagedProcess):
else: # prefill_and_decode else: # prefill_and_decode
port = "8081" port = "8081"
# Set debug logging environment # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -141,6 +143,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -141,6 +143,7 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(140) # 3x average @pytest.mark.timeout(140) # 3x average
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_request_cancellation_trtllm_aggregated(request, runtime_services): def test_request_cancellation_trtllm_aggregated(request, runtime_services):
""" """
End-to-end test for request cancellation functionality in aggregated mode. End-to-end test for request cancellation functionality in aggregated mode.
...@@ -213,6 +216,17 @@ def test_request_cancellation_trtllm_aggregated(request, runtime_services): ...@@ -213,6 +216,17 @@ def test_request_cancellation_trtllm_aggregated(request, runtime_services):
@pytest.mark.timeout(350) # 3x average @pytest.mark.timeout(350) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_cancellation_trtllm_decode_cancel(request, runtime_services): def test_request_cancellation_trtllm_decode_cancel(request, runtime_services):
""" """
End-to-end test for request cancellation during decode phase with unified frontend. End-to-end test for request cancellation during decode phase with unified frontend.
...@@ -284,6 +298,17 @@ def test_request_cancellation_trtllm_decode_cancel(request, runtime_services): ...@@ -284,6 +298,17 @@ def test_request_cancellation_trtllm_decode_cancel(request, runtime_services):
@pytest.mark.timeout(350) # 3x average @pytest.mark.timeout(350) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_cancellation_trtllm_prefill_cancel(request, runtime_services): def test_request_cancellation_trtllm_prefill_cancel(request, runtime_services):
""" """
End-to-end test for request cancellation during prefill phase with unified frontend. End-to-end test for request cancellation during prefill phase with unified frontend.
...@@ -365,6 +390,7 @@ def test_request_cancellation_trtllm_prefill_cancel(request, runtime_services): ...@@ -365,6 +390,7 @@ def test_request_cancellation_trtllm_prefill_cancel(request, runtime_services):
@pytest.mark.timeout(350) # 3x average @pytest.mark.timeout(350) # 3x average
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.xfail( @pytest.mark.xfail(
reason="May fail due to unknown reason with TRT-LLM or backend implementation", reason="May fail due to unknown reason with TRT-LLM or backend implementation",
strict=False, strict=False,
......
...@@ -65,8 +65,10 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -65,8 +65,10 @@ class DynamoWorkerProcess(ManagedProcess):
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate), (f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
] ]
# Set debug logging environment # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -134,6 +136,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -134,6 +136,7 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(110) # 3x average @pytest.mark.timeout(110) # 3x average
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_request_cancellation_vllm_aggregated(request, runtime_services): def test_request_cancellation_vllm_aggregated(request, runtime_services):
""" """
End-to-end test for request cancellation functionality in aggregated mode. End-to-end test for request cancellation functionality in aggregated mode.
...@@ -206,6 +209,17 @@ def test_request_cancellation_vllm_aggregated(request, runtime_services): ...@@ -206,6 +209,17 @@ def test_request_cancellation_vllm_aggregated(request, runtime_services):
@pytest.mark.timeout(150) # 3x average @pytest.mark.timeout(150) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_cancellation_vllm_decode_cancel( def test_request_cancellation_vllm_decode_cancel(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -276,6 +290,17 @@ def test_request_cancellation_vllm_decode_cancel( ...@@ -276,6 +290,17 @@ def test_request_cancellation_vllm_decode_cancel(
@pytest.mark.timeout(150) # 3x average @pytest.mark.timeout(150) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_cancellation_vllm_prefill_cancel( def test_request_cancellation_vllm_prefill_cancel(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
......
...@@ -26,8 +26,8 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -26,8 +26,8 @@ class DynamoFrontendProcess(ManagedProcess):
def __init__(self, request): def __init__(self, request):
command = ["python", "-m", "dynamo.frontend"] command = ["python", "-m", "dynamo.frontend"]
# Set debug logging environment
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
......
...@@ -56,8 +56,9 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -56,8 +56,9 @@ class DynamoWorkerProcess(ManagedProcess):
str(migration_limit), str(migration_limit),
] ]
# Set debug logging environment # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -114,6 +115,17 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -114,6 +115,17 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(235) # 3x average @pytest.mark.timeout(235) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_migration_sglang_worker_failure( def test_request_migration_sglang_worker_failure(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -158,6 +170,17 @@ def test_request_migration_sglang_worker_failure( ...@@ -158,6 +170,17 @@ def test_request_migration_sglang_worker_failure(
@pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented") @pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented")
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_migration_sglang_graceful_shutdown( def test_request_migration_sglang_graceful_shutdown(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -206,6 +229,17 @@ def test_request_migration_sglang_graceful_shutdown( ...@@ -206,6 +229,17 @@ def test_request_migration_sglang_graceful_shutdown(
@pytest.mark.timeout(135) # 3x average @pytest.mark.timeout(135) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_no_request_migration_sglang_worker_failure( def test_no_request_migration_sglang_worker_failure(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -266,6 +300,17 @@ def test_no_request_migration_sglang_worker_failure( ...@@ -266,6 +300,17 @@ def test_no_request_migration_sglang_worker_failure(
@pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented") @pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented")
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_no_request_migration_sglang_graceful_shutdown( def test_no_request_migration_sglang_graceful_shutdown(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
......
...@@ -54,8 +54,9 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -54,8 +54,9 @@ class DynamoWorkerProcess(ManagedProcess):
str(migration_limit), str(migration_limit),
] ]
# Set debug logging environment # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -110,6 +111,17 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -110,6 +111,17 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(290) # 3x average @pytest.mark.timeout(290) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_migration_trtllm_worker_failure( def test_request_migration_trtllm_worker_failure(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -154,6 +166,17 @@ def test_request_migration_trtllm_worker_failure( ...@@ -154,6 +166,17 @@ def test_request_migration_trtllm_worker_failure(
@pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented") @pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented")
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_migration_trtllm_graceful_shutdown( def test_request_migration_trtllm_graceful_shutdown(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -202,6 +225,17 @@ def test_request_migration_trtllm_graceful_shutdown( ...@@ -202,6 +225,17 @@ def test_request_migration_trtllm_graceful_shutdown(
@pytest.mark.timeout(185) # 3x average @pytest.mark.timeout(185) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_no_request_migration_trtllm_worker_failure( def test_no_request_migration_trtllm_worker_failure(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -262,6 +296,17 @@ def test_no_request_migration_trtllm_worker_failure( ...@@ -262,6 +296,17 @@ def test_no_request_migration_trtllm_worker_failure(
@pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented") @pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented")
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_no_request_migration_trtllm_graceful_shutdown( def test_no_request_migration_trtllm_graceful_shutdown(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
......
...@@ -53,8 +53,10 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -53,8 +53,10 @@ class DynamoWorkerProcess(ManagedProcess):
str(migration_limit), str(migration_limit),
] ]
# Set debug logging environment # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_VLLM_KV_EVENT_PORT"] = f"2008{worker_id[-1]}" env["DYN_VLLM_KV_EVENT_PORT"] = f"2008{worker_id[-1]}"
env["VLLM_NIXL_SIDE_CHANNEL_PORT"] = f"560{worker_id[-1]}" env["VLLM_NIXL_SIDE_CHANNEL_PORT"] = f"560{worker_id[-1]}"
...@@ -114,6 +116,17 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -114,6 +116,17 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(290) # 3x average @pytest.mark.timeout(290) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_migration_vllm_worker_failure( def test_request_migration_vllm_worker_failure(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -158,6 +171,17 @@ def test_request_migration_vllm_worker_failure( ...@@ -158,6 +171,17 @@ def test_request_migration_vllm_worker_failure(
@pytest.mark.timeout(280) # 3x average @pytest.mark.timeout(280) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_request_migration_vllm_graceful_shutdown( def test_request_migration_vllm_graceful_shutdown(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -206,6 +230,17 @@ def test_request_migration_vllm_graceful_shutdown( ...@@ -206,6 +230,17 @@ def test_request_migration_vllm_graceful_shutdown(
@pytest.mark.timeout(150) # 3x average @pytest.mark.timeout(150) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_no_request_migration_vllm_worker_failure( def test_no_request_migration_vllm_worker_failure(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
...@@ -266,6 +301,17 @@ def test_no_request_migration_vllm_worker_failure( ...@@ -266,6 +301,17 @@ def test_no_request_migration_vllm_worker_failure(
@pytest.mark.timeout(140) # 3x average @pytest.mark.timeout(140) # 3x average
@pytest.mark.parametrize(
"request_plane",
[
"nats",
pytest.param(
"tcp",
marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False),
),
],
indirect=True,
)
def test_no_request_migration_vllm_graceful_shutdown( def test_no_request_migration_vllm_graceful_shutdown(
request, runtime_services, set_ucx_tls_no_mm request, runtime_services, set_ucx_tls_no_mm
): ):
......
...@@ -23,13 +23,14 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -23,13 +23,14 @@ class DynamoFrontendProcess(ManagedProcess):
def __init__(self, request): def __init__(self, request):
command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"] command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"]
# Unset DYN_SYSTEM_PORT - frontend doesn't use system metrics server
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
# requests to workers interfering with the test process which may cause # requests to workers interfering with the test process which may cause
# intermittent failures # intermittent failures
env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_HEALTH_CHECK_ENABLED"] = "false"
# Unset DYN_SYSTEM_PORT - frontend doesn't use system metrics server
env.pop("DYN_SYSTEM_PORT", None) env.pop("DYN_SYSTEM_PORT", None)
log_dir = f"{request.node.name}_frontend" log_dir = f"{request.node.name}_frontend"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import random import random
import string import string
import time import time
...@@ -38,6 +39,7 @@ class KVRouterProcess(ManagedProcess): ...@@ -38,6 +39,7 @@ class KVRouterProcess(ManagedProcess):
store_backend: str = "etcd", store_backend: str = "etcd",
enforce_disagg: bool = False, enforce_disagg: bool = False,
busy_threshold: float | None = None, busy_threshold: float | None = None,
request_plane: str = "nats",
): ):
command = [ command = [
"python3", "python3",
...@@ -61,8 +63,12 @@ class KVRouterProcess(ManagedProcess): ...@@ -61,8 +63,12 @@ class KVRouterProcess(ManagedProcess):
if busy_threshold is not None: if busy_threshold is not None:
command.extend(["--busy-threshold", str(busy_threshold)]) command.extend(["--busy-threshold", str(busy_threshold)])
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
super().__init__( super().__init__(
command=command, command=command,
env=env,
timeout=60, timeout=60,
display_output=True, display_output=True,
health_check_ports=[frontend_port], health_check_ports=[frontend_port],
...@@ -1980,6 +1986,7 @@ def _test_busy_threshold_endpoint( ...@@ -1980,6 +1986,7 @@ def _test_busy_threshold_endpoint(
frontend_port: int, frontend_port: int,
test_payload: dict, test_payload: dict,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats",
): ):
"""Test that the /busy_threshold endpoint can be hit and responds correctly. """Test that the /busy_threshold endpoint can be hit and responds correctly.
...@@ -1997,6 +2004,7 @@ def _test_busy_threshold_endpoint( ...@@ -1997,6 +2004,7 @@ def _test_busy_threshold_endpoint(
frontend_port: Port for the frontend HTTP server frontend_port: Port for the frontend HTTP server
test_payload: Base test payload (used to extract model name) test_payload: Base test payload (used to extract model name)
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats".
Raises: Raises:
AssertionError: If endpoint responses are incorrect AssertionError: If endpoint responses are incorrect
...@@ -2014,6 +2022,7 @@ def _test_busy_threshold_endpoint( ...@@ -2014,6 +2022,7 @@ def _test_busy_threshold_endpoint(
engine_workers.namespace, engine_workers.namespace,
store_backend, store_backend,
busy_threshold=initial_threshold, busy_threshold=initial_threshold,
request_plane=request_plane,
) )
kv_router.__enter__() kv_router.__enter__()
......
...@@ -40,19 +40,23 @@ BLOCK_SIZE = 16 ...@@ -40,19 +40,23 @@ BLOCK_SIZE = 16
def get_unique_ports( def get_unique_ports(
request, num_ports: int = 1, store_backend: str = "etcd" request,
num_ports: int = 1,
store_backend: str = "etcd",
request_plane: str = "nats",
) -> list[int]: ) -> list[int]:
"""Generate unique ports for parallel test execution. """Generate unique ports for parallel test execution.
Ports are unique based on: Ports are unique based on:
- Test function name (each test gets a base offset) - Test function name (each test gets a base offset)
- Parametrization value (etcd=0, file=50) - Parametrization value (etcd=0, file=50; nats=0, tcp=25)
- Port index (for multi-port tests) - Port index (for multi-port tests)
Args: Args:
request: Pytest request fixture request: Pytest request fixture
num_ports: Number of ports needed (1 for single router, 2 for two routers) num_ports: Number of ports needed (1 for single router, 2 for two routers)
store_backend: Storage backend parameter ("etcd" or "file") store_backend: Storage backend parameter ("etcd" or "file")
request_plane: Request plane parameter ("nats" or "tcp")
Returns: Returns:
List of unique port numbers List of unique port numbers
...@@ -72,11 +76,15 @@ def get_unique_ports( ...@@ -72,11 +76,15 @@ def get_unique_ports(
base_offset = test_offsets.get(test_name, 0) base_offset = test_offsets.get(test_name, 0)
# Parametrization offset (etcd=0, file=50) # Parametrization offset (etcd=0, file=50; nats=0, tcp=25)
param_offset = 0 if store_backend == "etcd" else 50 store_offset = 0 if store_backend == "etcd" else 50
plane_offset = 0 if request_plane == "nats" else 25
# Generate ports # Generate ports
ports = [BASE_PORT + base_offset + param_offset + i for i in range(num_ports)] ports = [
BASE_PORT + base_offset + store_offset + plane_offset + i
for i in range(num_ports)
]
return ports return ports
...@@ -175,6 +183,7 @@ class MockerProcess: ...@@ -175,6 +183,7 @@ class MockerProcess:
mocker_args: Optional[Dict[str, Any]] = None, mocker_args: Optional[Dict[str, Any]] = None,
num_mockers: int = 1, num_mockers: int = 1,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats",
): ):
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}" self.namespace = f"test-namespace-{namespace_suffix}"
...@@ -191,8 +200,12 @@ class MockerProcess: ...@@ -191,8 +200,12 @@ class MockerProcess:
mocker_args=mocker_args, mocker_args=mocker_args,
) )
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
self._process = ManagedProcess( self._process = ManagedProcess(
command=command, command=command,
env=env,
timeout=60, timeout=60,
display_output=True, display_output=True,
health_check_ports=[], health_check_ports=[],
...@@ -649,8 +662,9 @@ def test_router_decisions_disagg( ...@@ -649,8 +662,9 @@ def test_router_decisions_disagg(
@pytest.mark.parallel @pytest.mark.parallel
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_busy_threshold_endpoint( def test_busy_threshold_endpoint(
request, runtime_services_session, predownload_tokenizers request, runtime_services_session, predownload_tokenizers, request_plane
): ):
"""Test that the /busy_threshold endpoint can be hit and responds correctly. """Test that the /busy_threshold endpoint can be hit and responds correctly.
...@@ -661,19 +675,26 @@ def test_busy_threshold_endpoint( ...@@ -661,19 +675,26 @@ def test_busy_threshold_endpoint(
For now, this test only verifies the endpoint is accessible and returns valid responses. For now, this test only verifies the endpoint is accessible and returns valid responses.
""" """
logger.info("Starting busy_threshold endpoint test") logger.info(
f"Starting busy_threshold endpoint test with request_plane={request_plane}"
)
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
try: try:
logger.info(f"Starting {NUM_MOCKERS} mocker instances") logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess( mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
) )
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
frontend_port = get_unique_ports(request, num_ports=1)[0] frontend_port = get_unique_ports(
request, num_ports=1, request_plane=request_plane
)[0]
_test_busy_threshold_endpoint( _test_busy_threshold_endpoint(
engine_workers=mockers, engine_workers=mockers,
...@@ -681,6 +702,7 @@ def test_busy_threshold_endpoint( ...@@ -681,6 +702,7 @@ def test_busy_threshold_endpoint(
request=request, request=request,
frontend_port=frontend_port, frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
request_plane=request_plane,
) )
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment