Unverified Commit 52c75363 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

test: fault injection tests for k8s (#3194)


Signed-off-by: default avatarnnshah1 <neelays@nvidia.com>
Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
Co-authored-by: default avatarnnshah1 <neelays@nvidia.com>
parent 116b9b43
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import random
import time
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict
import requests
from tests.utils.managed_deployment import ManagedDeployment
LOG_FORMAT = "[TEST] %(asctime)s %(levelname)s %(name)s: %(message)s"
DATE_FORMAT = "%Y-%m-%dT%H:%M:%S"
payload = {
"model": "",
"messages": [
{
"role": "user",
"content": "",
}
],
"max_tokens": 0,
"temperature": 0.1,
# "seed": 10,
"ignore_eos": True,
"min_tokens": 0,
"stream": False,
}
# Configure logging
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt=DATE_FORMAT, # ISO 8601 UTC format
)
def _get_random_prompt(length):
word_list = [f"{i}" for i in range(10)]
return " ".join(random.choices(word_list, k=length))
def _single_request(
url,
pod,
payload,
model,
logger,
retry_attempts=1,
input_token_length=100,
output_token_length=100,
timeout=30,
retry_delay=1,
):
prompt = _get_random_prompt(input_token_length)
payload_copy = deepcopy(payload)
payload_copy["messages"][0]["content"] = prompt
payload_copy["max_tokens"] = output_token_length
payload_copy["min_tokens"] = output_token_length
payload_copy["model"] = model
response = None
end_time = None
start_time = time.time()
results = []
while retry_attempts:
start_request_time = time.time()
response = None
try:
response = requests.post(
url,
json=payload_copy,
timeout=timeout,
)
end_time = time.time()
content = None
try:
content = response.json()
except ValueError:
pass
results.append(
{
"status": response.status_code,
"result": content,
"request_elapsed_time": end_time - start_request_time,
"url": url,
"pod": pod,
}
)
if response.status_code != 200:
time.sleep(retry_delay)
retry_attempts -= 1
continue
else:
break
except (requests.RequestException, requests.Timeout) as e:
results.append(
{
"status": str(e),
"result": None,
"request_elapsed_time": time.time() - start_request_time,
"url": url,
"pod": pod,
}
)
time.sleep(retry_delay)
retry_attempts -= 1
continue
return {
"time": datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
"results": results,
"total_time": time.time() - start_time,
"url": url,
"pod": pod,
}
def client(
deployment_spec,
namespace,
model,
log_dir,
index,
requests_per_client,
input_token_length,
output_token_length,
max_retries,
max_request_rate,
retry_delay=1,
):
logger = logging.getLogger(f"CLIENT: {index}")
logging.getLogger("httpx").setLevel(logging.WARNING)
managed_deployment = ManagedDeployment(log_dir, deployment_spec, namespace)
pod_ports: Dict[str, Any] = {}
min_elapsed_time = (1 / max_request_rate) if max_request_rate > 0 else 0.0
try:
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, f"client_{index}.log.txt")
with open(log_path, "w") as log:
for i in range(requests_per_client):
pods = managed_deployment.get_pods(
managed_deployment.frontend_service_name
)
port = 0
pod_name = None
pods_ready = []
for pod in pods[managed_deployment.frontend_service_name]:
if pod.ready():
pods_ready.append(pod)
else:
if pod.name in pod_ports:
pod_ports[pod.name].stop()
del pod_ports[pod.name]
if pods_ready:
pod = pods_ready[i % len(pods_ready)]
if pod.name not in pod_ports:
port_forward = managed_deployment.port_forward(
pod, deployment_spec.port
)
if port_forward:
pod_ports[pod.name] = port_forward
if pod.name in pod_ports:
port = pod_ports[pod.name].local_port
pod_name = pod.name
url = f"http://localhost:{port}/{deployment_spec.endpoint}"
result = _single_request(
url,
pod_name,
payload,
model,
logger,
max_retries,
input_token_length=input_token_length,
output_token_length=output_token_length,
retry_delay=retry_delay,
)
logger.info(
f"Request: {i} Pod {pod_name} Local Port {port} Status: {result['results'][-1]['status']} Latency: {result['results'][-1]['request_elapsed_time']}"
)
log.write(json.dumps(result) + "\n")
log.flush()
if result["total_time"] < min_elapsed_time:
time.sleep(min_elapsed_time - result["total_time"])
except Exception as e:
logger.error(str(e))
logger.info("Exiting")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
def pytest_addoption(parser):
parser.addoption("--image", type=str, default=None)
parser.addoption("--namespace", type=str, default="fault-tolerance-test")
@pytest.fixture
def image(request):
return request.config.getoption("--image")
@pytest.fixture
def namespace(request):
return request.config.getoption("--namespace")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
from tabulate import tabulate
def parse_test_log(file_path):
start_time = None
ready_time = None
fault_time = None
start_cmd: Optional[List[str]] = None
if not os.path.isfile(file_path):
return None, None, None
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if "Starting Deployment fault-tolerance-test with spec" in line:
start_time = datetime.fromisoformat(
line.split(" ")[1].replace("T", " ")
)
start_cmd = []
elif "Deployment fault-tolerance-test is ready" in line:
ready_time = datetime.fromisoformat(
line.split(" ")[1].replace("T", " ")
)
elif "Injecting failure for:" in line:
fault_time = datetime.fromisoformat(
line.split(" ")[1].replace("T", " ")
)
startup_time = (
(ready_time - start_time).total_seconds() if start_time and ready_time else None
)
return startup_time, fault_time, start_cmd
def parse_client_logs(test_dir, expected_length=100):
all_logs = []
for file in os.listdir(test_dir):
if file.startswith("client_") and file.endswith(".log.txt"):
with open(os.path.join(test_dir, file), "r") as f:
request_number = 0
for line in f:
request_number += 1
data = json.loads(line.strip())
for result in data["results"]:
log_entry = {
"time": datetime.fromisoformat(
data["time"].replace("T", " ")
),
"status": result["status"],
"request_elapsed_time": result["request_elapsed_time"],
"request_number": request_number - 1,
"client": file.split("_")[1].split(".")[0],
}
if (
"result" in result
and result["result"]
and "choices" in result["result"]
and result["result"]["choices"]
):
log_entry["success"] = True
if "content" in result["result"]["choices"][0]["message"]:
content = result["result"]["choices"][0]["message"][
"content"
]
elif (
"reasoning_content"
in result["result"]["choices"][0]["message"]
):
content = result["result"]["choices"][0]["message"][
"reasoning_content"
]
if not content or len(content) < expected_length:
log_entry["success"] = False
else:
log_entry["success"] = False
all_logs.append(log_entry)
if len(all_logs):
df = pd.DataFrame(all_logs)
df.sort_values("time", inplace=True)
return df
return None
def calculate_metrics(df, fault_time, sla=None):
if fault_time:
before_fault = df[df["time"] <= fault_time]
after_fault = df[df["time"] > fault_time]
else:
before_fault = df
after_fault = None
# Existing latency metrics (only successful requests)
successful_before = before_fault[before_fault["success"]]
avg_before = successful_before["request_elapsed_time"].mean()
std_before = successful_before["request_elapsed_time"].std()
success_before_count = before_fault["success"].sum()
failure_before_count = len(before_fault) - success_before_count
avg_after, std_after, success_after_count, failure_after_count = (
None,
None,
None,
None,
)
if after_fault is not None and not after_fault.empty:
successful_after = after_fault[after_fault["success"]]
avg_after = successful_after["request_elapsed_time"].mean()
std_after = successful_after["request_elapsed_time"].std()
success_after_count = after_fault["success"].sum()
failure_after_count = len(after_fault) - success_after_count
if sla:
# SLA violations (only successful requests exceeding the SLA)
violations_before = (successful_before["request_elapsed_time"] > sla).sum()
violations_after = (
(successful_after["request_elapsed_time"] > sla).sum()
if after_fault is not None and not after_fault.empty
else None
)
else:
violations_before = None
violations_after = None
return (
success_before_count,
failure_before_count,
success_after_count,
failure_after_count,
avg_before,
std_before,
avg_after,
std_after,
violations_before,
violations_after,
)
def parse_process_log(log_dir, process_name):
process_ready_pattern = {
"Frontend": re.compile(r"added model"),
"VllmDecodeWorker": re.compile(
r"VllmWorker for (?P<model_name>.*?) has been initialized"
),
"VllmPrefillWorker": re.compile(
r"VllmWorker for (?P<model_name>.*?) has been initialized"
),
}
if not os.path.isdir(log_dir):
return {}
ready_times: Dict[str, List[Tuple[datetime, str, float]]] = {}
for entry in os.listdir(log_dir):
if entry.endswith(".log") and "metrics" not in entry:
replica_number = entry.split(".")[0]
if replica_number not in ready_times:
ready_times[replica_number] = []
process_start_time = None
with open(os.path.join(log_dir, entry), "r") as f:
for line in f:
line = line.strip()
if not line:
continue
# Try to parse as JSONL first
try:
json_data = json.loads(line)
# Extract timestamp and message from JSON format
if "time" in json_data:
timestamp = datetime.fromisoformat(
json_data["time"].replace("Z", "")
)
log_message = json_data.get("message", "")
else:
continue
except (json.JSONDecodeError, ValueError, KeyError):
# Fall back to readable format parsing
clean_line = re.sub(
r"\x1b\[.*?m", "", line
) # Remove ANSI codes
if not clean_line:
continue
parts = clean_line.split()
if len(parts) < 2:
continue
try:
# Parse timestamp (remove 'Z' for naive datetime)
timestamp = datetime.fromisoformat(
parts[0].replace("Z", "")
)
except ValueError:
continue
log_message = " ".join(parts[1:])
if not process_start_time:
process_start_time = timestamp
relative_time = (timestamp - process_start_time).total_seconds()
# Check for process start lines
if process_name in process_ready_pattern:
if process_ready_pattern[process_name].search(log_message):
if "previous" in entry:
location = 0
else:
location = -1
ready_times[replica_number].insert(
location, (timestamp, log_message, relative_time)
)
return ready_times
def calculate_recovery_time(test_dir, failure_type, fault_time):
if not fault_time:
return None
processes = [
"Frontend",
"VllmDecodeWorker",
"VllmPrefillWorker",
]
process_start = {}
start_time = None
for process in processes:
starts = parse_process_log(os.path.join(test_dir, process), process)
if starts:
process_start[process] = starts
last_recovery_time = 0
for process, replicas in process_start.items():
for replica, container_starts in replicas.items():
for starts in container_starts:
start_time = starts[0]
recovery_time = (start_time - fault_time).total_seconds()
if recovery_time > last_recovery_time:
last_recovery_time = recovery_time
if last_recovery_time == 0:
return None
return last_recovery_time
def process_test_directory(test_dir, sla):
if "test_fault_scenario" not in test_dir:
return {}
test_name = test_dir.split("test_fault_scenario[", 1)[1].rstrip("]")
failure_type = test_name.split("-")[-1]
test_prefix = "-".join(test_name.split("-")[:-1])
startup_time, fault_time, start_cmd = parse_test_log(
os.path.join(test_dir, "test.log.txt")
)
df = parse_client_logs(test_dir)
if df is None or df.empty:
return None
(
success_before,
failure_before,
success_after,
failure_after,
avg_before,
std_before,
avg_after,
std_after,
violations_before,
violations_after,
) = calculate_metrics(df, fault_time, sla)
recovery_time = calculate_recovery_time(test_dir, failure_type, fault_time)
return {
"test": test_prefix,
"cmd": start_cmd,
"failure": failure_type,
"start_time": startup_time,
"success_before_requests": success_before,
"failed_before_requests": failure_before,
"success_after_requests": success_after,
"failed_after_requests": failure_after,
"avg_latency_before": avg_before,
"std_latency_before": std_before,
"avg_latency_after": avg_after,
"std_latency_after": std_after,
"violations_before": violations_before,
"violations_after": violations_after,
"recovery_time": recovery_time,
}
def main(logs_dir, tablefmt, log_paths=None, sla=None):
results = []
if log_paths:
for log_path in log_paths:
result = process_test_directory(log_path, sla)
if result:
results.append(result)
elif logs_dir:
for entry in os.listdir(logs_dir):
if entry.startswith("test_fault_scenario[") and os.path.isdir(
os.path.join(logs_dir, entry)
):
result = process_test_directory(os.path.join(logs_dir, entry), sla)
if result:
results.append(result)
# Group results by test prefix
grouped: dict[str, list[dict[str, Any]]] = {}
commands = {}
for res in results:
test_prefix = res["test"]
if test_prefix not in grouped:
grouped[test_prefix] = []
commands[test_prefix] = res["cmd"]
grouped[test_prefix].append(res)
order = [
"none",
"frontend",
"frontend_pod",
"decode_worker",
"decode_worker_pod",
"prefill_worker",
"prefill_worker_pod",
"vllm_decode_engine_core",
"vllm_prefill_engine_core",
]
# Print grouped tables
for test_prefix, group in grouped.items():
new_group = []
for failure in order:
for res in group:
if failure == res["failure"]:
new_group.append(res)
group = new_group
if sla:
headers = [
"Failure",
"Startup",
"Success\nBefore",
"Failed\nBefore",
"Success\nAfter",
"Failed\nAfter",
"Latency\nBefore",
"Latency\nAfter",
"Violations\nBefore",
"Violations\nAfter",
"Recovery",
]
else:
headers = [
"Failure",
"Startup",
"Success\nBefore",
"Failed\nBefore",
"Success\nAfter",
"Failed\nAfter",
"Latency\nBefore",
"Latency\nAfter",
"Recovery",
]
rows = []
for res in group:
if sla:
row = [
res["failure"],
res["start_time"], # if res["start_time"] is not None else "N/A",
res["success_before_requests"],
res["failed_before_requests"],
res["success_after_requests"],
res["failed_after_requests"],
res["avg_latency_before"],
res["avg_latency_after"],
res["violations_before"],
res["violations_after"],
res["recovery_time"],
]
else:
row = [
res["failure"],
res["start_time"], # if res["start_time"] is not None else "N/A",
res["success_before_requests"],
res["failed_before_requests"],
res["success_after_requests"],
res["failed_after_requests"],
res["avg_latency_before"],
res["avg_latency_after"],
res["recovery_time"],
]
rows.append(row)
print(f"\nTest Group: {test_prefix}")
# print(f"\nTest Command: {commands[test_prefix]}")
print(
tabulate(
rows,
headers,
tablefmt=tablefmt,
floatfmt=".2f",
missingval="N/A",
numalign="right",
stralign="center",
)
)
print("\n" + "=" * 80)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Parse test results")
parser.add_argument("--log-dir", default=".", help="Path to the logs directory")
parser.add_argument(
"--format", choices=["fancy", "markdown"], default="fancy", help="Table format"
)
parser.add_argument("--sla", type=float, default=None)
args = parser.parse_args()
# Map format choices to tabulate formats
tablefmt = (
"fancy_grid" if args.format == "fancy" else "pipe"
) # Using pipe for markdown compatibility
main(args.log_dir, tablefmt, args.sla)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from tests.utils.managed_deployment import DeploymentSpec
@dataclass
class Load:
clients: int = 10
requests_per_client: int = 150
input_token_length: int = 100
output_token_length: int = 100
max_retries: int = 1
max_request_rate: float = 1
sla: Optional[float] = None
@dataclass
class Failure:
time: int
pod_name: str
command: str
signal: str = "SIGINT"
replicas: int = 1
@dataclass
class Scenario:
deployment: DeploymentSpec
load: Load
failures: list[Failure]
model: Optional[str] = None
# Each Deployment Spec contains
# the dynamo deployment configuration
deployment_specs = {
"agg-tp-1-dp-1": (
DeploymentSpec("/workspace/components/backends/vllm/deploy/agg.yaml")
),
"disagg-tp-1-dp-1": (
DeploymentSpec("/workspace/components/backends/vllm/deploy/disagg.yaml")
),
}
# TP-2 scenarios
deployment_specs["agg-tp-2-dp-1"] = DeploymentSpec(
"/workspace/components/backends/vllm/deploy/agg.yaml"
)
deployment_specs["agg-tp-2-dp-1"].set_tensor_parallel(2, ["VllmDecodeWorker"])
deployment_specs["disagg-prefill-tp-2-decode-tp-2-dp-1"] = DeploymentSpec(
"/workspace/components/backends/vllm/deploy/disagg.yaml"
)
deployment_specs["disagg-prefill-tp-2-decode-tp-2-dp-1"][
"VllmPrefillWorker"
].tensor_parallel_size = 2
deployment_specs["disagg-prefill-tp-2-decode-tp-2-dp-1"][
"VllmDecodeWorker"
].tensor_parallel_size = 2
# TP-4 scenarios
deployment_specs["agg-tp-4-dp-1"] = DeploymentSpec(
"/workspace/components/backends/vllm/deploy/agg.yaml"
)
deployment_specs["agg-tp-4-dp-1"].set_tensor_parallel(4, ["VllmDecodeWorker"])
deployment_specs["disagg-prefill-tp-4-decode-tp-4-dp-1"] = DeploymentSpec(
"/workspace/components/backends/vllm/deploy/disagg.yaml"
)
deployment_specs["disagg-prefill-tp-4-decode-tp-4-dp-1"][
"VllmPrefillWorker"
].tensor_parallel_size = 4
deployment_specs["disagg-prefill-tp-4-decode-tp-4-dp-1"][
"VllmDecodeWorker"
].tensor_parallel_size = 4
# Derivative Specs With Incremented Replicats
deployment_specs["agg-tp-1-dp-2"] = DeploymentSpec(
"/workspace/components/backends/vllm/deploy/agg.yaml"
)
deployment_specs["agg-tp-1-dp-2"]["Frontend"].replicas = 2
deployment_specs["agg-tp-1-dp-2"]["VllmDecodeWorker"].replicas = 2
deployment_specs["disagg-tp-1-dp-2"] = DeploymentSpec(
"/workspace/components/backends/vllm/deploy/disagg.yaml"
)
deployment_specs["disagg-tp-1-dp-2"]["Frontend"].replicas = 2
deployment_specs["disagg-tp-1-dp-2"]["VllmDecodeWorker"].replicas = 2
deployment_specs["disagg-tp-1-dp-2"]["VllmPrefillWorker"].replicas = 2
# Each failure scenaro contains a list of failure injections
# Each failure injection has a time in seconds after the pervious injection and
# a list of failures to inject including the number of failures for each type.
# Failures are currently process termination or pod deletion
#
# Example:
#
# "prefill_worker": [[30, [("dynamo_prefillworker", 1)]]],
#
# terminates 1 prefill worker after 30 seconds
failures = {
"frontend": [Failure(30, "Frontend", "dynamo.frontend")],
"frontend_pod": [Failure(30, "Frontend", "delete_pod")],
"decode_worker": [Failure(30, "VllmDecodeWorker", "dynamo.vllm", "SIGKILL")],
"decode_worker_pod": [Failure(30, "VllmDecodeWorker", "delete_pod")],
"prefill_worker": [Failure(30, "VllmPrefillWorker", "dynamo.vllm", "SIGKILL")],
"prefill_worker_pod": [Failure(30, "VllmPrefillWorker", "delete_pod")],
"vllm_decode_engine_core": [
Failure(30, "VllmDecodeWorker", "VLLM::EngineCore", "SIGKILL")
],
"vllm_prefill_engine_core": [
Failure(30, "VllmPrefillWorker", "VLLM::EngineCore", "SIGKILL")
],
"none": [],
}
load = Load()
# model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model = None
# Populate Scenarios
scenarios = {}
for deployment_name, deployment_spec in deployment_specs.items():
for failure_name, failure in failures.items():
if "prefill" in failure_name and "disagg" not in deployment_name:
continue
scenarios[f"{deployment_name}-{failure_name}"] = Scenario(
deployment=deployment_spec, load=load, failures=failure, model=model
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import multiprocessing
import time
from contextlib import contextmanager
import pytest
from tests.fault_tolerance.deploy.client import client
from tests.fault_tolerance.deploy.parse_results import main as parse_results
from tests.fault_tolerance.deploy.scenarios import scenarios
from tests.utils.managed_deployment import ManagedDeployment
@pytest.fixture(params=scenarios.keys())
def scenario(request):
return scenarios[request.param]
@contextmanager
def _clients(
logger,
num_clients,
request,
deployment_spec,
namespace,
model,
requests_per_client,
input_token_length,
output_token_length,
max_retries,
max_request_rate,
):
procs = []
ctx = multiprocessing.get_context("spawn")
for i in range(num_clients):
procs.append(
ctx.Process(
target=client,
args=(
deployment_spec,
namespace,
model,
request.node.name,
i,
requests_per_client,
input_token_length,
output_token_length,
max_retries,
max_request_rate,
),
)
)
procs[-1].start()
yield procs
for proc in procs:
logger.debug(f"{proc} waiting for join")
proc.join()
logger.debug(f"{proc} joined")
def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811
for failure in failures:
time.sleep(failure.time)
pods = deployment.get_pods(failure.pod_name)[failure.pod_name]
num_pods = len(pods)
if not pods:
continue
replicas = failure.replicas
if not replicas:
replicas = num_pods
logger.info(f"Injecting failure for: {failure}")
for x in range(replicas):
pod = pods[x % num_pods]
if failure.command == "delete_pod":
deployment.get_pod_logs(failure.pod_name, pod, ".before_delete")
pod.delete(force=True)
else:
processes = deployment.get_processes(pod)
for process in processes:
if failure.command in process.command:
logger.info(
f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command}"
)
process.kill(failure.signal)
global_result_list = []
@pytest.fixture(autouse=True)
def results_table(request, scenario): # noqa: F811
yield
parse_results(
logs_dir=None,
log_paths=[request.node.name],
tablefmt="fancy_grid",
sla=scenario.load.sla,
)
global_result_list.append(request.node.name)
@pytest.fixture(autouse=True, scope="session")
def results_summary():
yield
parse_results(
logs_dir=None,
log_paths=global_result_list,
tablefmt="fancy_grid",
)
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
scenario, # noqa: F811
request,
image,
namespace,
):
"""
Test dynamo serve deployments with injected failures
"""
logger = logging.getLogger(request.node.name)
scenario.deployment.disable_grove()
scenario.deployment.name = "fault-tolerance-test"
if image:
scenario.deployment.set_image(image)
if scenario.model:
scenario.deployment.set_model(scenario.model)
model = scenario.model
else:
model = scenario.deployment["VllmDecodeWorker"].model
scenario.deployment.set_logging(True, "info")
async with ManagedDeployment(
namespace=namespace,
log_dir=request.node.name,
deployment_spec=scenario.deployment,
) as deployment:
with _clients(
logger,
scenario.load.clients,
request,
scenario.deployment,
namespace,
model,
scenario.load.requests_per_client,
scenario.load.input_token_length,
scenario.load.output_token_length,
scenario.load.max_retries,
scenario.load.max_request_rate,
):
_inject_failures(scenario.failures, logger, deployment)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment