Unverified Commit a473402f authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat: fault tolerance rolling upgrade test scenarios (#4558)

parent 5585f803
......@@ -18,12 +18,14 @@
import json
import logging
import os
import signal
import subprocess
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import requests
from kr8s.objects import Pod
from tests.utils.managed_deployment import ManagedDeployment
......@@ -44,7 +46,7 @@ def get_frontend_port(
deployment_spec: Any,
pod_ports: Dict[str, Any],
logger: logging.Logger,
) -> Tuple[Optional[str], Optional[int], Optional[str]]:
) -> Tuple[Optional[str], Optional[int], Optional[Pod]]:
"""
Select a frontend pod using round-robin and setup port forwarding.
......@@ -60,7 +62,7 @@ def get_frontend_port(
Returns:
Tuple of (pod_name, local_port, pod_instance) or (None, None, None) if failed
"""
pods = managed_deployment.get_pods(managed_deployment.frontend_service_name)
pods = managed_deployment.get_pods([managed_deployment.frontend_service_name])
port = 0
pod_name = None
......@@ -270,6 +272,7 @@ def run_aiperf(
logger: logging.Logger,
max_retries: int = 1,
retry_delay: float = 1,
continuous_load: bool = False,
) -> bool:
"""
Execute AI-Perf with specified parameters.
......@@ -280,13 +283,14 @@ def run_aiperf(
model: Model name
pod_name: Selected pod name for logging
port: Local port number
requests_per_client: Number of requests to send
requests_per_client: Number of requests to send (used if continuous load not enabled)
input_token_length: Input token count
output_token_length: Output token count
output_dir: Directory for AI-Perf artifacts
logger: Logger instance
max_retries: Maximum number of retry attempts (default: 1)
retry_delay: Delay in seconds between retries (default: 1)
continuous_load: If True, use continuous load instead of fixed request count
Returns:
True if successful, False otherwise
......@@ -315,8 +319,6 @@ def run_aiperf(
# Enable streaming for TTFT and ITL metrics
"--streaming",
# Request parameters
"--request-count",
str(requests_per_client), # Required: how many requests
"--concurrency",
"1", # Optional: we set to 1 for sequential
# Token configuration
......@@ -338,8 +340,13 @@ def run_aiperf(
"100", # For reproducible results
]
# Calculate timeout (same as legacy would for all requests)
timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes
if continuous_load:
cmd.extend(["--benchmark-duration", "1800"]) # 30 minutes for continuous load
logger.info("Using continuous load with duration: 30 minutes")
timeout = 1860 # 31 minutes default for duration-based tests (30 minutes + 1 minute buffer)
else:
cmd.extend(["--request-count", str(requests_per_client)])
timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes
# Log execution
logger.info(f"Starting AI-Perf for Pod {pod_name} Local Port {port}")
......@@ -354,15 +361,19 @@ def run_aiperf(
logger.info(f"Command: {' '.join(cmd)}")
# Retry logic for fault tolerance - retry FULL request count until success
max_attempts = max_retries if max_retries > 0 else 1
# Note: For continuous load, we only run once and expect SIGINT to stop it
max_attempts = 1 if continuous_load else (max_retries if max_retries > 0 else 1)
success = False
all_results = []
for attempt in range(max_attempts):
logger.info(
f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests"
)
if continuous_load:
logger.info(
"AI-Perf continuous load (will run until interrupted by SIGINT)"
)
else:
logger.info(
f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests"
)
# Update output directory for this attempt
attempt_dir = output_dir / f"attempt_{attempt}"
......@@ -374,13 +385,7 @@ def run_aiperf(
cmd_attempt[artifact_dir_idx] = str(attempt_dir)
try:
result = subprocess.run(
cmd_attempt,
capture_output=True,
text=True,
timeout=timeout,
stdin=subprocess.DEVNULL, # Prevent stdin reading which can cause process suspension
)
result = run_aiperf_with_signal_handling(cmd_attempt, logger, timeout)
# Save logs for this attempt
with open(attempt_dir / "genai_perf.log", "w") as f:
......@@ -389,15 +394,6 @@ def run_aiperf(
f.write("\n\n=== STDERR ===\n")
f.write(result.stderr)
all_results.append(
{
"attempt": attempt + 1,
"returncode": result.returncode,
"stdout": result.stdout,
"stderr": result.stderr,
}
)
if result.returncode == 0:
# AI-Perf returns 0 even if all requests failed, so we need to check the output
json_path = attempt_dir / "profile_export_aiperf.json"
......@@ -412,6 +408,19 @@ def run_aiperf(
)
if success:
break # Success - exit the retry loop
## TODO: bug with aiperf git+https://github.com/ai-dynamo/aiperf.git@4d3fa29403c8f75da22a14f1f7b3aeb27db9288f
## where sending a SIGINT on Mac can sometimes have an error code of -9 (SIGABRT) which results in profile_export_aiperf.json not being created
elif result.returncode == -9 and continuous_load:
logger.warning(
f"""
Attempt {attempt + 1} failed with return code {result.returncode}
This is a known bug with aiperf on Mac where sending a SIGINT can sometimes have an error code of -9 (SIGABRT)
which results in profile_export_aiperf.json not being created
"""
)
logger.debug(
f"Stderr: {result.stderr[:500] if result.stderr else 'No stderr'}"
)
else:
logger.warning(
f"Attempt {attempt + 1} failed with return code {result.returncode}"
......@@ -421,22 +430,84 @@ def run_aiperf(
)
except Exception as e:
logger.error(f"Error in attempt {attempt + 1}: {str(e)}")
all_results.append({"attempt": attempt + 1, "error": str(e)})
# Sleep before next attempt (if not the last attempt)
if not success and attempt < max_attempts - 1:
# Sleep before next attempt (if not the last attempt and not continuous load)
if not success and attempt < max_attempts - 1 and not continuous_load:
time.sleep(retry_delay)
if success:
if success and not continuous_load:
logger.info(
f"AI-Perf successfully completed all {requests_per_client} requests for {pod_name}"
)
elif success and continuous_load:
logger.info(
f"AI-Perf sustained continuous load for {pod_name} and existed succesfully"
)
else:
logger.error(f"AI-Perf failed all {max_attempts} attempts for {pod_name}")
return success
# TODO: use file redirection and wait() instead of pipes and communicate
def run_aiperf_with_signal_handling(
cmd_attempt: List[str],
logger: logging.Logger,
timeout: int,
) -> subprocess.CompletedProcess:
"""
Run aiperf with signal handling for graceful shutdown.
Handles SIGINT and SIGTERM forwarding and timeout when running with subprocess.Popen.
This ensures that Ctrl-C (SIGINT) and graceful termination signals (SIGTERM)
are properly forwarded to the subprocess so it can clean up gracefully and write results files.
"""
proc = subprocess.Popen(
cmd_attempt,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
stdin=subprocess.DEVNULL,
)
def signal_handler(signum, frame):
signal_names = {
signal.SIGINT: "SIGINT",
signal.SIGTERM: "SIGTERM",
}
signal_name = signal_names.get(signum, f"signal {signum}")
logger.info(f"Received {signal_name}, forwarding to aiperf subprocess")
try:
proc.send_signal(signum)
except ProcessLookupError:
pass # Process already terminated
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
stdout, stderr = proc.communicate(timeout=timeout)
returncode = proc.returncode
except subprocess.TimeoutExpired:
logger.warning(f"AI-Perf subprocess timed out after {timeout}s")
proc.kill()
stdout, stderr = proc.communicate()
returncode = proc.returncode
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, sending SIGINT to aiperf subprocess")
proc.send_signal(signal.SIGINT)
try:
stdout, stderr = proc.communicate(timeout=30) # Give it time to clean up
returncode = proc.returncode
except subprocess.TimeoutExpired:
logger.warning("Subprocess didn't terminate gracefully, killing it")
proc.kill()
stdout, stderr = proc.communicate()
returncode = proc.returncode
return subprocess.CompletedProcess(cmd_attempt, returncode, stdout, stderr)
def log_summary_metrics(
output_dir: Path, logger: logging.Logger, pod_name: str, port: int
) -> None:
......@@ -513,6 +584,7 @@ def client(
output_token_length: int,
max_retries: int,
retry_delay: float = 1,
continuous_load: bool = False,
):
"""
Generate load using AI-Perf for fault tolerance testing.
......@@ -527,11 +599,12 @@ def client(
model: Model name
log_dir: Directory for output logs and AI-Perf artifacts
index: Client index used for round-robin pod selection
requests_per_client: Number of requests to generate
requests_per_client: Number of requests to generate (used if continuous load not enabled)
input_token_length: Number of input tokens per request
output_token_length: Number of output tokens per request
max_retries: Maximum retry attempts for AI-Perf execution
retry_delay: Delay in seconds between retry attempts
continuous_load: If True, use continuous load instead of fixed request count
"""
logger = logging.getLogger(f"CLIENT: {index}")
logging.getLogger("httpx").setLevel(logging.WARNING)
......@@ -578,6 +651,7 @@ def client(
logger=logger,
max_retries=max_retries,
retry_delay=retry_delay,
continuous_load=continuous_load,
)
if not success:
......
......@@ -42,6 +42,7 @@ def get_client_function(client_type: str) -> Callable:
output_token_length,
max_retries,
retry_delay_or_rate, # Differs between implementations
continuous_load,
)
Raises:
......
......@@ -35,6 +35,13 @@ def pytest_addoption(parser):
help="Include tests that require custom builds (e.g., MoE models). "
"By default, these tests are excluded.",
)
parser.addoption(
"--skip-service-restart",
action="store_true",
default=False,
help="Skip restarting NATS and etcd services before deployment. "
"By default, these services are restarted.",
)
def pytest_generate_tests(metafunc):
......@@ -109,3 +116,9 @@ def namespace(request):
def client_type(request):
"""Get client type from command line or use scenario default."""
return request.config.getoption("--client-type")
@pytest.fixture
def skip_service_restart(request):
"""Get skip restart services flag from command line."""
return request.config.getoption("--skip-service-restart")
......@@ -192,6 +192,7 @@ def client(
max_retries,
max_request_rate,
retry_delay=1,
continuous_load=False,
):
"""Legacy custom client for fault tolerance testing.
......@@ -211,7 +212,11 @@ def client(
max_retries: Maximum retry attempts per request
max_request_rate: Maximum requests per second (for rate limiting)
retry_delay: Delay in seconds between retries
continuous_load: If True, use continuous load instead of fixed request count
"""
if continuous_load:
raise ValueError("Continuous load is not supported for legacy client")
logger = logging.getLogger(f"CLIENT: {index}")
logging.getLogger("httpx").setLevel(logging.WARNING)
......@@ -228,7 +233,7 @@ def client(
for i in range(requests_per_client):
# Get available pods
pods = managed_deployment.get_pods(
managed_deployment.frontend_service_name
[managed_deployment.frontend_service_name]
)
port = 0
pod_name = None
......
......@@ -341,6 +341,7 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]:
Returns:
Dictionary with aggregated metrics and client count
"""
logger = logging.getLogger(__name__)
all_metrics: Dict[str, Any] = {
"total_requests": 0,
"successful_requests": 0,
......@@ -382,22 +383,28 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]:
with open(profile_json) as f:
client_metrics = json.load(f)
# AI-Perf format has "records" dictionary at the top level
# AI-Perf format can have "records" dictionary or metrics at top level
# Try records first (older format), then fall back to top level (newer format)
records = client_metrics.get("records", {})
# Extract successful request count
request_count_record = records.get("request_count", {})
# Extract successful request count - check both locations
request_count_record = records.get(
"request_count"
) or client_metrics.get("request_count", {})
successful_count = (
int(request_count_record.get("avg", 0))
if request_count_record
if request_count_record and isinstance(request_count_record, dict)
else 0
)
# Extract error request count
error_request_count_record = records.get("error_request_count", {})
# Extract error request count - check both locations
error_request_count_record = records.get(
"error_request_count"
) or client_metrics.get("error_request_count", {})
error_request_count = (
int(error_request_count_record.get("avg", 0))
if error_request_count_record
and isinstance(error_request_count_record, dict)
else 0
)
......@@ -418,9 +425,17 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]:
# Sum up actual error counts from each error type
error_count = sum(error.get("count", 0) for error in error_summary)
# Check if test was cancelled
# Log if test was cancelled (expected for continuous load mode)
if client_metrics.get("was_cancelled", False):
error_count = request_count # Mark all as failed if cancelled
logger.info(
f"AI-Perf client {item} was cancelled - anticipated if running with continuous load mode. "
f"Completed {request_count} requests before cancellation."
)
# Note: If test was cancelled (was_cancelled=True), we still count the requests
# that were successfully completed before cancellation. The request_count
# represents successful requests, and error_count represents actual errors.
# We don't mark cancelled requests as failed - they were just interrupted.
# Validate data consistency
if request_count < error_count:
......
......@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Pattern
from typing_extensions import TypedDict
from typing_extensions import Required, TypedDict
from tests.utils.managed_deployment import DeploymentSpec
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
if TYPE_CHECKING:
from tests.fault_tolerance.deploy.base_checker import BaseChecker
......@@ -54,8 +57,8 @@ class DeploymentInfo(TypedDict, total=False):
is_moe: Optional flag indicating if this is a Mixture-of-Experts model
"""
spec: DeploymentSpec
backend: str
spec: Required[DeploymentSpec]
backend: Required[str]
model: str
is_moe: bool
......@@ -155,14 +158,144 @@ class Load:
overflow_request_count: int = 15 # Number of overflow requests
normal_request_count: int = 15 # Number of normal requests after overflow
continuous_load: bool = (
False # If True, use continuous load instead of fixed request count
)
@dataclass
class Failure:
class Failure(ABC):
"""Base class for all failure types."""
# time to wait in seconds before the failure is injected
time: int
pod_name: str
command: str
signal: str = "SIGINT"
replicas: int = 1
# names of DGD services to inject the failure into the corresponding pods for
service_names: list[str]
@abstractmethod
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute the failure injection.
Args:
deployment: The managed deployment to inject the failure into
logger: Logger instance for logging failure injection
Returns: List of affected pod names
"""
pass
@abstractmethod
def get_failure_key(self) -> str:
"""Get the failure key for the failure."""
pass
@dataclass
class RollingUpgradeFailure(Failure):
"""Failure type for triggering rolling upgrades."""
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute rolling upgrade failure injection."""
await deployment.trigger_rolling_upgrade(self.service_names)
# Need to wait for the deployment to be unready so we know the rolling upgrade has started
await deployment.wait_for_unready(timeout=60, log_interval=10)
await deployment._wait_for_ready(timeout=1800) # 30 minute timeout
await asyncio.sleep(
self.time
) # have some requests processed after the rolling upgrade has completed
return await deployment.get_pod_names(self.service_names)
def get_failure_key(self) -> str:
"""Get the failure key for the rolling upgrade failure."""
return f"rolling_upgrade:{','.join(self.service_names)}"
@dataclass
class DeletePodFailure(Failure):
"""Failure type for deleting pods."""
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute pod deletion failure injection."""
service_pod_dict = deployment.get_pods(self.service_names)
pod_names: list[str] = []
for service_name, pods in service_pod_dict.items():
for pod in pods:
deployment.get_pod_manifest_logs_metrics(
service_name, pod, ".before_delete"
)
pod.delete(force=True) # force means no graceful termination
pod_names.append(pod.name)
return pod_names
def get_failure_key(self) -> str:
"""Get the failure key for the delete pod failure."""
return f"delete_pod:{','.join(self.service_names)}"
class TerminateProcessFailure(Failure):
"""Failure type for terminating specific processes by name."""
def __init__(
self,
time: int,
service_names: list[str],
signal: str = "SIGINT",
process_name: str = "",
):
"""Initialize TerminateProcessFailure.
Args:
time: Time to wait in seconds before the failure is injected
service_names: Names of DGD services to inject the failure into
signal: Signal to send (default: "SIGINT")
process_name: Name of the process to terminate (required)
end_condition: End condition for failure (e.g., "dgd_ready")
"""
super().__init__(
time=time,
service_names=service_names,
)
if not process_name or not signal:
raise ValueError(
"process_name and signal are required for TerminateProcessFailure"
)
self.process_name = process_name
self.signal = signal
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute process termination failure injection."""
service_pod_dict = deployment.get_pods(self.service_names)
pod_names: list[str] = []
for service_name, pods in service_pod_dict.items():
for pod in pods:
processes = deployment.get_processes(pod)
for process in processes:
if self.process_name in process.command:
logger.info(
f"Terminating {service_name} pod {pod} Pid {process.pid} Command {process.command}"
)
process.kill(self.signal)
pod_names.append(pod.name)
return pod_names
def get_failure_key(self) -> str:
"""Get the failure key for the terminate process failure."""
return f"terminate_process:{','.join(self.service_names)}:{self.process_name}:{self.signal}"
@dataclass
......@@ -182,13 +315,25 @@ class TokenOverflowFailure(Failure):
):
super().__init__(
time=time,
pod_name="Client",
command="token_overflow",
service_names=["Client"],
)
self.max_seq_len = max_seq_len
self.overflow_multiplier = overflow_multiplier
self.overflow_token_count = int(max_seq_len * overflow_multiplier)
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Token overflow is handled client-side, so this is a no-op."""
# The actual overflow is handled by the client configuration
# which uses the input_token_length from the Load config
# This is just a placeholder for the abstract method
return []
def get_failure_key(self) -> str:
"""Get the failure key for the token overflow failure."""
return f"token_overflow:{self.overflow_token_count}"
@dataclass
class Scenario:
......@@ -206,7 +351,7 @@ class Scenario:
# Helper functions to create deployment specs
def _create_deployment_spec(backend: str, yaml_path: str) -> DeploymentInfo:
def _create_deployment_info(backend: str, yaml_path: str) -> DeploymentInfo:
"""Create a deployment spec with backend information.
Args:
......@@ -240,7 +385,9 @@ def _set_replicas(deployment_spec, backend, deploy_type, replicas):
spec[WORKER_MAP[backend]["prefill"]].replicas = replicas
def _set_tensor_parallel(deployment_spec, backend, deploy_type, tp_size):
def _set_tensor_parallel(
deployment_spec: DeploymentInfo, backend: str, deploy_type: str, tp_size: int
):
"""Set tensor parallel size for worker components."""
spec = deployment_spec["spec"]
......@@ -308,7 +455,7 @@ def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]:
scenario_name = "-".join(name_parts)
# Create and configure the deployment
deployment = _create_deployment_spec(backend, yaml_files[deploy_type])
deployment = _create_deployment_info(backend, yaml_files[deploy_type])
if tp_size > 1:
_set_tensor_parallel(deployment, backend, deploy_type, tp_size)
if dp_replicas > 1:
......@@ -397,34 +544,69 @@ def _create_backend_failures(backend, deploy_type="disagg"):
process_name = f"dynamo.{backend}"
failures = {
"frontend": [Failure(30, "Frontend", "dynamo.frontend")],
"frontend_pod": [Failure(30, "Frontend", "delete_pod")],
"decode_worker": [Failure(30, decode_worker, process_name, "SIGKILL")],
"decode_worker_pod": [Failure(30, decode_worker, "delete_pod")],
"prefill_worker": [Failure(30, prefill_worker, process_name, "SIGKILL")],
"prefill_worker_pod": [Failure(30, prefill_worker, "delete_pod")],
"frontend": [
TerminateProcessFailure(
30, ["Frontend"], "SIGINT", process_name="dynamo.frontend"
)
],
"frontend_pod": [DeletePodFailure(30, ["Frontend"])],
"decode_worker": [
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name=process_name
)
],
"decode_worker_pod": [DeletePodFailure(30, [decode_worker])],
"prefill_worker": [
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name=process_name
)
],
"prefill_worker_pod": [DeletePodFailure(30, [prefill_worker])],
"none": [],
}
if backend == "vllm":
failures["vllm_decode_engine_core"] = [
Failure(30, decode_worker, "VLLM::EngineCore", "SIGKILL")
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="VLLM::EngineCore"
)
]
failures["vllm_prefill_engine_core"] = [
Failure(30, prefill_worker, "VLLM::EngineCore", "SIGKILL")
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="VLLM::EngineCore"
)
]
elif backend == "sglang":
failures["sglang_decode_scheduler"] = [
Failure(30, decode_worker, "sglang::scheduler", "SIGKILL")
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="sglang::scheduler"
)
]
failures["sglang_decode_detokenizer"] = [
Failure(30, decode_worker, "sglang::detokenizer", "SIGKILL")
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="sglang::detokenizer"
)
]
failures["sglang_prefill_scheduler"] = [
Failure(30, prefill_worker, "sglang::scheduler", "SIGKILL")
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="sglang::scheduler"
)
]
failures["sglang_prefill_detokenizer"] = [
Failure(30, prefill_worker, "sglang::detokenizer", "SIGKILL")
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="sglang::detokenizer"
)
]
elif backend == "trtllm":
failures["trtllm_decode_engine_core"] = [
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
)
]
failures["trtllm_prefill_engine_core"] = [
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
)
]
return failures
......@@ -533,7 +715,7 @@ model = None
# Populate Scenarios
scenarios = {}
scenarios: dict[str, Scenario] = {}
# Map of backend+deploy_type to failure definitions
backend_failure_map = {}
......@@ -729,5 +911,59 @@ def add_token_overflow_scenarios():
)
def add_rolling_upgrade_scenarios():
for backend in ["vllm", "sglang", "trtllm"]:
for worker_mode in ["agg", "disagg"]:
yaml_files = {
"agg": f"examples/backends/{backend}/deploy/agg.yaml",
"disagg": f"examples/backends/{backend}/deploy/disagg.yaml",
}
deployment_info = _create_deployment_info(backend, yaml_files[worker_mode])
deployment_spec: DeploymentSpec = deployment_info["spec"]
service_names: list[str] = []
# setting replicas to 2 so we have availability of 1 replica at a time
if worker_mode == "agg" and backend == "trtllm":
service_names.append(WORKER_MAP[backend]["decode_agg"])
else:
service_names.append(WORKER_MAP[backend]["decode"])
if worker_mode == "disagg":
service_names.append(WORKER_MAP[backend]["prefill"])
for service_name in service_names:
deployment_spec.set_service_replicas(service_name, 2)
load = Load(
clients=10,
input_token_length=100,
output_token_length=100,
max_retries=1,
client_type="aiperf",
max_request_rate=1.0,
success_threshold=100.0,
continuous_load=True,
)
scenario_name = f"{backend}-{worker_mode}-rolling-upgrade"
model = "Qwen/Qwen3-0.6B"
failure = RollingUpgradeFailure(
time=30,
service_names=service_names,
)
scenarios[scenario_name] = Scenario(
deployment=deployment_info["spec"],
load=load,
failures=[failure],
model=model,
backend=backend,
)
# Add the token overflow scenarios
add_token_overflow_scenarios()
# Add the rolling upgrade scenarios
add_rolling_upgrade_scenarios()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import multiprocessing
import os
import re
import time
import signal
from contextlib import contextmanager
from typing import Any
from multiprocessing.context import SpawnProcess
from typing import Any, Optional
import pytest
......@@ -17,11 +20,12 @@ from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery
from tests.fault_tolerance.deploy.scenarios import (
OVERFLOW_SUFFIX,
RECOVERY_SUFFIX,
Failure,
Load,
TokenOverflowFailure,
Scenario,
scenarios,
)
from tests.utils.managed_deployment import ManagedDeployment
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
@pytest.fixture
......@@ -55,18 +59,18 @@ def scenario(scenario_name, client_type):
@contextmanager
def _clients(
logger,
request,
deployment_spec,
namespace,
model,
logger: logging.Logger,
log_dir: str,
deployment_spec: DeploymentSpec,
namespace: str,
model: str,
load_config: Load,
):
"""Start client processes using factory pattern for client selection.
Args:
logger: Logger instance
request: Pytest request fixture
log_dir: Log directory for output logs and client logs/artifacts
deployment_spec: Deployment specification
namespace: Kubernetes namespace
model: Model name to test
......@@ -79,7 +83,7 @@ def _clients(
f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
)
procs = []
procs: list[SpawnProcess] = []
ctx = multiprocessing.get_context("spawn")
# Determine retry_delay_or_rate based on client type
......@@ -90,6 +94,9 @@ def _clients(
# AI-Perf client uses retry_delay between attempts (default 5s)
retry_delay_or_rate = 5
# Check if this is a continuous load test (rolling upgrade scenarios)
continuous_load = getattr(load_config, "continuous_load", False)
# Check if this is a mixed token test (overflow + recovery)
# If mixed_token_test is True, run two phases; otherwise run normally
if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test:
......@@ -108,13 +115,14 @@ def _clients(
deployment_spec,
namespace,
model,
request.node.name + OVERFLOW_SUFFIX,
f"{log_dir}{OVERFLOW_SUFFIX}",
i,
load_config.overflow_request_count, # 15 overflow requests
load_config.overflow_token_length, # 2x max_seq_len tokens
load_config.output_token_length,
load_config.max_retries,
retry_delay_or_rate,
continuous_load,
),
)
proc_overflow.start()
......@@ -128,7 +136,7 @@ def _clients(
logger.info("Overflow requests completed. Starting recovery phase...")
# Second phase: Send normal requests to test recovery
procs_recovery = []
procs_recovery: list[SpawnProcess] = []
for i in range(load_config.clients):
proc_normal = ctx.Process(
target=client_func,
......@@ -136,7 +144,7 @@ def _clients(
deployment_spec,
namespace,
model,
request.node.name + RECOVERY_SUFFIX,
f"{log_dir}{RECOVERY_SUFFIX}",
i,
load_config.normal_request_count, # 15 normal requests
load_config.input_token_length, # Normal token count
......@@ -161,13 +169,14 @@ def _clients(
deployment_spec,
namespace,
model,
request.node.name,
log_dir,
i,
load_config.requests_per_client,
load_config.input_token_length,
load_config.output_token_length,
load_config.max_retries,
retry_delay_or_rate,
continuous_load, # Pass continuous_load flag
),
)
)
......@@ -182,65 +191,50 @@ def _clients(
logger.debug(f"{proc} joined")
def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811
"""Inject failures and return info about affected pods.
Returns:
Dict mapping failure info to list of affected pod names
Example: {"VllmDecodeWorker:delete_pod": ["pod-abc123", "pod-xyz789"]}
def _terminate_client_processes(
client_procs: list[SpawnProcess],
logger: logging.Logger,
):
"""
affected_pods: dict[str, list] = {}
for failure in failures:
time.sleep(failure.time)
# Handle TokenOverflowFailure differently - it's a client-side injection
if isinstance(failure, TokenOverflowFailure):
# The actual overflow is handled by the client configuration
# which uses the input_token_length from the Load config
# This is just logging for visibility
continue
pods = deployment.get_pods(failure.pod_name)[failure.pod_name]
num_pods = len(pods)
Terminate client processes.
"""
# Send SIGINT to client processes to stop continuous load
if client_procs:
logger.info(f"Sending SIGINT to {len(client_procs)} client processes...")
for proc in client_procs:
if proc.is_alive():
try:
if proc.pid is not None:
logger.debug(f"Sending SIGINT to client process {proc.pid}")
os.kill(proc.pid, signal.SIGINT)
else:
raise ValueError(f"Process {proc} has no PID")
except ProcessLookupError:
logger.debug(f"Process {proc.pid} already terminated")
except Exception as e:
logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}")
logger.info(
"SIGINT sent to all client processes, waiting for graceful shutdown..."
)
else:
logger.warning("No client processes provided to terminate")
if not pods:
continue
replicas = failure.replicas
async def _inject_failures(
failures: list[Failure],
logger: logging.Logger,
deployment: ManagedDeployment,
) -> dict[str, list]: # noqa: F811
affected_pods: dict[str, list] = {}
if not replicas:
replicas = num_pods
for failure in failures:
await asyncio.sleep(failure.time)
logger.info(f"Injecting failure for: {failure}")
# Track which pods were affected by this failure
failure_key = f"{failure.pod_name}:{failure.command}"
if failure_key not in affected_pods:
affected_pods[failure_key] = []
for x in range(replicas):
pod = pods[x % num_pods]
# Capture the exact pod name before we kill it
pod_name = pod.name
affected_pods[failure_key].append(pod_name)
logger.info(f"Target pod for failure: {pod_name}")
if failure.command == "delete_pod":
deployment.get_pod_logs(failure.pod_name, pod, ".before_delete")
logger.info(f"Deleting pod: {pod_name}")
pod.delete(force=True)
else:
processes = deployment.get_processes(pod)
for process in processes:
if failure.command in process.command:
logger.info(
f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command} in pod {pod_name}"
)
process.kill(failure.signal)
affected_pods[failure.get_failure_key()] = await failure.execute(
deployment, logger
)
return affected_pods
......@@ -445,11 +439,12 @@ def results_summary():
@pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
scenario, # noqa: F811
scenario: Scenario, # noqa: F811
request,
image,
namespace,
image: str,
namespace: str,
validation_context, # noqa: F811 # Shared context for passing data to validation
skip_service_restart: bool,
):
"""
Test dynamo serve deployments with injected failures
......@@ -468,6 +463,7 @@ async def test_fault_scenario(
if image:
scenario.deployment.set_image(image)
model: Optional[str] = None
if scenario.model:
scenario.deployment.set_model(scenario.model)
model = scenario.model
......@@ -500,6 +496,7 @@ async def test_fault_scenario(
namespace=namespace,
log_dir=request.node.name,
deployment_spec=scenario.deployment,
skip_service_restart=skip_service_restart,
) as deployment:
# Populate shared context for validation
validation_context["deployment"] = deployment
......@@ -507,14 +504,17 @@ async def test_fault_scenario(
with _clients(
logger,
request,
request.node.name,
scenario.deployment,
namespace,
model,
scenario.load, # Pass entire Load config object
):
) as client_procs:
# Inject failures and capture which pods were affected
affected_pods = _inject_failures(scenario.failures, logger, deployment)
validation_context["affected_pods"] = affected_pods
affected_pods = await _inject_failures(
scenario.failures, logger, deployment
)
logger.info(f"Affected pods during test: {affected_pods}")
if scenario.load.continuous_load:
_terminate_client_processes(client_procs, logger)
......@@ -5,18 +5,18 @@ import asyncio
import logging
import os
import re
import secrets
import shlex
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional
import kr8s
import kubernetes
import requests
import yaml
from kr8s.objects import Pod as kr8s_Pod
from kr8s.objects import Service as kr8s_Service
from kr8s.objects import Pod, Service
from kubernetes_asyncio import client, config
from kubernetes_asyncio.client import exceptions
def _get_workspace_dir() -> str:
......@@ -65,6 +65,15 @@ class ServiceSpec:
self._spec["extraPodSpec"]["mainContainer"] = {}
self._spec["extraPodSpec"]["mainContainer"]["image"] = value
@property
def envs(self) -> list[dict[str, str]]:
"""Environment variables for the service"""
return self._spec.get("envs", [])
@envs.setter
def envs(self, value: list[dict[str, str]]):
self._spec["envs"] = value
# ----- Replicas -----
@property
def replicas(self) -> int:
......@@ -314,8 +323,36 @@ class DeploymentSpec:
return {"jsonl_enabled": jsonl_enabled, "log_level": log_level}
def set_service_env_var(self, service_name: str, name: str, value: str):
"""
Set an environment variable for a specific service
"""
service = self.get_service(service_name)
envs = service.envs if service.envs is not None else []
# if env var already exists, update it
for env in envs:
if env["name"] == name:
env["value"] = value
service.envs = envs # Save back to trigger the setter
return
# if env var does not exist, add it
envs.append({"name": name, "value": value})
service.envs = envs # Save back to trigger the setter
def get_service_env_vars(self, service_name: str) -> list[dict]:
"""
Get all environment variables for a specific service
Returns:
List of environment variable dicts (e.g., [{"name": "VAR", "value": "val"}])
"""
service = self.get_service(service_name)
return service.envs
@property
def services(self) -> list:
def services(self) -> list[ServiceSpec]:
"""List of ServiceSpec objects"""
return [
ServiceSpec(svc, spec)
......@@ -340,28 +377,25 @@ class DeploymentSpec:
arg_name: Argument name (e.g., "--max-model-len", "--max-seq-len")
arg_value: Argument value (e.g., "1024")
"""
# Get the service
if service_name not in self._deployment_spec["spec"]["services"]:
raise ValueError(f"Service '{service_name}' not found in deployment spec")
service = self._deployment_spec["spec"]["services"][service_name]
service = self.get_service(service_name)
service_spec = service._spec
# Ensure args list exists
if "extraPodSpec" not in service:
service["extraPodSpec"] = {"mainContainer": {}}
if "mainContainer" not in service["extraPodSpec"]:
service["extraPodSpec"]["mainContainer"] = {}
if "args" not in service["extraPodSpec"]["mainContainer"]:
service["extraPodSpec"]["mainContainer"]["args"] = []
if "extraPodSpec" not in service_spec:
service_spec["extraPodSpec"] = {"mainContainer": {}}
if "mainContainer" not in service_spec["extraPodSpec"]:
service_spec["extraPodSpec"]["mainContainer"] = {}
if "args" not in service_spec["extraPodSpec"]["mainContainer"]:
service_spec["extraPodSpec"]["mainContainer"]["args"] = []
args_list = service["extraPodSpec"]["mainContainer"]["args"]
args_list = service_spec["extraPodSpec"]["mainContainer"]["args"]
# Convert to list if needed (sometimes it's a single string)
if isinstance(args_list, str):
import shlex
args_list = shlex.split(args_list)
service["extraPodSpec"]["mainContainer"]["args"] = args_list
service_spec["extraPodSpec"]["mainContainer"]["args"] = args_list
# Find existing argument
arg_index = None
......@@ -384,6 +418,24 @@ class DeploymentSpec:
# Add new argument
args_list.extend([arg_name, arg_value])
def get_service(self, service_name: str) -> ServiceSpec:
"""
Get a specific service from the deployment spec
"""
if service_name not in self._deployment_spec["spec"]["services"]:
raise ValueError(f"Service '{service_name}' not found in deployment spec")
return ServiceSpec(
service_name, self._deployment_spec["spec"]["services"][service_name]
)
def set_service_replicas(self, service_name: str, replicas: int):
"""
Set the number of replicas for a specific service
"""
service = self.get_service(service_name)
service.replicas = replicas
def save(self, out_file: str):
"""Save updated deployment to file"""
with open(out_file, "w") as f:
......@@ -391,7 +443,7 @@ class DeploymentSpec:
class PodProcess:
def __init__(self, pod: kr8s_Pod, line: str):
def __init__(self, pod: Pod, line: str):
self.pid = int(re.split(r"\s+", line)[1])
self.command = " ".join(
re.split(r"\s+", line)[10:]
......@@ -439,10 +491,13 @@ class ManagedDeployment:
log_dir: str
deployment_spec: DeploymentSpec
namespace: str
frontend_service_name: Optional[str] = "Frontend"
# TODO: this should be determined by the deployment_spec
# the service containing component_type: Frontend determines what is actually the frontend service
frontend_service_name: str = "Frontend"
skip_service_restart: bool = False
_custom_api: Optional[Any] = None
_core_api: Optional[Any] = None
_custom_api: Optional[client.CustomObjectsApi] = None
_core_api: Optional[client.CoreV1Api] = None
_in_cluster: bool = False
_logger: logging.Logger = logging.getLogger()
_port_forward: Optional[Any] = None
......@@ -457,7 +512,7 @@ class ManagedDeployment:
"""Initialize kubernetes client"""
try:
# Try in-cluster config first (for pods with service accounts)
await config.load_incluster_config()
config.load_incluster_config()
self._in_cluster = True
except Exception:
# Fallback to kube config file (for local development)
......@@ -511,6 +566,17 @@ class ManagedDeployment:
self._logger.info(f"Restarted {name} {label}")
async def wait_for_unready(self, timeout: int = 1800, sleep=1, log_interval=60):
"""
Wait for the custom resource to be unready.
Args:
timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while)
"""
return await self._wait_for_condition(
timeout, sleep, log_interval, False, "pending"
)
async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60):
"""
Wait for the custom resource to be ready.
......@@ -518,9 +584,23 @@ class ManagedDeployment:
Args:
timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while)
"""
return await self._wait_for_condition(
timeout, sleep, log_interval, True, "successful"
)
async def _wait_for_condition(
self,
timeout: int = 1800,
sleep=1,
log_interval=60,
desired_ready_condition_val: bool = True,
desired_state_val: str = "successful",
):
start_time = time.time()
self._logger.info(f"Waiting for Deployment {self._deployment_name}")
self._logger.info(
f"Waiting for Deployment {self._deployment_name} to have Ready condition {desired_ready_condition_val} and state {desired_state_val}"
)
attempt = 0
......@@ -528,7 +608,7 @@ class ManagedDeployment:
try:
attempt += 1
assert self._custom_api is not None, "Kubernetes API not initialized"
status = await self._custom_api.get_namespaced_custom_object(
status = await self._custom_api.get_namespaced_custom_object( # type: ignore[awaitable-is-not-coroutine]
group="nvidia.com",
version="v1alpha1",
namespace=self.namespace,
......@@ -538,29 +618,34 @@ class ManagedDeployment:
# Check both conditions:
# 1. Ready condition is True
# 2. State is successful
status_obj = status.get("status", {})
conditions = status_obj.get("conditions", [])
current_state = status_obj.get("state", "unknown")
status_obj = status.get("status", {}) # type: ignore[attr-defined]
conditions = status_obj.get("conditions", []) # type: ignore[attr-defined]
current_state = status_obj.get("state", "unknown") # type: ignore[attr-defined]
ready_condition = False
observed_ready_condition_val = ""
for condition in conditions:
if (
condition.get("type") == "Ready"
and condition.get("status") == "True"
):
ready_condition = True
break
state_successful = status_obj.get("state") == "successful"
if ready_condition and state_successful:
if condition.get("type") == "Ready":
observed_ready_condition_val = condition.get("status")
if observed_ready_condition_val == str(
desired_ready_condition_val
):
break
observed_state_val = status_obj.get("state") # type: ignore[attr-defined]
if (
observed_ready_condition_val == str(desired_ready_condition_val)
and observed_state_val == desired_state_val
):
self._logger.info(f"Current deployment state: {current_state}")
self._logger.info(f"Current conditions: {conditions}")
self._logger.info(
f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s"
)
self._logger.info(f"Deployment {self._deployment_name} is ready")
self._logger.info(
f"Deployment {self._deployment_name} has Ready condition {desired_ready_condition_val} and state {desired_state_val}"
)
return True
else:
if attempt % log_interval == 0:
......@@ -570,10 +655,10 @@ class ManagedDeployment:
f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s"
)
self._logger.info(
f"Deployment not ready yet - Ready condition: {ready_condition}, State successful: {state_successful}"
f"Deployment has Ready condition {observed_ready_condition_val} and state {observed_state_val}, desired condition {desired_ready_condition_val} and state {desired_state_val}"
)
except kubernetes.client.rest.ApiException as e:
except exceptions.ApiException as e:
self._logger.info(
f"API Exception while checking deployment status: {e}"
)
......@@ -624,7 +709,7 @@ class ManagedDeployment:
)
self._logger.info(self.deployment_spec.spec())
self._logger.info(f"Deployment Started {self._deployment_name}")
except kubernetes.client.rest.ApiException as e:
except exceptions.ApiException as e:
if e.status == 409: # Already exists
self._logger.info(f"Deployment {self._deployment_name} already exists")
else:
......@@ -633,7 +718,64 @@ class ManagedDeployment:
)
raise
def get_processes(self, pod) -> list:
async def trigger_rolling_upgrade(self, service_names: list[str]):
"""
Triggers a rolling update for a list of services
This is a dummy update - sets an env var on the service
"""
if not service_names:
raise ValueError(
"service_names cannot be empty for trigger_rolling_upgrade"
)
patch_body: dict[str, Any] = {"spec": {"services": {}}}
for service_name in service_names:
self.deployment_spec.set_service_env_var(
service_name, "TEST_ROLLING_UPDATE_TRIGGER", secrets.token_hex(8)
)
updated_envs = self.deployment_spec.get_service_env_vars(service_name)
patch_body["spec"]["services"][service_name] = {"envs": updated_envs}
try:
assert self._custom_api is not None, "Kubernetes API not initialized"
await self._custom_api.patch_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.namespace,
plural="dynamographdeployments",
name=self._deployment_name,
body=patch_body,
_content_type="application/merge-patch+json",
)
except exceptions.ApiException as e:
self._logger.info(
f"Failed to patch deployment {self._deployment_name}: {e}"
)
raise
async def get_pod_names(self, service_names: list[str] | None = None) -> list[str]:
if not service_names:
service_names = [service.name for service in self.deployment_spec.services]
pod_names: list[str] = []
for service_name in service_names:
label_selector = (
f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}"
)
assert self._core_api is not None, "Kubernetes API not initialized"
pods: client.V1PodList = await self._core_api.list_namespaced_pod(
self.namespace, label_selector=label_selector
)
for pod in pods.items:
pod_names.append(pod.metadata.name)
return pod_names
def get_processes(self, pod: Pod) -> list[PodProcess]:
"""Get list of processes in the given pod"""
result = pod.exec(["ps", "-aux"])
lines = result.stdout.decode().splitlines()
......@@ -646,38 +788,34 @@ class ManagedDeployment:
service_name = ""
full_service_name = f"{self._deployment_name}-{service_name.lower()}"
return kr8s_Service.get(full_service_name, namespace=self.namespace)
return Service.get(full_service_name, namespace=self.namespace)
def get_pods(self, service_name=None):
result = {}
def get_pods(self, service_names: list[str] | None = None) -> dict[str, list[Pod]]:
result: dict[str, list[Pod]] = {}
service_list = []
if not service_names:
service_names = [service.name for service in self.deployment_spec.services]
if not service_name:
service_list = [service.name for service in self.deployment_spec.services]
else:
service_list = [service_name]
for service in service_list:
for service_name in service_names:
# List pods for this service using the selector label
# nvidia.com/selector: deployment-name-service
label_selector = (
f"nvidia.com/selector={self._deployment_name}-{service.lower()}"
f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}"
)
pods = []
pods: list[Pod] = []
for pod in kr8s.get(
"pods", namespace=self.namespace, label_selector=label_selector
):
pods.append(pod)
pods.append(pod) # type: ignore[arg-type]
result[service] = pods
result[service_name] = pods
return result
def get_pod_logs(self, service, pod, suffix=""):
directory = os.path.join(self.log_dir, service)
def get_pod_manifest_logs_metrics(self, service_name: str, pod: Pod, suffix=""):
directory = os.path.join(self.log_dir, service_name)
os.makedirs(directory, exist_ok=True)
try:
......@@ -699,16 +837,20 @@ class ManagedDeployment:
except Exception as e:
self._logger.debug(e)
self._get_pod_metrics(pod, service, suffix)
self._get_pod_metrics(pod, service_name, suffix)
def _get_service_logs(self, service_name=None, suffix=""):
service_pods = self.get_pods(service_name)
service_names = None
if service_name:
service_names = [service_name]
service_pods = self.get_pods(service_names)
for service, pods in service_pods.items():
for i, pod in enumerate(pods):
self.get_pod_logs(service, pod, suffix)
for pod in pods:
self.get_pod_manifest_logs_metrics(service, pod, suffix)
def _get_pod_metrics(self, pod, service_name, suffix=""):
def _get_pod_metrics(self, pod: Pod, service_name: str, suffix=""):
directory = os.path.join(self.log_dir, service_name)
os.makedirs(directory, exist_ok=True)
port = None
......@@ -757,11 +899,13 @@ class ManagedDeployment:
plural="dynamographdeployments",
name=self._deployment_name,
)
except client.exceptions.ApiException as e:
except exceptions.ApiException as e:
if e.status != 404: # Ignore if already deleted
raise
def port_forward(self, pod, remote_port, max_connection_attempts=3):
def port_forward(
self, pod: Pod, remote_port: int, max_connection_attempts: int = 3
):
"""Attempt to connect to a pod and return the port-forward object on success.
Note: Port forwards run in background threads. When pods are terminated,
......@@ -866,9 +1010,13 @@ class ManagedDeployment:
self._deployment_name = self.deployment_spec.name
logging.getLogger("httpx").setLevel(logging.WARNING)
await self._init_kubernetes()
await self._delete_deployment()
await self._restart_etcd()
await self._restart_nats()
# Run delete deployment and service restarts in parallel
tasks = [self._delete_deployment()]
if not self.skip_service_restart:
tasks.extend([self._restart_etcd(), self._restart_nats()])
await asyncio.gather(*tasks)
await self._create_deployment()
await self._wait_for_ready()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment