Unverified Commit cdcf6d0a authored by Jacky's avatar Jacky Committed by GitHub
Browse files

test: Rewrite Request Migration Tests and Add Disagg Scenarios (#5448)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent edf847e5
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Test Execution Times (Last Run: 2025-12-09): Test Execution Times (Last Run: 2026-01-13):
- test_request_migration_sglang_worker_failure: ~58s (gpu_1) - test_request_migration_sglang_aggregated: ~75s
- test_request_migration_sglang_graceful_shutdown: ~58s (gpu_1, skipped) - test_request_migration_sglang_prefill: N/A
- test_no_request_migration_sglang_worker_failure: ~38s (gpu_1) - test_request_migration_sglang_kv_transfer: N/A
- test_no_request_migration_sglang_graceful_shutdown: ~38s (gpu_1, skipped) - test_request_migration_sglang_decode: ~75s
- Total: 115.71s (0:01:55) for enabled tests
""" """
import logging import logging
...@@ -17,19 +16,12 @@ import shutil ...@@ -17,19 +16,12 @@ import shutil
import pytest import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess, terminate_process_tree from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_models_api from tests.utils.payloads import check_models_api
from tests.utils.port_utils import allocate_port, deallocate_port from tests.utils.port_utils import allocate_port, deallocate_port
# Import utilities from the refactored utils module # Customized utils for migration tests
from .utils import ( from .utils import DynamoFrontendProcess, run_migration_test
DynamoFrontendProcess,
determine_request_receiving_worker,
start_completion_request,
validate_completion_response,
verify_migration_metrics,
verify_migration_occurred,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,23 +31,77 @@ pytestmark = [ ...@@ -39,23 +31,77 @@ pytestmark = [
pytest.mark.e2e, pytest.mark.e2e,
pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
pytest.mark.post_merge, # post_merge to pinpoint failure commit pytest.mark.post_merge, # post_merge to pinpoint failure commit
pytest.mark.parametrize(
"migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
),
pytest.mark.parametrize(
"immediate_kill",
[
pytest.param(True, id="worker_failure"),
pytest.param(
False,
id="graceful_shutdown",
marks=pytest.mark.xfail(
strict=False, reason="SGLang graceful shutdown not yet implemented"
),
),
],
),
pytest.mark.parametrize(
"request_api",
[
pytest.param("chat"),
pytest.param(
"completion",
marks=pytest.mark.skip(reason="Behavior unverified yet"),
),
],
),
pytest.mark.parametrize(
"stream",
[
pytest.param(True, id="stream"),
pytest.param(
False,
id="unary",
marks=pytest.mark.skip(reason="Behavior unverified yet"),
),
],
),
pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True), pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
] ]
class DynamoWorkerProcess(ManagedProcess): class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with SGLang backend""" """Process manager for Dynamo worker with SGLang backend
Supports both aggregated mode (single worker) and disaggregated mode
(separate prefill and decode workers).
Args:
request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "worker2")
frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated
"""
def __init__( def __init__(
self, self,
request, request,
worker_id: str, worker_id: str,
system_port: int,
frontend_port: int, frontend_port: int,
migration_limit: int = 3, migration_limit: int = 3,
disagg_mode: str | None = None,
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.system_port = system_port self.system_port = allocate_port(9100)
self.disagg_mode = disagg_mode
# Prefill workers require migration_limit=0 (no KV cache migration support)
if disagg_mode == "prefill":
logging.info("Prefill worker - setting migration_limit to 0")
migration_limit = 0
command = [ command = [
"python3", "python3",
...@@ -66,18 +112,41 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -66,18 +112,41 @@ class DynamoWorkerProcess(ManagedProcess):
"--served-model-name", "--served-model-name",
FAULT_TOLERANCE_MODEL_NAME, FAULT_TOLERANCE_MODEL_NAME,
"--trust-remote-code", "--trust-remote-code",
"--skip-tokenizer-init", "--page-size",
"16",
"--tp",
"1",
"--mem-fraction-static", "--mem-fraction-static",
"0.45", "0.3",
"--context-length", "--context-length",
"8192", "8192",
"--migration-limit", "--migration-limit",
str(migration_limit), str(migration_limit),
] ]
if disagg_mode is None:
# Aggregated
command.append("--skip-tokenizer-init")
else:
# Disaggregated
command.extend(
[
"--disaggregation-mode",
disagg_mode,
"--disaggregation-bootstrap-port",
f"1234{worker_id[-1]}",
"--host",
"0.0.0.0",
"--disaggregation-transfer-backend",
"nixl",
]
)
if disagg_mode == "prefill":
command.extend(["--port", "40000"])
# Set environment variables # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -85,9 +154,18 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -85,9 +154,18 @@ class DynamoWorkerProcess(ManagedProcess):
# intermittent failures # intermittent failures
env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_HEALTH_CHECK_ENABLED"] = "false"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = str(system_port) env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port) env["DYN_HTTP_PORT"] = str(frontend_port)
# Configure health check based on worker type
health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready)
]
if disagg_mode is None or disagg_mode == "decode":
health_check_urls.append(
(f"http://localhost:{frontend_port}/v1/models", check_models_api)
)
# TODO: Have the managed process take a command name explicitly to distinguish # TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command. # between processes started with the same command.
log_dir = f"{request.node.name}_{worker_id}" log_dir = f"{request.node.name}_{worker_id}"
...@@ -103,10 +181,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -103,10 +181,7 @@ class DynamoWorkerProcess(ManagedProcess):
super().__init__( super().__init__(
command=command, command=command,
env=env, env=env,
health_check_urls=[ health_check_urls=health_check_urls,
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{system_port}/health", self.is_ready),
],
timeout=300, timeout=300,
display_output=True, display_output=True,
terminate_existing=False, terminate_existing=False,
...@@ -140,316 +215,270 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -140,316 +215,270 @@ class DynamoWorkerProcess(ManagedProcess):
return False return False
@pytest.mark.timeout(235) # 3x average @pytest.mark.timeout(230) # 3x average
def test_request_migration_sglang_worker_failure( def test_request_migration_sglang_aggregated(
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with migration support using SGLang. End-to-end test for aggregated worker request migration.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
Timing (Last Run: 2025-12-09): ~58s total Parameters:
- Engine initialization: ~22s (Worker1: 12s, Worker2: 10s) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (request + migration): ~21s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~15s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
# Allocate ports to avoid conflicts with parallel tests # Step 1: Start the frontend
worker1_system_port = allocate_port(9100)
worker2_system_port = allocate_port(9200)
# Step 1: Start the frontend (allocates its own port)
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially # Step 2: Start 2 workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request, "worker1", frontend.frontend_port, migration_limit=migration_limit
"worker1",
system_port=worker1_system_port,
frontend_port=frontend.frontend_port,
) as worker1: ) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker2",
system_port=worker2_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
) as worker2: ) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Worker 2 PID: {worker2.get_pid()}")
# Step 3: Send the request # Step 3: Run migration test
request_thread, response_list = start_completion_request( run_migration_test(
frontend.frontend_port frontend,
) worker1,
worker2,
# Step 4: Use polling to determine which worker received the request receiving_pattern="New Request ID: ",
worker, worker_name = determine_request_receiving_worker( migration_limit=migration_limit,
worker1, worker2, receiving_pattern="New Request ID: " immediate_kill=immediate_kill,
) use_chat_completion=(request_api == "chat"),
stream=stream,
# Step 5: Kill the worker that has the request
logger.info(
f"Killing {worker_name} with PID {worker.get_pid()} processing the request"
)
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0)
# Step 6: Validate the completion response
validate_completion_response(request_thread, response_list)
# Step 7: Verify migration occurred
verify_migration_occurred(frontend)
# Step 8: Verify migration metrics
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
) )
@pytest.mark.timeout(235) # 3x average @pytest.mark.skip(reason="Cannot reliably migrate at Prefill that finish < 1 ms")
@pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented") @pytest.mark.xfail(strict=False, reason="Prefill migration not yet supported")
def test_request_migration_sglang_graceful_shutdown( @pytest.mark.timeout(230) # 3x average
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models def test_request_migration_sglang_prefill(
request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with graceful shutdown and migration support using SGLang. End-to-end test for prefill worker request migration in disaggregated mode.
This test verifies that when a worker receives a graceful shutdown signal (SIGTERM)
during request processing, the system can handle the shutdown gracefully and migrate
the request to another worker. Unlike the abrupt kill test, this simulates a more
controlled shutdown scenario where the worker has time to clean up and notify the
system about its shutdown.
Timing (Last Run: 2025-12-09): ~58s total (estimated, similar to worker_failure)
- Engine initialization: ~22s (Worker1: 12s, Worker2: 10s)
- Test execution (request + graceful shutdown + migration): ~21s
- Teardown: ~15s
"""
# Allocate ports to avoid conflicts with parallel tests Setup: 1 decode worker + 2 prefill workers
worker1_system_port = allocate_port(9100)
worker2_system_port = allocate_port(9200)
# Step 1: Start the frontend (allocates its own port) Parameters:
with DynamoFrontendProcess(request) as frontend: immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially # Step 2: Start decode worker first (required for prefill workers to connect)
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker1", "worker0",
system_port=worker1_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
) as worker1: disagg_mode="decode",
logger.info(f"Worker 1 PID: {worker1.get_pid()}") ) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 3: Start 2 prefill workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker1",
system_port=worker2_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
) as worker2: disagg_mode="prefill",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") ) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
# Step 3: Send the request
request_thread, response_list = start_completion_request( with DynamoWorkerProcess(
frontend.frontend_port request,
) "worker2",
frontend.frontend_port,
# Step 4: Use polling to determine which worker received the request migration_limit=migration_limit,
worker, worker_name = determine_request_receiving_worker( disagg_mode="prefill",
worker1, worker2, receiving_pattern="New Request ID: " ) as prefill2:
) logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
# Step 5: Gracefully shutdown the worker that has the request # Step 4: Run migration test
logger.info( run_migration_test(
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" frontend,
) prefill1,
terminate_process_tree( prefill2,
worker.get_pid(), immediate_kill=False, timeout=10 receiving_pattern="New Request ID: ",
) migration_limit=migration_limit,
immediate_kill=immediate_kill,
# Step 6: Validate the completion response use_chat_completion=(request_api == "chat"),
validate_completion_response(request_thread, response_list) stream=stream,
use_long_prompt=True,
# Step 7: Verify migration occurred during graceful shutdown )
verify_migration_occurred(frontend)
# Step 8: Verify migration metrics
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
)
@pytest.mark.timeout(135) # 3x average @pytest.mark.skip(reason="KV cache transfer may fail")
def test_no_request_migration_sglang_worker_failure( @pytest.mark.timeout(230) # 3x average
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models def test_request_migration_sglang_kv_transfer(
request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with migration disabled using SGLang. End-to-end test for request migration during KV transfer in disaggregated mode.
This test verifies that when migration is disabled (migration_limit=0) and a worker Setup: 1 prefill worker + 2 decode workers
is killed during request processing, the request fails as expected without migration.
This is the opposite behavior of test_request_migration_sglang_worker_failure.
Timing (Last Run: 2025-12-09): ~38s total Parameters:
- Engine initialization: ~23s (Worker1: 13s, Worker2: 10s) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (failure validation): <1s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~15s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
# Allocate ports to avoid conflicts with parallel tests # Step 1: Start the frontend
worker1_system_port = allocate_port(9100) with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
worker2_system_port = allocate_port(9200)
# Step 1: Start the frontend (allocates its own port)
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially with migration disabled # Step 2: Start prefill worker first
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker1", "worker0",
system_port=worker1_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
migration_limit=0, disagg_mode="prefill",
) as worker1: ) as prefill_worker:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start 2 decode workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker1",
system_port=worker2_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
migration_limit=0, disagg_mode="decode",
) as worker2: ) as decode1:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
# Step 3: Send the request with DynamoWorkerProcess(
request_thread, response_list = start_completion_request( request,
frontend.frontend_port "worker2",
) frontend.frontend_port,
migration_limit=migration_limit,
# Step 4: Use polling to determine which worker received the request disagg_mode="decode",
worker, worker_name = determine_request_receiving_worker( ) as decode2:
worker1, worker2, receiving_pattern="New Request ID: " logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
)
# Step 4: Run migration test
# Step 5: Kill the worker that has the request run_migration_test(
logger.info( frontend,
f"Killing {worker_name} with PID {worker.get_pid()} processing the request" decode1,
) decode2,
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) receiving_pattern="New Request ID: ",
migration_limit=migration_limit,
# Step 6: Validate the completion response - should fail without migration immediate_kill=immediate_kill,
try: use_chat_completion=(request_api == "chat"),
validate_completion_response(request_thread, response_list) stream=stream,
pytest.fail( use_long_prompt=True,
"Request succeeded unexpectedly when migration was disabled"
) )
except AssertionError as e:
assert "Request failed with status 500: " in str(
e
), f"Unexpected request error message: {e}"
# Step 7: Verify migration did NOT occur - should fail
try:
verify_migration_occurred(frontend)
pytest.fail(
"Migration verification unexpectedly passed when migration was disabled"
)
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(
e
), f"Unexpected migration message: {e}"
@pytest.mark.timeout(135) # 3x average @pytest.mark.timeout(230) # 3x average
@pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented") def test_request_migration_sglang_decode(
def test_no_request_migration_sglang_graceful_shutdown( request,
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with graceful shutdown and migration disabled using SGLang. End-to-end test for decode worker request migration in disaggregated mode.
This test verifies that when migration is disabled (migration_limit=0) and a worker Setup: 1 prefill worker + 2 decode workers
receives a graceful shutdown signal (SIGTERM) during request processing, the request
fails as expected without migration. This is the opposite behavior of
test_request_migration_sglang_graceful_shutdown.
Timing (Last Run: 2025-12-09): ~38s total (estimated, similar to no_migration_worker_failure) Parameters:
- Engine initialization: ~23s (Worker1: 13s, Worker2: 10s) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (graceful shutdown + failure validation): <1s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~15s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
if not stream:
pytest.skip(
"Decode test requires streaming to wait for response before stopping worker"
)
# Allocate ports to avoid conflicts with parallel tests # Step 1: Start the frontend
worker1_system_port = allocate_port(9100) with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
worker2_system_port = allocate_port(9200)
# Step 1: Start the frontend (allocates its own port)
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially with migration disabled # Step 2: Start prefill worker first
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker1", "worker0",
system_port=worker1_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
migration_limit=0, disagg_mode="prefill",
) as worker1: ) as prefill_worker:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start 2 decode workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker1",
system_port=worker2_system_port, frontend.frontend_port,
frontend_port=frontend.frontend_port, migration_limit=migration_limit,
migration_limit=0, disagg_mode="decode",
) as worker2: ) as decode1:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
# Step 3: Send the request with DynamoWorkerProcess(
request_thread, response_list = start_completion_request( request,
frontend.frontend_port "worker2",
) frontend.frontend_port,
migration_limit=migration_limit,
# Step 4: Use polling to determine which worker received the request disagg_mode="decode",
worker, worker_name = determine_request_receiving_worker( ) as decode2:
worker1, worker2, receiving_pattern="New Request ID: " logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
)
# Step 4: Run migration test
# Step 5: Gracefully shutdown the worker that has the request run_migration_test(
logger.info( frontend,
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" decode1,
) decode2,
terminate_process_tree( receiving_pattern="New Request ID: ",
worker.get_pid(), immediate_kill=False, timeout=10 migration_limit=migration_limit,
) immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"),
# Step 6: Validate the completion response - should fail without migration stream=stream,
try: wait_for_new_response_before_stop=True,
validate_completion_response(request_thread, response_list)
pytest.fail(
"Request succeeded unexpectedly when migration was disabled"
)
except AssertionError as e:
assert "Request failed with status 500: " in str(
e
), f"Unexpected request error message: {e}"
# Step 7: Verify migration did NOT occur - should fail
try:
verify_migration_occurred(frontend)
pytest.fail(
"Migration verification unexpectedly passed when migration was disabled"
) )
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(
e
), f"Unexpected migration message: {e}"
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Test Execution Times (Last Run: 2025-12-09): Test Execution Times (Last Run: 2026-01-12):
- test_request_migration_trtllm_worker_failure: ~95s (gpu_1) - test_request_migration_trtllm_aggregated: ~95s
- test_request_migration_trtllm_graceful_shutdown: ~95s (gpu_1, skipped) - test_request_migration_trtllm_prefill: N/A
- test_no_request_migration_trtllm_worker_failure: ~60s (gpu_1) - test_request_migration_trtllm_kv_transfer: N/A
- test_no_request_migration_trtllm_graceful_shutdown: ~60s (gpu_1, skipped) - test_request_migration_trtllm_decode: N/A
- Total: ~155s (0:02:35) for enabled tests
""" """
import logging import logging
...@@ -17,19 +16,12 @@ import shutil ...@@ -17,19 +16,12 @@ import shutil
import pytest import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess, terminate_process_tree from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_models_api from tests.utils.payloads import check_models_api
from tests.utils.port_utils import allocate_port, deallocate_port from tests.utils.port_utils import allocate_port, deallocate_port
# Import utilities from the refactored utils module # Customized utils for migration tests
from .utils import ( from .utils import DynamoFrontendProcess, run_migration_test
DynamoFrontendProcess,
determine_request_receiving_worker,
start_completion_request,
validate_completion_response,
verify_migration_metrics,
verify_migration_occurred,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,12 +31,60 @@ pytestmark = [ ...@@ -39,12 +31,60 @@ pytestmark = [
pytest.mark.e2e, pytest.mark.e2e,
pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
pytest.mark.post_merge, # post_merge to pinpoint failure commit pytest.mark.post_merge, # post_merge to pinpoint failure commit
pytest.mark.parametrize(
"migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
),
pytest.mark.parametrize(
"immediate_kill",
[
pytest.param(True, id="worker_failure"),
pytest.param(
False,
id="graceful_shutdown",
marks=pytest.mark.xfail(
strict=False, reason="TRT-LLM graceful shutdown not yet implemented"
),
),
],
),
pytest.mark.parametrize(
"request_api",
[
pytest.param("chat"),
pytest.param(
"completion",
marks=pytest.mark.skip(reason="Behavior unverified yet"),
),
],
),
pytest.mark.parametrize(
"stream",
[
pytest.param(True, id="stream"),
pytest.param(
False,
id="unary",
marks=pytest.mark.skip(reason="Behavior unverified yet"),
),
],
),
pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True), pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
] ]
class DynamoWorkerProcess(ManagedProcess): class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with TRT-LLM backend""" """Process manager for Dynamo worker with TRT-LLM backend
Supports both aggregated mode (single worker) and disaggregated mode
(separate prefill and decode workers).
Args:
request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
mode: "prefill_and_decode" for aggregated, "prefill" or "decode" for disaggregated
"""
def __init__( def __init__(
self, self,
...@@ -52,13 +92,16 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -52,13 +92,16 @@ class DynamoWorkerProcess(ManagedProcess):
worker_id: str, worker_id: str,
frontend_port: int, frontend_port: int,
migration_limit: int = 3, migration_limit: int = 3,
mode: str = "prefill_and_decode",
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.frontend_port = frontend_port self.system_port = allocate_port(9100)
self.mode = mode
# Allocate system port for this worker # Prefill workers require migration_limit=0 (no KV cache migration support)
system_port = allocate_port(9100) if mode == "prefill":
self.system_port = system_port logging.info("Prefill worker - setting migration_limit to 0")
migration_limit = 0
command = [ command = [
"python3", "python3",
...@@ -67,18 +110,32 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -67,18 +110,32 @@ class DynamoWorkerProcess(ManagedProcess):
"--model", "--model",
FAULT_TOLERANCE_MODEL_NAME, FAULT_TOLERANCE_MODEL_NAME,
"--disaggregation-mode", "--disaggregation-mode",
"prefill_and_decode", mode,
"--free-gpu-memory-fraction",
"0.45",
"--max-seq-len", "--max-seq-len",
"8192", "8192",
"--max-num-tokens",
"8192",
"--free-gpu-memory-fraction",
"0.15", # avoid validation error on TRT-LLM available memory checks
"--migration-limit", "--migration-limit",
str(migration_limit), str(migration_limit),
] ]
if mode != "prefill_and_decode":
config_file = (
f"test_request_migration_trtllm_config_{self.system_port}.yaml"
)
with open(config_file, "w") as f:
f.write(
"cache_transceiver_config:\n backend: DEFAULT\n max_tokens_in_buffer: 8192\n"
)
f.write("disable_overlap_scheduler: true\n")
f.write("kv_cache_config:\n max_tokens: 8192\n")
command += ["--extra-engine-args", config_file]
# Set environment variables # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -86,7 +143,17 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -86,7 +143,17 @@ class DynamoWorkerProcess(ManagedProcess):
# intermittent failures # intermittent failures
env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_HEALTH_CHECK_ENABLED"] = "false"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = str(system_port) env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port)
# Configure health check based on worker type
health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready)
]
if mode in ["decode", "prefill_and_decode"]:
health_check_urls.append(
(f"http://localhost:{frontend_port}/v1/models", check_models_api)
)
# TODO: Have the managed process take a command name explicitly to distinguish # TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command. # between processes started with the same command.
...@@ -103,10 +170,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -103,10 +170,7 @@ class DynamoWorkerProcess(ManagedProcess):
super().__init__( super().__init__(
command=command, command=command,
env=env, env=env,
health_check_urls=[ health_check_urls=health_check_urls,
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{system_port}/health", self.is_ready),
],
timeout=300, timeout=300,
display_output=True, display_output=True,
terminate_existing=False, terminate_existing=False,
...@@ -139,279 +203,268 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -139,279 +203,268 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(290) # 3x average @pytest.mark.timeout(290) # 3x average
def test_request_migration_trtllm_worker_failure( def test_request_migration_trtllm_aggregated(
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with migration support using TRT-LLM. End-to-end test for aggregated worker request migration.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
Timing (Last Run: 2025-12-09): ~95s total (2 workers at 45% GPU each) Parameters:
- Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (request + migration): ~40s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~3s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially # Step 2: Start 2 workers
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker2", frontend.frontend_port request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
) as worker2: ) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Worker 2 PID: {worker2.get_pid()}")
# Step 3: Send the request # Step 3: Run migration test
request_thread, response_list = start_completion_request( run_migration_test(
frontend.frontend_port frontend,
) worker1,
worker2,
# Step 4: Use polling to determine which worker received the request receiving_pattern="New Request ID: ",
worker, worker_name = determine_request_receiving_worker( migration_limit=migration_limit,
worker1, worker2, receiving_pattern="New Request ID: " immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"),
stream=stream,
) )
# Step 5: Kill the worker that has the request
logger.info(
f"Killing {worker_name} with PID {worker.get_pid()} processing the request"
)
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0)
# Step 6: Validate the completion response
validate_completion_response(request_thread, response_list)
# Step 7: Verify migration occurred
verify_migration_occurred(frontend)
# Step 8: Verify migration metrics
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
)
@pytest.mark.xfail(strict=False, reason="Prefill migration not yet supported")
@pytest.mark.timeout(290) # 3x average @pytest.mark.timeout(350) # 3x average
@pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented") def test_request_migration_trtllm_prefill(
def test_request_migration_trtllm_graceful_shutdown( request,
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with graceful shutdown and migration support using TRT-LLM. End-to-end test for prefill worker request migration in disaggregated mode.
This test verifies that when a worker receives a graceful shutdown signal (SIGTERM) Setup: 1 decode worker + 2 prefill workers
during request processing, the system can handle the shutdown gracefully and migrate
the request to another worker. Unlike the abrupt kill test, this simulates a more Parameters:
controlled shutdown scenario where the worker has time to clean up and notify the immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
system about its shutdown. migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
request_api: "chat" for chat completion API, "completion" for completion API
Timing (Last Run: 2025-12-09): ~95s total (2 workers at 45% GPU each) stream: True for streaming, False for non-streaming
- Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential)
- Test execution (request + graceful migration): ~40s
- Teardown: ~3s
""" """
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially # Step 2: Start decode worker first (required for prefill workers to connect)
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: with DynamoWorkerProcess(
logger.info(f"Worker 1 PID: {worker1.get_pid()}") request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
mode="decode",
) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 3: Start 2 prefill workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker2", frontend.frontend_port request,
) as worker2: "worker1",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") frontend.frontend_port,
migration_limit=migration_limit,
# Step 3: Send the request mode="prefill",
request_thread, response_list = start_completion_request( ) as prefill1:
frontend.frontend_port logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
)
with DynamoWorkerProcess(
# Step 4: Use polling to determine which worker received the request request,
worker, worker_name = determine_request_receiving_worker( "worker2",
worker1, worker2, receiving_pattern="New Request ID: " frontend.frontend_port,
) migration_limit=migration_limit,
mode="prefill",
# Step 5: Gracefully shutdown the worker that has the request ) as prefill2:
logger.info( logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request"
) # Step 4: Run migration test
terminate_process_tree( run_migration_test(
worker.get_pid(), immediate_kill=False, timeout=10 frontend,
) prefill1,
prefill2,
# Step 6: Validate the completion response receiving_pattern="Prefill Request ID: ",
validate_completion_response(request_thread, response_list) migration_limit=migration_limit,
immediate_kill=immediate_kill,
# Step 7: Verify migration occurred during graceful shutdown use_chat_completion=(request_api == "chat"),
verify_migration_occurred(frontend) stream=stream,
use_long_prompt=True,
# Step 8: Verify migration metrics )
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
)
@pytest.mark.timeout(185) # 3x average @pytest.mark.skip(reason="Decode worker can get stuck downloading kv cache")
def test_no_request_migration_trtllm_worker_failure( @pytest.mark.timeout(350) # 3x average
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models def test_request_migration_trtllm_kv_transfer(
request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with migration disabled using TRT-LLM. End-to-end test for request migration during KV transfer in disaggregated mode.
This test verifies that when migration is disabled (migration_limit=0) and a worker Setup: 1 prefill worker + 2 decode workers
is killed during request processing, the request fails as expected without migration.
This is the opposite behavior of test_request_migration_trtllm_worker_failure.
Timing (Last Run: 2025-12-09): ~60s total (2 workers at 45% GPU each) Parameters:
- Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (request failure): ~6s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~2s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially with migration disabled # Step 2: Start prefill worker first
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker1", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=0, migration_limit=migration_limit,
) as worker1: mode="prefill",
logger.info(f"Worker 1 PID: {worker1.get_pid()}") ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start 2 decode workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=0, migration_limit=migration_limit,
) as worker2: mode="decode",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
# Step 3: Send the request
request_thread, response_list = start_completion_request( with DynamoWorkerProcess(
frontend.frontend_port request,
) "worker2",
frontend.frontend_port,
# Step 4: Use polling to determine which worker received the request migration_limit=migration_limit,
worker, worker_name = determine_request_receiving_worker( mode="decode",
worker1, worker2, receiving_pattern="New Request ID: " ) as decode2:
) logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
# Step 5: Kill the worker that has the request # Step 4: Run migration test
logger.info( run_migration_test(
f"Killing {worker_name} with PID {worker.get_pid()} processing the request" frontend,
) decode1,
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) decode2,
receiving_pattern="Decode Request ID: ",
# Step 6: Validate the completion response - should fail without migration migration_limit=migration_limit,
try: immediate_kill=immediate_kill,
validate_completion_response(request_thread, response_list) use_chat_completion=(request_api == "chat"),
pytest.fail( stream=stream,
"Request succeeded unexpectedly when migration was disabled" use_long_prompt=True,
) )
except AssertionError as e:
assert "Request failed with status 500: " in str(
e
), f"Unexpected request error message: {e}"
# Step 7: Verify migration did NOT occur - should fail
try:
verify_migration_occurred(frontend)
pytest.fail(
"Migration verification unexpectedly passed when migration was disabled"
)
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(
e
), f"Unexpected migration message: {e}"
@pytest.mark.timeout(185) # 3x average @pytest.mark.timeout(350) # 3x average
@pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented") def test_request_migration_trtllm_decode(
def test_no_request_migration_trtllm_graceful_shutdown( request,
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with graceful shutdown and migration disabled using TRT-LLM. End-to-end test for decode worker request migration in disaggregated mode.
This test verifies that when migration is disabled (migration_limit=0) and a worker Setup: 1 prefill worker + 2 decode workers
receives a graceful shutdown signal (SIGTERM) during request processing, the request
fails as expected without migration. This is the opposite behavior of
test_request_migration_trtllm_graceful_shutdown.
Timing (Last Run: 2025-12-09): ~60s total (2 workers at 45% GPU each) Parameters:
- Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (graceful shutdown failure): ~6s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~2s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
if not stream:
pytest.skip(
"Decode test requires streaming to wait for response before stopping worker"
)
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially with migration disabled # Step 2: Start prefill worker first
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker1", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=0, migration_limit=migration_limit,
) as worker1: mode="prefill",
logger.info(f"Worker 1 PID: {worker1.get_pid()}") ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start 2 decode workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=0, migration_limit=migration_limit,
) as worker2: mode="decode",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
# Step 3: Send the request
request_thread, response_list = start_completion_request( with DynamoWorkerProcess(
frontend.frontend_port request,
) "worker2",
frontend.frontend_port,
# Step 4: Use polling to determine which worker received the request migration_limit=migration_limit,
worker, worker_name = determine_request_receiving_worker( mode="decode",
worker1, worker2, receiving_pattern="New Request ID: " ) as decode2:
) logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
# Step 5: Gracefully shutdown the worker that has the request # Step 4: Run migration test
logger.info( run_migration_test(
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" frontend,
) decode1,
terminate_process_tree( decode2,
worker.get_pid(), immediate_kill=False, timeout=10 receiving_pattern="Decode Request ID: ",
) migration_limit=migration_limit,
immediate_kill=immediate_kill,
# Step 6: Validate the completion response - should fail without migration use_chat_completion=(request_api == "chat"),
try: stream=stream,
validate_completion_response(request_thread, response_list) wait_for_new_response_before_stop=True,
pytest.fail(
"Request succeeded unexpectedly when migration was disabled"
)
except AssertionError as e:
assert "Request failed with status 500: " in str(
e
), f"Unexpected request error message: {e}"
# Step 7: Verify migration did NOT occur - should fail
try:
verify_migration_occurred(frontend)
pytest.fail(
"Migration verification unexpectedly passed when migration was disabled"
) )
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(
e
), f"Unexpected migration message: {e}"
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Test Execution Times (Last Run: 2025-12-09): Test Execution Times (Last Run: 2026-01-09):
- test_request_migration_vllm_worker_failure: ~90s (gpu_1) - test_request_migration_vllm_aggregated: ~95s
- test_request_migration_vllm_graceful_shutdown: ~80s (gpu_1) - test_request_migration_vllm_prefill: N/A
- test_no_request_migration_vllm_worker_failure: ~75s (gpu_1) - test_request_migration_vllm_kv_transfer: N/A
- test_no_request_migration_vllm_graceful_shutdown: ~75s (gpu_1) - test_request_migration_vllm_decode: ~115s
- Total: 318.73s (0:05:18)
""" """
import logging import logging
...@@ -17,19 +16,12 @@ import shutil ...@@ -17,19 +16,12 @@ import shutil
import pytest import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess, terminate_process_tree from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_models_api from tests.utils.payloads import check_models_api
from tests.utils.port_utils import allocate_port, deallocate_port from tests.utils.port_utils import allocate_port, deallocate_port
# Import utilities from the refactored utils module # Customized utils for migration tests
from .utils import ( from .utils import DynamoFrontendProcess, run_migration_test
DynamoFrontendProcess,
determine_request_receiving_worker,
start_completion_request,
validate_completion_response,
verify_migration_metrics,
verify_migration_occurred,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,12 +31,50 @@ pytestmark = [ ...@@ -39,12 +31,50 @@ pytestmark = [
pytest.mark.e2e, pytest.mark.e2e,
pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
pytest.mark.post_merge, # post_merge to pinpoint failure commit pytest.mark.post_merge, # post_merge to pinpoint failure commit
pytest.mark.parametrize(
"migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
),
pytest.mark.parametrize(
"immediate_kill", [True, False], ids=["worker_failure", "graceful_shutdown"]
),
pytest.mark.parametrize(
"request_api",
[
pytest.param("chat"),
pytest.param(
"completion",
marks=pytest.mark.skip(reason="Behavior unverified yet"),
),
],
),
pytest.mark.parametrize(
"stream",
[
pytest.param(True, id="stream"),
pytest.param(
False,
id="unary",
marks=pytest.mark.skip(reason="Behavior unverified yet"),
),
],
),
pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True), pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
] ]
class DynamoWorkerProcess(ManagedProcess): class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with vLLM backend""" """Process manager for Dynamo worker with vLLM backend
Supports both aggregated mode (single worker) and disaggregated mode
(separate prefill and decode workers).
Args:
request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
is_prefill: None for aggregated mode, True for prefill worker, False for decode worker
"""
def __init__( def __init__(
self, self,
...@@ -52,13 +82,10 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -52,13 +82,10 @@ class DynamoWorkerProcess(ManagedProcess):
worker_id: str, worker_id: str,
frontend_port: int, frontend_port: int,
migration_limit: int = 3, migration_limit: int = 3,
is_prefill: bool | None = None,
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.frontend_port = frontend_port self.system_port = allocate_port(9100)
# Allocate system port for this worker
system_port = allocate_port(9100)
self.system_port = system_port
command = [ command = [
"python3", "python3",
...@@ -67,25 +94,41 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -67,25 +94,41 @@ class DynamoWorkerProcess(ManagedProcess):
"--model", "--model",
FAULT_TOLERANCE_MODEL_NAME, FAULT_TOLERANCE_MODEL_NAME,
"--enforce-eager", "--enforce-eager",
"--gpu-memory-utilization",
"0.45",
"--max-model-len", "--max-model-len",
"8192", "8192", # input + output tokens
"--max-num-seqs",
"1", # number of requests at a time
"--num-gpu-blocks-override", # limit total KV cache allocation
"512", # 8192 tokens x 1 context / 16 tokens per block = 512 blocks
"--gpu-memory-utilization",
"0.15", # avoid assertion error on vLLM available memory checks
"--migration-limit", "--migration-limit",
str(migration_limit), str(migration_limit),
] ]
if is_prefill is True:
command.append("--is-prefill-worker")
elif is_prefill is False:
command.append("--is-decode-worker")
# Set environment variables # Set environment variables
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
env[ # Set KV event and NIXL ports based on worker mode
"DYN_VLLM_KV_EVENT_PORT" # All workers need unique NIXL side channel ports for KV transfer
] = f"2008{worker_id[-1]}" # TODO: use dynamic port allocation
env[ env[
"VLLM_NIXL_SIDE_CHANNEL_PORT" "VLLM_NIXL_SIDE_CHANNEL_PORT"
] = f"560{worker_id[-1]}" # TODO: use dynamic port allocation ] = f"560{worker_id[-1]}" # TODO: use dynamic port allocation
if is_prefill is False:
# Decode workers don't publish KV events
env.pop("DYN_VLLM_KV_EVENT_PORT", None)
else:
# Aggregated mode and prefill workers publish KV events
env[
"DYN_VLLM_KV_EVENT_PORT"
] = f"2008{worker_id[-1]}" # TODO: use dynamic port allocation
env["DYN_LOG"] = "debug" env["DYN_LOG"] = "debug"
# Disable canary health check - these tests expect full control over requests # Disable canary health check - these tests expect full control over requests
# sent to the workers where canary health check intermittently sends dummy # sent to the workers where canary health check intermittently sends dummy
...@@ -93,9 +136,19 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -93,9 +136,19 @@ class DynamoWorkerProcess(ManagedProcess):
# intermittent failures # intermittent failures
env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_HEALTH_CHECK_ENABLED"] = "false"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = str(system_port) env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port) env["DYN_HTTP_PORT"] = str(frontend_port)
# Configure health check based on worker type
health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready)
]
if is_prefill is None or is_prefill is False:
# aggregated or decode
health_check_urls.append(
(f"http://localhost:{frontend_port}/v1/models", check_models_api)
)
# TODO: Have the managed process take a command name explicitly to distinguish # TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command. # between processes started with the same command.
log_dir = f"{request.node.name}_{worker_id}" log_dir = f"{request.node.name}_{worker_id}"
...@@ -111,10 +164,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -111,10 +164,7 @@ class DynamoWorkerProcess(ManagedProcess):
super().__init__( super().__init__(
command=command, command=command,
env=env, env=env,
health_check_urls=[ health_check_urls=health_check_urls,
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{system_port}/health", self.is_ready),
],
timeout=300, timeout=300,
display_output=True, display_output=True,
terminate_existing=False, terminate_existing=False,
...@@ -149,265 +199,287 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -149,265 +199,287 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.timeout(290) # 3x average @pytest.mark.timeout(290) # 3x average
def test_request_migration_vllm_worker_failure( def test_request_migration_vllm_aggregated(
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with migration support. End-to-end test for aggregated worker request migration.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
Timing (Last Run: 2025-12-09): ~90s total Parameters:
- Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (request + migration): ~48s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~2s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially (each allocates its own system_port) # Step 2: Start 2 workers
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker2", frontend.frontend_port request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
) as worker2: ) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Worker 2 PID: {worker2.get_pid()}")
# Step 3: Send the request # Step 3: Run migration test
request_thread, response_list = start_completion_request( run_migration_test(
frontend.frontend_port frontend,
) worker1,
worker2,
# Step 4: Use polling to determine which worker received the request receiving_pattern="Decode Request ID: ",
worker, worker_name = determine_request_receiving_worker( migration_limit=migration_limit,
worker1, worker2, receiving_pattern="Decode Request ID: " immediate_kill=immediate_kill,
) use_chat_completion=(request_api == "chat"),
stream=stream,
# Step 5: Kill the worker that has the request
logger.info(
f"Killing {worker_name} with PID {worker.get_pid()} processing the request"
) )
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0)
# Step 6: Validate the completion response
validate_completion_response(request_thread, response_list)
# Step 7: Verify migration occurred @pytest.mark.xfail(strict=False, reason="Prefill migration not yet supported")
verify_migration_occurred(frontend) @pytest.mark.timeout(350) # 3x average
def test_request_migration_vllm_prefill(
# Step 8: Verify migration metrics request,
verify_migration_metrics( runtime_services_dynamic_ports,
frontend.frontend_port, expected_ongoing_request_count=1 set_ucx_tls_no_mm,
) predownload_models,
migration_limit,
immediate_kill,
@pytest.mark.timeout(280) # 3x average request_api,
def test_request_migration_vllm_graceful_shutdown( stream,
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models
): ):
""" """
End-to-end test for worker fault tolerance with graceful shutdown and migration support. End-to-end test for prefill worker request migration in disaggregated mode.
This test verifies that when a worker receives a graceful shutdown signal (SIGTERM) Setup: 1 decode worker + 2 prefill workers
during request processing, the system can handle the shutdown gracefully and migrate
the request to another worker. Unlike the abrupt kill test, this simulates a more Parameters:
controlled shutdown scenario where the worker has time to clean up and notify the immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
system about its shutdown. migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
request_api: "chat" for chat completion API, "completion" for completion API
Timing (Last Run: 2025-12-09): ~80s total stream: True for streaming, False for non-streaming
- Engine initialization: ~40s (Worker1: 20s, Worker2: 20s)
- Test execution (graceful shutdown + migration): ~38s
- Teardown: ~2s
""" """
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially (each allocates its own system_port) # Step 2: Start decode worker first (required for prefill workers to connect)
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: with DynamoWorkerProcess(
logger.info(f"Worker 1 PID: {worker1.get_pid()}") request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False,
) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 3: Start 2 prefill workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker2", frontend.frontend_port request,
) as worker2: "worker1",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") frontend.frontend_port,
migration_limit=migration_limit,
# Step 3: Send the request is_prefill=True,
request_thread, response_list = start_completion_request( ) as prefill1:
frontend.frontend_port logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
)
with DynamoWorkerProcess(
# Step 4: Use polling to determine which worker received the request request,
worker, worker_name = determine_request_receiving_worker( "worker2",
worker1, worker2, receiving_pattern="Decode Request ID: " frontend.frontend_port,
) migration_limit=migration_limit,
is_prefill=True,
# Step 5: Gracefully shutdown the worker that has the request ) as prefill2:
logger.info( logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request"
) # Step 4: Run migration test
terminate_process_tree( run_migration_test(
worker.get_pid(), immediate_kill=False, timeout=10 frontend,
) prefill1,
prefill2,
# Step 6: Validate the completion response receiving_pattern="Prefill Request ID: ",
validate_completion_response(request_thread, response_list) migration_limit=migration_limit,
immediate_kill=immediate_kill,
# Step 7: Verify migration occurred during graceful shutdown use_chat_completion=(request_api == "chat"),
verify_migration_occurred(frontend) stream=stream,
use_long_prompt=True,
# Step 8: Verify migration metrics )
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
)
@pytest.mark.timeout(150) # 3x average @pytest.mark.xfail(
def test_no_request_migration_vllm_worker_failure( strict=False,
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models reason=(
"Migration reuses the same request_id for vLLM, but the prefill worker's "
"KV cache still holds the request due to delay_free_blocks in disaggregated mode. "
"With chat completions API, prefix cache hits on chat template tokens cause "
"an assertion error in vLLM's KV cache manager (save_new_computed_blocks expects "
"no new computed blocks for existing requests)."
),
)
@pytest.mark.timeout(350) # 3x average
def test_request_migration_vllm_kv_transfer(
request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with migration disabled. End-to-end test for request migration during KV transfer in disaggregated mode.
This test verifies that when migration is disabled (migration_limit=0) and a worker Setup: 1 prefill worker + 2 decode workers
is killed during request processing, the request fails as expected without migration.
This is the opposite behavior of test_request_migration_vllm_worker_failure.
Timing (Last Run: 2025-12-09): ~75s total Parameters:
- Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (failure validation): ~33s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~2s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially with migration disabled (each allocates its own system_port) # Step 2: Start prefill worker first
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=0 request,
) as worker1: "worker0",
logger.info(f"Worker 1 PID: {worker1.get_pid()}") frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True,
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start 2 decode workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker2", frontend.frontend_port, migration_limit=0 request,
) as worker2: "worker1",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") frontend.frontend_port,
migration_limit=migration_limit,
# Step 3: Send the request is_prefill=False,
request_thread, response_list = start_completion_request( ) as decode1:
frontend.frontend_port logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
)
with DynamoWorkerProcess(
# Step 4: Use polling to determine which worker received the request request,
worker, worker_name = determine_request_receiving_worker( "worker2",
worker1, worker2, receiving_pattern="Decode Request ID: " frontend.frontend_port,
) migration_limit=migration_limit,
is_prefill=False,
# Step 5: Kill the worker that has the request ) as decode2:
logger.info( logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
f"Killing {worker_name} with PID {worker.get_pid()} processing the request"
) # Step 4: Run migration test
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) run_migration_test(
frontend,
# Step 6: Validate the completion response - should fail without migration decode1,
try: decode2,
validate_completion_response(request_thread, response_list) receiving_pattern="Decode Request ID: ",
pytest.fail( migration_limit=migration_limit,
"Request succeeded unexpectedly when migration was disabled" immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"),
stream=stream,
use_long_prompt=True,
) )
except AssertionError as e:
assert "Request failed with status 500: " in str(
e
), f"Unexpected request error message: {e}"
# Step 7: Verify migration did NOT occur - should fail
try:
verify_migration_occurred(frontend)
pytest.fail(
"Migration verification unexpectedly passed when migration was disabled"
)
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(
e
), f"Unexpected migration message: {e}"
@pytest.mark.timeout(140) # 3x average @pytest.mark.xfail(
def test_no_request_migration_vllm_graceful_shutdown( strict=False,
request, runtime_services_dynamic_ports, set_ucx_tls_no_mm, predownload_models reason=(
"Migration reuses the same request_id for vLLM, but the prefill worker's "
"KV cache still holds the request due to delay_free_blocks in disaggregated mode. "
"With chat completions API, prefix cache hits on chat template tokens cause "
"an assertion error in vLLM's KV cache manager (save_new_computed_blocks expects "
"no new computed blocks for existing requests)."
),
)
@pytest.mark.timeout(350) # 3x average
def test_request_migration_vllm_decode(
request,
runtime_services_dynamic_ports,
set_ucx_tls_no_mm,
predownload_models,
migration_limit,
immediate_kill,
request_api,
stream,
): ):
""" """
End-to-end test for worker fault tolerance with graceful shutdown and migration disabled. End-to-end test for decode worker request migration in disaggregated mode.
This test verifies that when migration is disabled (migration_limit=0) and a worker Setup: 1 prefill worker + 2 decode workers
receives a graceful shutdown signal (SIGTERM) during request processing, the request
fails as expected without migration. This is the opposite behavior of
test_request_migration_vllm_graceful_shutdown.
Timing (Last Run: 2025-12-09): ~75s total Parameters:
- Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
- Test execution (graceful shutdown validation): ~33s migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
- Teardown: ~2s request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming
""" """
if not stream:
pytest.skip(
"Decode test requires streaming to wait for response before stopping worker"
)
# Step 1: Start the frontend (allocates its own frontend_port) # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers sequentially with migration disabled (each allocates its own system_port) # Step 2: Start prefill worker first
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=0 request,
) as worker1: "worker0",
logger.info(f"Worker 1 PID: {worker1.get_pid()}") frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True,
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start 2 decode workers
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, "worker2", frontend.frontend_port, migration_limit=0 request,
) as worker2: "worker1",
logger.info(f"Worker 2 PID: {worker2.get_pid()}") frontend.frontend_port,
migration_limit=migration_limit,
# Step 3: Send the request is_prefill=False,
request_thread, response_list = start_completion_request( ) as decode1:
frontend.frontend_port logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
)
with DynamoWorkerProcess(
# Step 4: Use polling to determine which worker received the request request,
worker, worker_name = determine_request_receiving_worker( "worker2",
worker1, worker2, receiving_pattern="Decode Request ID: " frontend.frontend_port,
) migration_limit=migration_limit,
is_prefill=False,
# Step 5: Gracefully shutdown the worker that has the request ) as decode2:
logger.info( logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request"
) # Step 4: Run migration test
terminate_process_tree( run_migration_test(
worker.get_pid(), immediate_kill=False, timeout=10 frontend,
) decode1,
decode2,
# Step 6: Validate the completion response - should fail without migration receiving_pattern="Decode Request ID: ",
try: migration_limit=migration_limit,
validate_completion_response(request_thread, response_list) immediate_kill=immediate_kill,
pytest.fail( use_chat_completion=(request_api == "chat"),
"Request succeeded unexpectedly when migration was disabled" stream=stream,
) wait_for_new_response_before_stop=True,
except AssertionError as e:
assert "Request failed with status 500: " in str(
e
), f"Unexpected request error message: {e}"
# Step 7: Verify migration did NOT occur - should fail
try:
verify_migration_occurred(frontend)
pytest.fail(
"Migration verification unexpectedly passed when migration was disabled"
) )
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(
e
), f"Unexpected migration message: {e}"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import logging import logging
import threading import threading
import time import time
...@@ -12,7 +13,7 @@ from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME ...@@ -12,7 +13,7 @@ from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ( from tests.utils.managed_process import (
DynamoFrontendProcess as BaseDynamoFrontendProcess, DynamoFrontendProcess as BaseDynamoFrontendProcess,
) )
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess, terminate_process_tree
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,62 +21,259 @@ logger = logging.getLogger(__name__) ...@@ -20,62 +21,259 @@ logger = logging.getLogger(__name__)
class DynamoFrontendProcess(BaseDynamoFrontendProcess): class DynamoFrontendProcess(BaseDynamoFrontendProcess):
"""Fault-tolerance frontend wrapper (keeps env settings from the historical helper).""" """Fault-tolerance frontend wrapper (keeps env settings from the historical helper)."""
def __init__(self, request): def __init__(self, request, enforce_disagg: bool = False):
extra_env = { extra_env = {
"DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"), "DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"),
# These tests expect full control over requests sent to workers. The canary # These tests expect full control over requests sent to workers. The canary
# health check can inject extra requests and cause intermittent failures. # health check can inject extra requests and cause intermittent failures.
"DYN_HEALTH_CHECK_ENABLED": "false", "DYN_HEALTH_CHECK_ENABLED": "false",
} }
extra_args = []
if enforce_disagg:
extra_args.append("--enforce-disagg")
super().__init__( super().__init__(
request, request,
frontend_port=0, # allocate a free port (xdist-safe) frontend_port=0, # allocate a free port (xdist-safe)
router_mode="round-robin", router_mode="round-robin",
extra_args=extra_args if extra_args else None,
extra_env=extra_env, extra_env=extra_env,
terminate_existing=False, terminate_existing=False,
) )
def start_completion_request(frontend_port: int) -> tuple: def _parse_completion_sse_content(line: str) -> str | Exception | None:
"""
Parse an SSE line from the completions API and extract the text content.
Args:
line: Raw SSE line string
Returns:
str: The text content if found
Exception: If error event or parse error
None: If no content (e.g., [DONE] or empty)
"""
if line.startswith("event: error"):
return Exception(f"SSE error event received: {line}")
if not line.startswith("data: "):
return None # Skip non-data lines
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
return None
try:
chunk = json.loads(data_str)
text = chunk["choices"][0].get("text")
return text # May be None if no text content
except Exception as e:
return Exception(f"Error parsing response chunk: {e}")
def start_completion_request(
frontend_port: int, stream: bool, use_long_prompt: bool = False
) -> tuple:
""" """
Start a long-running completion request in a separate thread. Start a long-running completion request in a separate thread.
Responses are processed internally to extract content. First entry is (None, start_time)
to mark when request was sent. Subsequent entries contain extracted content or exceptions.
Args: Args:
frontend_port: Port where the frontend is running frontend_port: Port where the frontend is running
stream: Whether to use streaming responses
use_long_prompt: Whether to use a long prompt (~8000 tokens)
Returns: Returns:
tuple: (request_thread, response_list) tuple: (request_thread, response_list) where response_list contains
(str | None | Exception, float) tuples.
- For streaming: each entry is (content_word, timestamp)
- For non-streaming: single entry is (full_content, timestamp)
""" """
response_list = [] # Thread safe is not required as only one thread writes to it response_list: list[tuple[str | None | Exception, float]] = []
def send_request(): def send_request():
prompt = "Tell me a long long long story about yourself?" prompt = "Tell me a long long long story about yourself?"
max_tokens = 8000 if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
timeout = 240 # Extended timeout for long request timeout = 240 # Extended timeout for long request
payload = { payload = {
"model": FAULT_TOLERANCE_MODEL_NAME, "model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt, "prompt": prompt,
"max_tokens": max_tokens, "stream": stream,
} }
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
logger.info( logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}" f"Sending completion request (stream={stream}) with prompt: '{prompt[:50]}...'"
) )
response_list.append((None, time.time())) # start timestamp
try: try:
response = requests.post( with requests.post(
f"http://localhost:{frontend_port}/v1/completions", f"http://localhost:{frontend_port}/v1/completions",
headers=headers, headers=headers,
json=payload, json=payload,
timeout=timeout, timeout=timeout,
) stream=stream,
logger.info(f"Received response with status code: {response.status_code}") ) as response:
response_list.append(response) logger.info(
f"Received response with status code: {response.status_code}"
)
if response.status_code != 200:
response_list.append(
(
Exception(
f"Request failed with status {response.status_code}: {response.text}"
),
time.time(),
)
)
return
if stream:
for line in response.iter_lines():
if line:
content = _parse_completion_sse_content(
line.decode("utf-8")
)
if content is not None:
response_list.append((content, time.time()))
else:
try:
content = response.json()["choices"][0]["text"]
response_list.append((content, time.time()))
except Exception as e:
response_list.append(
(Exception(f"Error parsing response: {e}"), time.time())
)
except Exception as e:
logger.error(f"Request failed with error: {e}")
response_list.append((e, time.time()))
request_thread = threading.Thread(target=send_request, daemon=True)
request_thread.start()
return request_thread, response_list
def _parse_chat_completion_sse_content(line: str) -> str | Exception | None:
"""
Parse an SSE line and extract the content.
Args:
line: Raw SSE line string
Returns:
str: The content delta if found
Exception: If error event or parse error
None: If no content (e.g., [DONE] or empty delta)
"""
if line.startswith("event: error"):
return Exception(f"SSE error event received: {line}")
if not line.startswith("data: "):
return None # Skip non-data lines
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
return None
try:
chunk = json.loads(data_str)
content = chunk["choices"][0]["delta"].get("content")
return content # May be None if delta has no content
except Exception as e:
return Exception(f"Error parsing response chunk: {e}")
def start_chat_completion_request(
frontend_port: int, stream: bool, use_long_prompt: bool = False
) -> tuple:
"""
Start a long-running chat completion request in a separate thread.
Responses are processed internally to extract content. First entry is (None, start_time)
to mark when request was sent. Subsequent entries contain extracted content or exceptions.
Args:
frontend_port: Port where the frontend is running
stream: Whether to use streaming responses
use_long_prompt: Whether to use a long prompt (~8000 tokens)
Returns:
tuple: (request_thread, response_list) where response_list contains
(str | None | Exception, float) tuples.
- For streaming: each entry is (content_word, timestamp)
- For non-streaming: single entry is (full_content, timestamp)
"""
response_list: list[tuple[str | None | Exception, float]] = []
def send_request():
prompt = "Tell me a long long long story about yourself?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
timeout = 240 # Extended timeout for long request
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"messages": [{"role": "user", "content": prompt}],
"stream": stream,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending chat completion request (stream={stream}) with prompt: '{prompt[:50]}...'"
)
response_list.append((None, time.time())) # start timestamp
try:
with requests.post(
f"http://localhost:{frontend_port}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout,
stream=stream,
) as response:
logger.info(
f"Received response with status code: {response.status_code}"
)
if response.status_code != 200:
response_list.append(
(
Exception(
f"Request failed with status {response.status_code}: {response.text}"
),
time.time(),
)
)
return
if stream:
for line in response.iter_lines():
if line:
content = _parse_chat_completion_sse_content(
line.decode("utf-8")
)
if content is not None:
response_list.append((content, time.time()))
else:
try:
content = response.json()["choices"][0]["message"]["content"]
response_list.append((content, time.time()))
except Exception as e:
response_list.append(
(Exception(f"Error parsing response: {e}"), time.time())
)
except Exception as e: except Exception as e:
logger.error(f"Request failed with error: {e}") logger.error(f"Request failed with error: {e}")
response_list.append((e, time.time()))
request_thread = threading.Thread(target=send_request, daemon=True) request_thread = threading.Thread(target=send_request, daemon=True)
request_thread.start() request_thread.start()
...@@ -99,6 +297,8 @@ def determine_request_receiving_worker( ...@@ -99,6 +297,8 @@ def determine_request_receiving_worker(
""" """
worker1_results: list[bool] = [] worker1_results: list[bool] = []
worker2_results: list[bool] = [] worker2_results: list[bool] = []
# Event to signal all threads to exit when one finds the pattern
found_event = threading.Event()
# Poll both workers in parallel # Poll both workers in parallel
def poll_worker(worker: ManagedProcess, result_list: list[bool]): def poll_worker(worker: ManagedProcess, result_list: list[bool]):
...@@ -107,13 +307,14 @@ def determine_request_receiving_worker( ...@@ -107,13 +307,14 @@ def determine_request_receiving_worker(
max_iterations = max_wait_ms // poll_interval_ms max_iterations = max_wait_ms // poll_interval_ms
iteration = 0 iteration = 0
while iteration < max_iterations: while iteration < max_iterations and not found_event.is_set():
# Check if the worker logs contain 'New Request ID:' message # Check if the worker logs contain the pattern
try: try:
with open(worker.log_path, "r") as f: with open(worker.log_path, "r") as f:
log_content = f.read() log_content = f.read()
if receiving_pattern in log_content: if receiving_pattern in log_content:
result_list.append(True) result_list.append(True)
found_event.set() # Signal other thread to exit
return return
except Exception as e: except Exception as e:
logger.error(f"Could not read log file {worker.log_path}: {e}") logger.error(f"Could not read log file {worker.log_path}: {e}")
...@@ -150,44 +351,77 @@ def determine_request_receiving_worker( ...@@ -150,44 +351,77 @@ def determine_request_receiving_worker(
pytest.fail("Neither worker received the request") pytest.fail("Neither worker received the request")
def validate_completion_response( def wait_for_response(
request_thread: threading.Thread, response_list: list response_list: list[tuple[str | None | Exception, float]],
num_responses: int = 5,
max_wait_time: float = 10.0,
) -> None: ) -> None:
""" """
Wait for and validate the completion response after worker failure. Block until num_responses new responses are received or max_wait_time is reached.
Args: Args:
request_thread: The thread running the completion request response_list: List being populated by background thread
response_list: List containing the response from the request num_responses: Number of new responses to wait for (default 5)
max_wait_time: Maximum time to wait in seconds (default 10s)
"""
initial_len = len(response_list)
target_len = initial_len + num_responses
poll_interval = 0.001 # 1ms
elapsed = 0.0
while elapsed < max_wait_time:
if len(response_list) >= target_len:
return
time.sleep(poll_interval)
elapsed += poll_interval
logger.warning(
f"Only received {len(response_list) - initial_len}/{num_responses} new responses within {max_wait_time}s"
)
def validate_response(
request_thread: threading.Thread,
response_list: list[tuple[str | None | Exception, float]],
validate_delay: bool = True,
) -> None:
"""
Wait for and validate the response after migration.
Checks that delay before each response is reasonable (covers both TTFT and TPOT).
Args:
request_thread: The thread running the request
response_list: List of (content_string | None | Exception, timestamp) tuples.
Content is already parsed - no SSE format parsing needed.
validate_delay: Whether to validate delay before each response.
""" """
request_thread.join(timeout=240) request_thread.join(timeout=240)
if request_thread.is_alive(): assert not request_thread.is_alive(), "Request did not complete within 240 seconds"
pytest.fail("Request did not complete within 240 seconds")
# Get the response assert len(response_list) > 0, "Missing first entry with start timestamp"
if len(response_list) != 1: assert response_list[0][0] is None, "First entry should be start timestamp only"
pytest.fail(f"Received {len(response_list)} responses, expected 1") prev_timestamp = response_list[0][1]
response = response_list[0]
assert ( response_words: list[str] = []
response.status_code == 200 for res, timestamp in response_list[1:]:
), f"Request failed with status {response.status_code}: {response.text}" delay = timestamp - prev_timestamp
if delay > 2.0 and validate_delay:
# Cold workers can take longer on first token - only warn but don't fail
logger.warning(f"Delay before response: {delay:.3f} secs")
# Capture cases like migration is blocked by engine graceful shutdown
assert delay <= 6.0, f"Delay before response > 6 secs, got {delay:.3f} secs"
prev_timestamp = timestamp
try: assert res is not None, "Response entry should not be None"
data = response.json() if isinstance(res, Exception):
except ValueError: raise res
pytest.fail(f"Response is not valid JSON: {response.text}")
# Validate OpenAI completion response structure # Content is already parsed - just collect it
assert "choices" in data, f"Response missing 'choices' field: {data}" response_words.append(res)
assert len(data["choices"]) > 0, f"Response has empty 'choices': {data}"
assert "text" in data["choices"][0], f"Response choice missing 'text' field: {data}"
assert data["choices"][0]["text"], f"Response text is empty: {data}"
logger.info( logger.info(
f"Received valid completion response: {data['choices'][0]['text'][:100]}..." f"Received {len(response_words)} response(s): {''.join(response_words)[:100]}..."
) )
logger.info("Request completed successfully")
def verify_migration_occurred(frontend_process: DynamoFrontendProcess) -> None: def verify_migration_occurred(frontend_process: DynamoFrontendProcess) -> None:
...@@ -198,11 +432,18 @@ def verify_migration_occurred(frontend_process: DynamoFrontendProcess) -> None: ...@@ -198,11 +432,18 @@ def verify_migration_occurred(frontend_process: DynamoFrontendProcess) -> None:
frontend_process: The frontend process to check logs for frontend_process: The frontend process to check logs for
""" """
log_path = frontend_process.log_path log_path = frontend_process.log_path
try: log_content = ""
with open(log_path, "r") as f: for i in range(10):
log_content = f.read() try:
except Exception as e: with open(log_path, "r") as f:
pytest.fail(f"Could not read frontend log file {log_path}: {e}") log_content = f.read()
except Exception as e:
pytest.fail(f"Could not read frontend log file {log_path}: {e}")
# Make sure this message is captured if any with the polling
if "Cannot recreate stream: " in log_content:
break
time.sleep(0.005)
assert ( assert (
"Stream disconnected... recreating stream..." in log_content "Stream disconnected... recreating stream..." in log_content
), "'Stream disconnected... recreating stream...' message not found in logs" ), "'Stream disconnected... recreating stream...' message not found in logs"
...@@ -293,3 +534,85 @@ def verify_migration_metrics( ...@@ -293,3 +534,85 @@ def verify_migration_metrics(
f"Expected at least {expected_new_request_count} new_request migrations, " f"Expected at least {expected_new_request_count} new_request migrations, "
f"but got {new_request_count}" f"but got {new_request_count}"
) )
def run_migration_test(
frontend: DynamoFrontendProcess,
worker1: ManagedProcess,
worker2: ManagedProcess,
receiving_pattern: str,
migration_limit: int,
immediate_kill: bool,
use_chat_completion: bool,
stream: bool,
use_long_prompt: bool = False,
wait_for_new_response_before_stop: bool = False,
) -> None:
"""
Run the common migration test flow after frontend and workers are started.
Args:
frontend: The frontend process
worker1: First worker process
worker2: Second worker process
receiving_pattern: Log pattern to identify which worker received the request
migration_limit: Migration limit setting (0 = disabled)
immediate_kill: True for immediate kill, False for graceful shutdown
use_chat_completion: Whether to use chat completion API (True) or completion API (False)
stream: Whether to use streaming responses
use_long_prompt: Whether to use long prompt (for prefill tests)
wait_for_new_response_before_stop: Whether to wait for response before stopping (for decode tests)
"""
# Step 1: Send the request
if use_chat_completion:
request_thread, response_list = start_chat_completion_request(
frontend.frontend_port, stream=stream, use_long_prompt=use_long_prompt
)
else:
request_thread, response_list = start_completion_request(
frontend.frontend_port, stream=stream, use_long_prompt=use_long_prompt
)
# Step 2: Determine which worker received the request
worker, worker_name = determine_request_receiving_worker(
worker1, worker2, receiving_pattern=receiving_pattern
)
# Step 3: Optionally wait for new response before stop (for decode tests)
if wait_for_new_response_before_stop:
wait_for_response(response_list)
# Step 4: Stop the worker (kill or graceful shutdown)
if immediate_kill:
logger.info(f"Killing {worker_name} with PID {worker.get_pid()}")
terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0)
else:
logger.info(
f"Gracefully shutting down {worker_name} with PID {worker.get_pid()}"
)
terminate_process_tree(worker.get_pid(), immediate_kill=False, timeout=10)
# Step 5: Validate response based on migration setting
if migration_limit > 0:
validate_response(request_thread, response_list, validate_delay=stream)
verify_migration_occurred(frontend)
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
)
else:
try:
validate_response(request_thread, response_list, validate_delay=stream)
pytest.fail("Request succeeded unexpectedly when migration was disabled")
except Exception as e:
# Request failed as expected - verify it's a known error type
error_str = str(e)
assert (
"SSE error event received:" in error_str
or "Request failed with status" in error_str
), f"Unexpected error: {e}"
try:
verify_migration_occurred(frontend)
pytest.fail("Migration unexpectedly occurred when disabled")
except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment