Unverified Commit 36f03d40 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

test: fault tolerance tests (#1444)


Signed-off-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent fb213a2f
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from tests.utils.deployment_graph import (
DeploymentGraph,
Payload,
chat_completions_response_handler,
)
# Initial payload used for testing
# initial deployment readiness.
text_prompt = "Tell me a short joke about AI."
text_payload = Payload(
payload_chat={
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
"role": "user",
"content": text_prompt, # Shorter prompt
}
],
"max_tokens": 150,
"temperature": 0.1,
# "seed": 10,
"ignore_eos": True,
"min_tokens": 150,
"stream": False,
},
expected_log=[],
expected_response=["AI"],
)
# Each Deployment Graph contains
# the dynamo serve module and configuration as well
# as the endpoint for interaction
deployment_graphs = {
"agg-tp-1-dp-1": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="/workspace/tests/fault_tolerance/configs/agg_tp_1_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_1, pytest.mark.vllm],
),
text_payload,
),
"agg-tp-1-dp-8": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="/workspace/tests/fault_tolerance/configs/agg_tp_1_dp_8.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_8, pytest.mark.vllm],
),
text_payload,
),
"agg-tp-1-dp-4": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="/workspace/tests/fault_tolerance/configs/agg_tp_1_dp_4.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_4, pytest.mark.vllm],
),
text_payload,
),
"agg-tp-2-dp-1": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="/workspace/tests/fault_tolerance/configs/agg_tp_2_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
),
text_payload,
),
"agg-tp-2-dp-2": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="/workspace/tests/fault_tolerance/configs/agg_tp_2_dp_2.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_4, pytest.mark.vllm],
),
text_payload,
),
"agg-tp-2-dp-4": (
DeploymentGraph(
module="graphs.agg:Frontend",
config="/workspace/tests/fault_tolerance/configs/agg_tp_2_dp_4.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_8, pytest.mark.vllm],
),
text_payload,
),
"disagg-p-tp-1-dp-1-d-tp-1-dp-1": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="/workspace/tests/fault_tolerance/configs/disagg_p_tp_1_dp_1_d_tp_1_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
),
text_payload,
),
"disagg-p-tp-1-dp-4-d-tp-4-dp-1": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="/workspace/tests/fault_tolerance/configs/disagg_p_tp_1_dp_4_d_tp_4_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_8, pytest.mark.vllm],
),
text_payload,
),
"disagg-p-tp-2-dp-2-d-tp-4-dp-1": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="/workspace/tests/fault_tolerance/configs/disagg_p_tp_2_dp_2_d_tp_4_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_8, pytest.mark.vllm],
),
text_payload,
),
"disagg-p-tp-2-dp-1-d-tp-4-dp-1": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="/workspace/tests/fault_tolerance/configs/disagg_p_tp_2_dp_1_d_tp_4_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_8, pytest.mark.vllm],
),
text_payload,
),
"disagg-p-tp-1-dp-2-d-tp-2-dp-1": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="/workspace/tests/fault_tolerance/configs/disagg_p_tp_1_dp_2_d_tp_2_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_4, pytest.mark.vllm],
),
text_payload,
),
"disagg-p-tp-1-dp-1-d-tp-2-dp-1": (
DeploymentGraph(
module="graphs.disagg:Frontend",
config="/workspace/tests/fault_tolerance/configs/disagg_p_tp_1_dp_1_d_tp_2_dp_1.yaml",
directory="/workspace/examples/llm",
endpoints=["v1/chat/completions"],
response_handlers=[chat_completions_response_handler],
marks=[pytest.mark.gpu_4, pytest.mark.vllm],
),
text_payload,
),
}
# Each failure scenaro contains a list of failure injections
# Each failure injection has a time in seconds after the pervious injection and
# a list of failures to inject including the number of failures for each type.
# Failures are currently process termination.
#
# Example:
#
# "prefill_worker": [[30, [("dynamo_prefillworker", 1)]]],
#
# terminates 1 prefill worker after 30 seconds
failure_scenarios = {
"decode_worker": [[30, [("dynamo_vllmworker", 1)]]],
"prefill_worker": [[30, [("dynamo_prefillworker", 1)]]],
"frontend": [[30, [("dynamo_frontend", 1)]]],
"processor": [[30, [("dynamo_processor", 1)]]],
"vllm_worker": [[30, [("vllm_worker", 1)]]],
"none": [],
}
@pytest.fixture(params=list(failure_scenarios.keys()))
def failures(request):
return failure_scenarios[request.param]
@pytest.fixture(params=list(deployment_graphs.keys()))
def deployment_graph_test(request):
"""
Fixture that provides different deployment graph test configurations.
"""
return deployment_graphs[request.param]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
from contextlib import contextmanager
from multiprocessing import Process
import psutil
import pytest
from tests.fault_tolerance.client import client
from tests.fault_tolerance.parse_results import main as parse_results
from tests.fault_tolerance.scenarios import ( # noqa: F401
deployment_graph_test,
failures,
)
from tests.fault_tolerance.utils.circus_controller import CircusController
from tests.fault_tolerance.utils.metrics import nvidia_smi # noqa: F401
from tests.fault_tolerance.utils.metrics import worker_metrics # noqa: F401
from tests.serve.test_dynamo_serve import DynamoServeProcess
from tests.utils.managed_process import terminate_process_tree
def _set_deployment_args(request, max_num_seqs):
decode_worker_name = "VllmWorker"
args = {}
if max_num_seqs is not None:
args[f"--{decode_worker_name}.max_num_seqs"] = max_num_seqs
return args
def _list_vllm_worker_processes():
processes = []
for ps_process in psutil.process_iter(["name", "cmdline"]):
try:
if "from multiprocessing.spawn import spawn_main;" in " ".join(
ps_process.cmdline()
):
processes.append(ps_process.pid)
except Exception:
pass
return processes
@contextmanager
def _clients(
logger,
num_clients,
request,
deployment_graph,
server_process,
payload,
requests_per_client,
input_token_length,
output_token_length,
max_retries,
):
procs = []
for i in range(num_clients):
procs.append(
Process(
target=client,
args=(
deployment_graph,
server_process,
payload,
request.node.name,
i,
requests_per_client,
input_token_length,
output_token_length,
max_retries,
),
)
)
procs[-1].start()
yield procs
for proc in procs:
logger.debug(f"{proc} waiting for join")
proc.join()
logger.debug(f"{proc} joined")
def _inject_failures(failures, logger): # noqa: F811
circus_controller = CircusController.from_state_file("dynamo")
for failure_time, component in failures:
time.sleep(failure_time)
for component_name, number in component:
logger.info(f"Injecting failure for: {component_name}")
if "dynamo" in component_name:
result = circus_controller.client.call(
{"command": "list", "properties": {"name": f"{component_name}"}}
)
if result["status"] == "error":
logger.warning(f"component {component_name} not found {result}")
continue
num_processes = len(result["pids"])
if number is None:
number = num_processes
for x in range(number):
pid = result["pids"][x % num_processes]
logger.info(f"Terminating {component_name} Pid {pid}")
terminate_process_tree(pid, logger, immediate_kill=True)
elif "vllm" in component_name:
vllm_processes = _list_vllm_worker_processes()
num_processes = len(vllm_processes)
if number is None:
number = len(vllm_processes)
for x in range(number):
pid = vllm_processes[x % num_processes]
terminate_process_tree(pid, logger, immediate_kill=True)
circus_controller.close()
global_result_list = []
@pytest.fixture(autouse=True)
def results_table(request):
yield
parse_results(logs_dir=None, log_paths=[request.node.name], tablefmt="fancy")
global_result_list.append(request.node.name)
@pytest.fixture(autouse=True, scope="session")
def results_summary():
yield
parse_results(logs_dir=None, log_paths=global_result_list, tablefmt="fancy")
@pytest.mark.e2e
@pytest.mark.slow
def test_worker_failure(
deployment_graph_test, # noqa: F811
request,
runtime_services,
num_clients,
requests_per_client,
worker_metrics, # noqa: F811
respawn,
failures, # noqa: F811
input_token_length,
output_token_length,
max_num_seqs,
max_retries,
display_dynamo_output,
nvidia_smi, # noqa: F811
separate_process_logs,
hf_hub_offline,
):
"""
Test dynamo serve deployments with injected failures
"""
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
deployment_graph, payload = deployment_graph_test
if hf_hub_offline:
os.environ["HF_HUB_OFFLINE"] = "1"
else:
if "HF_HUB_OFFLINE" in os.environ:
del os.environ["HF_HUB_OFFLINE"]
if respawn:
os.environ["DYN_CIRCUS_RESPAWN"] = "1"
else:
if "DYN_CIRCUS_RESPAWN" in os.environ:
del os.environ["DYN_CIRCUS_RESPAWN"]
if separate_process_logs:
os.environ["DYN_CIRCUS_LOG_DIR"] = os.path.abspath(request.node.name)
deployment_args = _set_deployment_args(request, max_num_seqs)
with DynamoServeProcess(
deployment_graph,
request,
display_output=display_dynamo_output,
args=deployment_args,
) as server_process:
server_process.wait_for_ready(payload)
with _clients(
logger,
num_clients,
request,
deployment_graph,
server_process,
payload,
requests_per_client,
input_token_length,
output_token_length,
max_retries,
):
_inject_failures(failures, logger)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from pathlib import Path
from typing import List, Optional
from circus.client import CircusClient
from circus.exc import CallError
logger = logging.getLogger(__name__)
class CircusController:
"""A circus client implementation for Dynamo"""
def __init__(self, endpoint: str):
"""Initialize connection to arbiter.
Args:
endpoint: The circus endpoint (e.g., tcp://127.0.0.1:54927)
"""
self.endpoint = endpoint
self.client = CircusClient(endpoint=endpoint, timeout=15.0)
@classmethod
def from_state_file(cls, namespace: str) -> "CircusController":
"""
Create a CircusController from a Dynamo state file.
Args:
namespace: The Dynamo namespace
Returns:
CircusController instance
Raises:
FileNotFoundError: If state file doesn't exist
ValueError: If no endpoint found in state file
"""
state_file = (
Path(
os.environ.get("DYN_LOCAL_STATE_DIR", Path.home() / ".dynamo" / "state")
)
/ f"{namespace}.json"
)
if not state_file.exists():
raise FileNotFoundError(f"State file not found: {state_file}")
with open(state_file, "r") as f:
state = json.load(f)
endpoint = state.get("circus_endpoint")
if not endpoint:
raise ValueError(f"No endpoint found in state file: {state_file}")
return cls(endpoint)
async def _get_watcher_processes(self, name: str) -> Optional[int]:
"""
Get number of processes for a watcher.
Args:
name: The name of the watcher
Returns:
Number of processes for the watcher. Returns None operation fails.
"""
try:
response = self.client.send_message("numprocesses", name=name)
return int(response.get("numprocesses", 0))
except (CallError, Exception) as e:
logger.error(f"Failed to get process count for {name}: {e}")
return None
async def _list_watchers(self) -> List[str]:
"""
List all watchers managed by circus.
Returns:
List of watcher names. Returns None if the list operation fails.
"""
try:
response = self.client.send_message("list")
return response.get("watchers", [])
except (CallError, Exception) as e:
logger.error(f"Failed to list watchers: {e}")
return []
def close(self) -> None:
"""Close the connection to the arbiter."""
if hasattr(self, "client"):
self.client.stop()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import os
from datetime import datetime
from multiprocessing import Process
import psutil
import pytest
from dynamo.runtime import dynamo_worker
from tests.fault_tolerance.utils.circus_controller import CircusController
from tests.utils.managed_process import ManagedProcess
def run_metrics_process(log_dir):
asyncio.run(get_metrics(log_dir))
@dynamo_worker()
async def get_metrics(runtime, log_dir):
# Log # processes
# Log # metrics per vllm worker
circus_controller = None
pipeline = None
log_path = os.path.join(log_dir, "watcher.log.txt")
with open(log_path, "w") as log:
while True:
try:
await asyncio.sleep(0.5)
if not circus_controller:
circus_controller = CircusController.from_state_file("dynamo")
if not pipeline:
pipeline = (
await runtime.namespace("dynamo")
.component("VllmWorker")
.endpoint("load_metrics")
.client()
)
watchers = []
for x in await circus_controller._list_watchers():
result = circus_controller.client.call(
{"command": "list", "properties": {"name": f"{x}"}}
)
watchers.append((x, result))
metrics = []
for x in pipeline.instance_ids():
async for worker_metric in await pipeline.direct(None, x):
metrics.append((x, worker_metric.data()))
vllm_processes = []
for ps_process in psutil.process_iter(["name", "cmdline"]):
try:
if "from multiprocessing.spawn import spawn_main;" in " ".join(
ps_process.cmdline()
):
vllm_processes.append(ps_process.pid)
except (psutil.NoSuchProcess, psutil.AccessDenied):
# Process may have terminated or become inaccessible during iteration
pass
record = {
"time": datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
"watchers": watchers,
"metrics": metrics,
"vllm_processes": vllm_processes,
}
log.write(json.dumps(record) + "\n")
log.flush()
except Exception as e:
record = {
"time": datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
"watchers": [],
"metrics": [],
"vllm_processes": [],
"error": str(e),
}
log.write(json.dumps(record) + "\n")
log.flush()
@pytest.fixture
def worker_metrics(request):
process = Process(target=run_metrics_process, args=(request.node.name,))
process.start()
yield
process.kill()
class NvidiaSMI(ManagedProcess):
def __init__(self, request):
super().__init__(
command=[
"nvidia-smi",
"dmon",
"--select=puc",
],
health_check_ports=[],
terminate_existing=True,
display_output=False,
data_dir=None,
log_dir=request.node.name,
)
@pytest.fixture
def nvidia_smi(request):
with NvidiaSMI(request) as nvidia_smi_process:
yield nvidia_smi_process
...@@ -251,20 +251,35 @@ deployment_graphs = { ...@@ -251,20 +251,35 @@ deployment_graphs = {
class DynamoServeProcess(ManagedProcess): class DynamoServeProcess(ManagedProcess):
def __init__(self, graph: DeploymentGraph, request, port=8000, timeout=900): def __init__(
self,
graph: DeploymentGraph,
request,
port=8000,
timeout=900,
display_output=True,
args=None,
):
command = ["dynamo", "serve", graph.module] command = ["dynamo", "serve", graph.module]
if graph.config: if graph.config:
command.extend(["-f", os.path.join(graph.directory, graph.config)]) command.extend(["-f", os.path.join(graph.directory, graph.config)])
command.extend(["--Frontend.port", str(port)])
if args:
for k, v in args.items():
command.extend([f"{k}", f"{v}"])
health_check_urls = []
health_check_ports = []
# Handle multimodal deployments differently # Handle multimodal deployments differently
if "multimodal" in graph.directory: if "multimodal" in graph.directory:
# Set DYNAMO_PORT environment variable for multimodal # Set DYNAMO_PORT environment variable for multimodal
env = os.environ.copy() env = os.environ.copy()
env["DYNAMO_PORT"] = str(port) env["DYNAMO_PORT"] = str(port)
health_check_urls = []
# Don't add health check on port since multimodal uses DYNAMO_PORT # Don't add health check on port since multimodal uses DYNAMO_PORT
health_check_ports = []
else: else:
# Regular LLM deployments # Regular LLM deployments
command.extend(["--Frontend.port", str(port)]) command.extend(["--Frontend.port", str(port)])
...@@ -275,16 +290,22 @@ class DynamoServeProcess(ManagedProcess): ...@@ -275,16 +290,22 @@ class DynamoServeProcess(ManagedProcess):
env = None env = None
self.port = port self.port = port
self.graph = graph
super().__init__( super().__init__(
command=command, command=command,
timeout=timeout, timeout=timeout,
display_output=True, display_output=display_output,
working_dir=graph.directory, working_dir=graph.directory,
health_check_ports=health_check_ports, health_check_ports=health_check_ports,
health_check_urls=health_check_urls, health_check_urls=health_check_urls,
delayed_start=graph.delayed_start, delayed_start=graph.delayed_start,
stragglers=["http"], stragglers=["http"],
straggler_commands=[
"dynamo.sdk.cli.serve_dynamo",
"from multiprocessing.resource_tracker",
"from multiprocessing.spawn",
],
log_dir=request.node.name, log_dir=request.node.name,
env=env, # Pass the environment variables env=env, # Pass the environment variables
) )
...@@ -298,50 +319,10 @@ class DynamoServeProcess(ManagedProcess): ...@@ -298,50 +319,10 @@ class DynamoServeProcess(ManagedProcess):
return True return True
return False return False
def check_response(
@pytest.fixture( self, payload, response, response_handler, logger=logging.getLogger()
params=[ ):
pytest.param("agg", marks=[pytest.mark.vllm, pytest.mark.gpu_1]), assert response.status_code == 200, "Response Error"
pytest.param("agg_router", marks=[pytest.mark.vllm, pytest.mark.gpu_1]),
pytest.param("disagg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("disagg_router", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("multimodal_agg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("trtllm_agg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]),
pytest.param(
"trtllm_agg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]
),
pytest.param(
"trtllm_disagg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2]
),
pytest.param(
"trtllm_disagg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2]
),
# pytest.param("sglang", marks=[pytest.mark.sglang, pytest.mark.gpu_2]),
]
)
def deployment_graph_test(request):
"""
Fixture that provides different deployment graph test configurations.
"""
return deployment_graphs[request.param]
@pytest.mark.e2e
@pytest.mark.slow
def test_serve_deployment(deployment_graph_test, request, runtime_services):
"""
Test dynamo serve deployments with different graph configurations.
"""
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
deployment_graph, payload = deployment_graph_test
def check_response(response, response_handler):
assert response.status_code == 200, "Server is not healthy"
content = response_handler(response) content = response_handler(response)
logger.info("Received Content: %s", content) logger.info("Received Content: %s", content)
# Check for expected responses # Check for expected responses
...@@ -349,32 +330,25 @@ def test_serve_deployment(deployment_graph_test, request, runtime_services): ...@@ -349,32 +330,25 @@ def test_serve_deployment(deployment_graph_test, request, runtime_services):
for expected in payload.expected_response: for expected in payload.expected_response:
assert expected in content, "Expected '%s' not found in response" % expected assert expected in content, "Expected '%s' not found in response" % expected
with DynamoServeProcess(deployment_graph, request) as server_process: def wait_for_ready(self, payload, logger=logging.getLogger()):
first_success_pending = True url = f"http://localhost:{self.port}/{self.graph.endpoints[0]}"
for endpoint, response_handler in zip(
deployment_graph.endpoints, deployment_graph.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time() start_time = time.time()
retry_delay = 5 retry_delay = 5
elapsed = 0.0 elapsed = 0.0
request_body = ( logger.info("Waiting for Deployment Ready")
json_payload = (
payload.payload_chat payload.payload_chat
if endpoint == "v1/chat/completions" if self.graph.endpoints[0] == "v1/chat/completions"
else payload.payload_completions else payload.payload_completions
) )
# We can skip this while time.time() - start_time < self.graph.timeout:
while (
time.time() - start_time < deployment_graph.timeout
and first_success_pending
):
elapsed = time.time() - start_time elapsed = time.time() - start_time
try: try:
response = requests.post( response = requests.post(
url, url,
json=request_body, json=json_payload,
timeout=deployment_graph.timeout - elapsed, timeout=self.graph.timeout - elapsed,
) )
except (requests.RequestException, requests.Timeout) as e: except (requests.RequestException, requests.Timeout) as e:
logger.warning("Retrying due to Request failed: %s", e) logger.warning("Retrying due to Request failed: %s", e)
...@@ -405,24 +379,87 @@ def test_serve_deployment(deployment_graph_test, request, runtime_services): ...@@ -405,24 +379,87 @@ def test_serve_deployment(deployment_graph_test, request, runtime_services):
% (response.status_code, response.text) % (response.status_code, response.text)
) )
else: else:
check_response(response, response_handler)
first_success_pending = False
break break
else: else:
if first_success_pending:
logger.error( logger.error(
"Service did not return a successful response within %s s", "Service did not return a successful response within %s s",
deployment_graph.timeout, self.graph.timeout,
) )
pytest.fail( pytest.fail(
"Service did not return a successful response within %s s" "Service did not return a successful response within %s s"
% deployment_graph.timeout % self.graph.timeout
)
self.check_response(payload, response, self.graph.response_handlers[0], logger)
logger.info("Deployment Ready")
@pytest.fixture(
params=[
pytest.param("agg", marks=[pytest.mark.vllm, pytest.mark.gpu_1]),
pytest.param("agg_router", marks=[pytest.mark.vllm, pytest.mark.gpu_1]),
pytest.param("disagg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("disagg_router", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("multimodal_agg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]),
pytest.param("trtllm_agg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]),
pytest.param(
"trtllm_agg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]
),
pytest.param(
"trtllm_disagg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2]
),
pytest.param(
"trtllm_disagg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2]
),
# pytest.param("sglang", marks=[pytest.mark.sglang, pytest.mark.gpu_2]),
]
)
def deployment_graph_test(request):
"""
Fixture that provides different deployment graph test configurations.
"""
return deployment_graphs[request.param]
@pytest.mark.e2e
@pytest.mark.slow
def test_serve_deployment(deployment_graph_test, request, runtime_services):
"""
Test dynamo serve deployments with different graph configurations.
"""
# runtime_services is used to start nats and etcd
logger = logging.getLogger(request.node.name)
logger.info("Starting test_deployment")
deployment_graph, payload = deployment_graph_test
with DynamoServeProcess(deployment_graph, request) as server_process:
server_process.wait_for_ready(payload, logger)
for endpoint, response_handler in zip(
deployment_graph.endpoints, deployment_graph.response_handlers
):
url = f"http://localhost:{server_process.port}/{endpoint}"
start_time = time.time()
elapsed = 0.0
request_body = (
payload.payload_chat
if endpoint == "v1/chat/completions"
else payload.payload_completions
) )
for _ in range(payload.repeat_count): for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = requests.post( response = requests.post(
url, url,
json=request_body, json=request_body,
timeout=deployment_graph.timeout - elapsed, timeout=deployment_graph.timeout - elapsed,
) )
check_response(response, response_handler) server_process.check_response(
payload, response, response_handler, logger
)
...@@ -26,6 +26,45 @@ import psutil ...@@ -26,6 +26,45 @@ import psutil
import requests import requests
def terminate_process(process, logger=logging.getLogger(), immediate_kill=False):
try:
logger.info("Terminating PID: %s name: %s", process.pid, process.name())
if immediate_kill:
logger.info("Sending Kill: %s %s", process.pid, process.name())
process.kill()
else:
process.terminate()
except psutil.AccessDenied:
logger.warning("Access denied for PID %s", process.pid)
except psutil.NoSuchProcess:
logger.warning("PID %s no longer exists", process.pid)
def terminate_process_tree(
pid, logger=logging.getLogger(), immediate_kill=False, timeout=10
):
try:
parent = psutil.Process(pid)
for child in parent.children(recursive=True):
terminate_process(child, logger, immediate_kill)
terminate_process(parent, logger, immediate_kill)
for child in parent.children(recursive=True):
try:
child.wait(timeout)
except psutil.TimeoutExpired:
terminate_process(child, logger, immediate_kill=True)
try:
parent.wait(timeout)
except psutil.TimeoutExpired:
terminate_process(parent, logger, immediate_kill=True)
except psutil.NoSuchProcess:
# Process already terminated
pass
@dataclass @dataclass
class ManagedProcess: class ManagedProcess:
command: List[str] command: List[str]
...@@ -39,6 +78,7 @@ class ManagedProcess: ...@@ -39,6 +78,7 @@ class ManagedProcess:
data_dir: Optional[str] = None data_dir: Optional[str] = None
terminate_existing: bool = True terminate_existing: bool = True
stragglers: List[str] = field(default_factory=list) stragglers: List[str] = field(default_factory=list)
straggler_commands: List[str] = field(default_factory=list)
log_dir: str = os.getcwd() log_dir: str = os.getcwd()
_logger = logging.getLogger() _logger = logging.getLogger()
...@@ -78,7 +118,7 @@ class ManagedProcess: ...@@ -78,7 +118,7 @@ class ManagedProcess:
process.stdout.close() process.stdout.close()
if process.stdin: if process.stdin:
process.stdin.close() process.stdin.close()
self._terminate_process_tree(process.pid) terminate_process_tree(process.pid, self._logger)
process.wait() process.wait()
if self.data_dir: if self.data_dir:
self._remove_directory(self.data_dir) self._remove_directory(self.data_dir)
...@@ -86,7 +126,20 @@ class ManagedProcess: ...@@ -86,7 +126,20 @@ class ManagedProcess:
for ps_process in psutil.process_iter(["name", "cmdline"]): for ps_process in psutil.process_iter(["name", "cmdline"]):
try: try:
if ps_process.name() in self.stragglers: if ps_process.name() in self.stragglers:
self._terminate_process_tree(ps_process.pid) self._logger.info(
"Terminating Straggler %s %s", ps_process.name(), ps_process.pid
)
terminate_process_tree(ps_process.pid, self._logger)
for cmdline in self.straggler_commands:
if cmdline in " ".join(ps_process.cmdline()):
self._logger.info(
"Terminating Straggler Cmdline %s %s %s",
ps_process.name(),
ps_process.pid,
cmdline,
)
terminate_process_tree(ps_process.pid, self._logger)
except (psutil.NoSuchProcess, psutil.AccessDenied): except (psutil.NoSuchProcess, psutil.AccessDenied):
# Process may have terminated or become inaccessible during iteration # Process may have terminated or become inaccessible during iteration
pass pass
...@@ -202,34 +255,22 @@ class ManagedProcess: ...@@ -202,34 +255,22 @@ class ManagedProcess:
def _terminate_existing(self): def _terminate_existing(self):
if self.terminate_existing: if self.terminate_existing:
self._logger.info("Terminating Existing %s", self._command_name)
for proc in psutil.process_iter(["name", "cmdline"]): for proc in psutil.process_iter(["name", "cmdline"]):
if proc.name() == self._command_name or proc.name() in self.stragglers: if proc.name() == self._command_name or proc.name() in self.stragglers:
self._terminate_process_tree(proc.pid) self._logger.info(
"Terminating Existing %s %s", proc.name(), proc.pid
def _terminate_process(self, process):
try:
self._logger.info("Terminating %s", process)
process.terminate()
except psutil.AccessDenied:
self._logger.warning("Access denied for PID %s", process.pid)
except psutil.NoSuchProcess:
self._logger.warning("PID %s no longer exists", process.pid)
except psutil.TimeoutExpired:
self._logger.warning(
"PID %s did not terminate before timeout, killing", process.pid
) )
process.kill()
def _terminate_process_tree(self, pid): terminate_process_tree(proc.pid, self._logger)
try: for cmdline in self.straggler_commands:
parent = psutil.Process(pid) if cmdline in " ".join(proc.cmdline()):
for child in parent.children(recursive=True): self._logger.info(
self._terminate_process(child) "Terminating Existing CmdLine %s %s %s",
self._terminate_process(parent) proc.name(),
except psutil.NoSuchProcess: proc.pid,
# Process already terminated proc.cmdline(),
pass )
terminate_process_tree(proc.pid, self._logger)
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment