Unverified Commit a473402f authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat: fault tolerance rolling upgrade test scenarios (#4558)

parent 5585f803
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
import json import json
import logging import logging
import os import os
import signal
import subprocess import subprocess
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import requests import requests
from kr8s.objects import Pod
from tests.utils.managed_deployment import ManagedDeployment from tests.utils.managed_deployment import ManagedDeployment
...@@ -44,7 +46,7 @@ def get_frontend_port( ...@@ -44,7 +46,7 @@ def get_frontend_port(
deployment_spec: Any, deployment_spec: Any,
pod_ports: Dict[str, Any], pod_ports: Dict[str, Any],
logger: logging.Logger, logger: logging.Logger,
) -> Tuple[Optional[str], Optional[int], Optional[str]]: ) -> Tuple[Optional[str], Optional[int], Optional[Pod]]:
""" """
Select a frontend pod using round-robin and setup port forwarding. Select a frontend pod using round-robin and setup port forwarding.
...@@ -60,7 +62,7 @@ def get_frontend_port( ...@@ -60,7 +62,7 @@ def get_frontend_port(
Returns: Returns:
Tuple of (pod_name, local_port, pod_instance) or (None, None, None) if failed Tuple of (pod_name, local_port, pod_instance) or (None, None, None) if failed
""" """
pods = managed_deployment.get_pods(managed_deployment.frontend_service_name) pods = managed_deployment.get_pods([managed_deployment.frontend_service_name])
port = 0 port = 0
pod_name = None pod_name = None
...@@ -270,6 +272,7 @@ def run_aiperf( ...@@ -270,6 +272,7 @@ def run_aiperf(
logger: logging.Logger, logger: logging.Logger,
max_retries: int = 1, max_retries: int = 1,
retry_delay: float = 1, retry_delay: float = 1,
continuous_load: bool = False,
) -> bool: ) -> bool:
""" """
Execute AI-Perf with specified parameters. Execute AI-Perf with specified parameters.
...@@ -280,13 +283,14 @@ def run_aiperf( ...@@ -280,13 +283,14 @@ def run_aiperf(
model: Model name model: Model name
pod_name: Selected pod name for logging pod_name: Selected pod name for logging
port: Local port number port: Local port number
requests_per_client: Number of requests to send requests_per_client: Number of requests to send (used if continuous load not enabled)
input_token_length: Input token count input_token_length: Input token count
output_token_length: Output token count output_token_length: Output token count
output_dir: Directory for AI-Perf artifacts output_dir: Directory for AI-Perf artifacts
logger: Logger instance logger: Logger instance
max_retries: Maximum number of retry attempts (default: 1) max_retries: Maximum number of retry attempts (default: 1)
retry_delay: Delay in seconds between retries (default: 1) retry_delay: Delay in seconds between retries (default: 1)
continuous_load: If True, use continuous load instead of fixed request count
Returns: Returns:
True if successful, False otherwise True if successful, False otherwise
...@@ -315,8 +319,6 @@ def run_aiperf( ...@@ -315,8 +319,6 @@ def run_aiperf(
# Enable streaming for TTFT and ITL metrics # Enable streaming for TTFT and ITL metrics
"--streaming", "--streaming",
# Request parameters # Request parameters
"--request-count",
str(requests_per_client), # Required: how many requests
"--concurrency", "--concurrency",
"1", # Optional: we set to 1 for sequential "1", # Optional: we set to 1 for sequential
# Token configuration # Token configuration
...@@ -338,8 +340,13 @@ def run_aiperf( ...@@ -338,8 +340,13 @@ def run_aiperf(
"100", # For reproducible results "100", # For reproducible results
] ]
# Calculate timeout (same as legacy would for all requests) if continuous_load:
timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes cmd.extend(["--benchmark-duration", "1800"]) # 30 minutes for continuous load
logger.info("Using continuous load with duration: 30 minutes")
timeout = 1860 # 31 minutes default for duration-based tests (30 minutes + 1 minute buffer)
else:
cmd.extend(["--request-count", str(requests_per_client)])
timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes
# Log execution # Log execution
logger.info(f"Starting AI-Perf for Pod {pod_name} Local Port {port}") logger.info(f"Starting AI-Perf for Pod {pod_name} Local Port {port}")
...@@ -354,15 +361,19 @@ def run_aiperf( ...@@ -354,15 +361,19 @@ def run_aiperf(
logger.info(f"Command: {' '.join(cmd)}") logger.info(f"Command: {' '.join(cmd)}")
# Retry logic for fault tolerance - retry FULL request count until success # Retry logic for fault tolerance - retry FULL request count until success
# Note: For continuous load, we only run once and expect SIGINT to stop it
max_attempts = max_retries if max_retries > 0 else 1 max_attempts = 1 if continuous_load else (max_retries if max_retries > 0 else 1)
success = False success = False
all_results = []
for attempt in range(max_attempts): for attempt in range(max_attempts):
logger.info( if continuous_load:
f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests" logger.info(
) "AI-Perf continuous load (will run until interrupted by SIGINT)"
)
else:
logger.info(
f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests"
)
# Update output directory for this attempt # Update output directory for this attempt
attempt_dir = output_dir / f"attempt_{attempt}" attempt_dir = output_dir / f"attempt_{attempt}"
...@@ -374,13 +385,7 @@ def run_aiperf( ...@@ -374,13 +385,7 @@ def run_aiperf(
cmd_attempt[artifact_dir_idx] = str(attempt_dir) cmd_attempt[artifact_dir_idx] = str(attempt_dir)
try: try:
result = subprocess.run( result = run_aiperf_with_signal_handling(cmd_attempt, logger, timeout)
cmd_attempt,
capture_output=True,
text=True,
timeout=timeout,
stdin=subprocess.DEVNULL, # Prevent stdin reading which can cause process suspension
)
# Save logs for this attempt # Save logs for this attempt
with open(attempt_dir / "genai_perf.log", "w") as f: with open(attempt_dir / "genai_perf.log", "w") as f:
...@@ -389,15 +394,6 @@ def run_aiperf( ...@@ -389,15 +394,6 @@ def run_aiperf(
f.write("\n\n=== STDERR ===\n") f.write("\n\n=== STDERR ===\n")
f.write(result.stderr) f.write(result.stderr)
all_results.append(
{
"attempt": attempt + 1,
"returncode": result.returncode,
"stdout": result.stdout,
"stderr": result.stderr,
}
)
if result.returncode == 0: if result.returncode == 0:
# AI-Perf returns 0 even if all requests failed, so we need to check the output # AI-Perf returns 0 even if all requests failed, so we need to check the output
json_path = attempt_dir / "profile_export_aiperf.json" json_path = attempt_dir / "profile_export_aiperf.json"
...@@ -412,6 +408,19 @@ def run_aiperf( ...@@ -412,6 +408,19 @@ def run_aiperf(
) )
if success: if success:
break # Success - exit the retry loop break # Success - exit the retry loop
## TODO: bug with aiperf git+https://github.com/ai-dynamo/aiperf.git@4d3fa29403c8f75da22a14f1f7b3aeb27db9288f
## where sending a SIGINT on Mac can sometimes have an error code of -9 (SIGABRT) which results in profile_export_aiperf.json not being created
elif result.returncode == -9 and continuous_load:
logger.warning(
f"""
Attempt {attempt + 1} failed with return code {result.returncode}
This is a known bug with aiperf on Mac where sending a SIGINT can sometimes have an error code of -9 (SIGABRT)
which results in profile_export_aiperf.json not being created
"""
)
logger.debug(
f"Stderr: {result.stderr[:500] if result.stderr else 'No stderr'}"
)
else: else:
logger.warning( logger.warning(
f"Attempt {attempt + 1} failed with return code {result.returncode}" f"Attempt {attempt + 1} failed with return code {result.returncode}"
...@@ -421,22 +430,84 @@ def run_aiperf( ...@@ -421,22 +430,84 @@ def run_aiperf(
) )
except Exception as e: except Exception as e:
logger.error(f"Error in attempt {attempt + 1}: {str(e)}") logger.error(f"Error in attempt {attempt + 1}: {str(e)}")
all_results.append({"attempt": attempt + 1, "error": str(e)})
# Sleep before next attempt (if not the last attempt) # Sleep before next attempt (if not the last attempt and not continuous load)
if not success and attempt < max_attempts - 1: if not success and attempt < max_attempts - 1 and not continuous_load:
time.sleep(retry_delay) time.sleep(retry_delay)
if success: if success and not continuous_load:
logger.info( logger.info(
f"AI-Perf successfully completed all {requests_per_client} requests for {pod_name}" f"AI-Perf successfully completed all {requests_per_client} requests for {pod_name}"
) )
elif success and continuous_load:
logger.info(
f"AI-Perf sustained continuous load for {pod_name} and existed succesfully"
)
else: else:
logger.error(f"AI-Perf failed all {max_attempts} attempts for {pod_name}") logger.error(f"AI-Perf failed all {max_attempts} attempts for {pod_name}")
return success return success
# TODO: use file redirection and wait() instead of pipes and communicate
def run_aiperf_with_signal_handling(
cmd_attempt: List[str],
logger: logging.Logger,
timeout: int,
) -> subprocess.CompletedProcess:
"""
Run aiperf with signal handling for graceful shutdown.
Handles SIGINT and SIGTERM forwarding and timeout when running with subprocess.Popen.
This ensures that Ctrl-C (SIGINT) and graceful termination signals (SIGTERM)
are properly forwarded to the subprocess so it can clean up gracefully and write results files.
"""
proc = subprocess.Popen(
cmd_attempt,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
stdin=subprocess.DEVNULL,
)
def signal_handler(signum, frame):
signal_names = {
signal.SIGINT: "SIGINT",
signal.SIGTERM: "SIGTERM",
}
signal_name = signal_names.get(signum, f"signal {signum}")
logger.info(f"Received {signal_name}, forwarding to aiperf subprocess")
try:
proc.send_signal(signum)
except ProcessLookupError:
pass # Process already terminated
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
stdout, stderr = proc.communicate(timeout=timeout)
returncode = proc.returncode
except subprocess.TimeoutExpired:
logger.warning(f"AI-Perf subprocess timed out after {timeout}s")
proc.kill()
stdout, stderr = proc.communicate()
returncode = proc.returncode
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, sending SIGINT to aiperf subprocess")
proc.send_signal(signal.SIGINT)
try:
stdout, stderr = proc.communicate(timeout=30) # Give it time to clean up
returncode = proc.returncode
except subprocess.TimeoutExpired:
logger.warning("Subprocess didn't terminate gracefully, killing it")
proc.kill()
stdout, stderr = proc.communicate()
returncode = proc.returncode
return subprocess.CompletedProcess(cmd_attempt, returncode, stdout, stderr)
def log_summary_metrics( def log_summary_metrics(
output_dir: Path, logger: logging.Logger, pod_name: str, port: int output_dir: Path, logger: logging.Logger, pod_name: str, port: int
) -> None: ) -> None:
...@@ -513,6 +584,7 @@ def client( ...@@ -513,6 +584,7 @@ def client(
output_token_length: int, output_token_length: int,
max_retries: int, max_retries: int,
retry_delay: float = 1, retry_delay: float = 1,
continuous_load: bool = False,
): ):
""" """
Generate load using AI-Perf for fault tolerance testing. Generate load using AI-Perf for fault tolerance testing.
...@@ -527,11 +599,12 @@ def client( ...@@ -527,11 +599,12 @@ def client(
model: Model name model: Model name
log_dir: Directory for output logs and AI-Perf artifacts log_dir: Directory for output logs and AI-Perf artifacts
index: Client index used for round-robin pod selection index: Client index used for round-robin pod selection
requests_per_client: Number of requests to generate requests_per_client: Number of requests to generate (used if continuous load not enabled)
input_token_length: Number of input tokens per request input_token_length: Number of input tokens per request
output_token_length: Number of output tokens per request output_token_length: Number of output tokens per request
max_retries: Maximum retry attempts for AI-Perf execution max_retries: Maximum retry attempts for AI-Perf execution
retry_delay: Delay in seconds between retry attempts retry_delay: Delay in seconds between retry attempts
continuous_load: If True, use continuous load instead of fixed request count
""" """
logger = logging.getLogger(f"CLIENT: {index}") logger = logging.getLogger(f"CLIENT: {index}")
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
...@@ -578,6 +651,7 @@ def client( ...@@ -578,6 +651,7 @@ def client(
logger=logger, logger=logger,
max_retries=max_retries, max_retries=max_retries,
retry_delay=retry_delay, retry_delay=retry_delay,
continuous_load=continuous_load,
) )
if not success: if not success:
......
...@@ -42,6 +42,7 @@ def get_client_function(client_type: str) -> Callable: ...@@ -42,6 +42,7 @@ def get_client_function(client_type: str) -> Callable:
output_token_length, output_token_length,
max_retries, max_retries,
retry_delay_or_rate, # Differs between implementations retry_delay_or_rate, # Differs between implementations
continuous_load,
) )
Raises: Raises:
......
...@@ -35,6 +35,13 @@ def pytest_addoption(parser): ...@@ -35,6 +35,13 @@ def pytest_addoption(parser):
help="Include tests that require custom builds (e.g., MoE models). " help="Include tests that require custom builds (e.g., MoE models). "
"By default, these tests are excluded.", "By default, these tests are excluded.",
) )
parser.addoption(
"--skip-service-restart",
action="store_true",
default=False,
help="Skip restarting NATS and etcd services before deployment. "
"By default, these services are restarted.",
)
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
...@@ -109,3 +116,9 @@ def namespace(request): ...@@ -109,3 +116,9 @@ def namespace(request):
def client_type(request): def client_type(request):
"""Get client type from command line or use scenario default.""" """Get client type from command line or use scenario default."""
return request.config.getoption("--client-type") return request.config.getoption("--client-type")
@pytest.fixture
def skip_service_restart(request):
"""Get skip restart services flag from command line."""
return request.config.getoption("--skip-service-restart")
...@@ -192,6 +192,7 @@ def client( ...@@ -192,6 +192,7 @@ def client(
max_retries, max_retries,
max_request_rate, max_request_rate,
retry_delay=1, retry_delay=1,
continuous_load=False,
): ):
"""Legacy custom client for fault tolerance testing. """Legacy custom client for fault tolerance testing.
...@@ -211,7 +212,11 @@ def client( ...@@ -211,7 +212,11 @@ def client(
max_retries: Maximum retry attempts per request max_retries: Maximum retry attempts per request
max_request_rate: Maximum requests per second (for rate limiting) max_request_rate: Maximum requests per second (for rate limiting)
retry_delay: Delay in seconds between retries retry_delay: Delay in seconds between retries
continuous_load: If True, use continuous load instead of fixed request count
""" """
if continuous_load:
raise ValueError("Continuous load is not supported for legacy client")
logger = logging.getLogger(f"CLIENT: {index}") logger = logging.getLogger(f"CLIENT: {index}")
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
...@@ -228,7 +233,7 @@ def client( ...@@ -228,7 +233,7 @@ def client(
for i in range(requests_per_client): for i in range(requests_per_client):
# Get available pods # Get available pods
pods = managed_deployment.get_pods( pods = managed_deployment.get_pods(
managed_deployment.frontend_service_name [managed_deployment.frontend_service_name]
) )
port = 0 port = 0
pod_name = None pod_name = None
......
...@@ -341,6 +341,7 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: ...@@ -341,6 +341,7 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]:
Returns: Returns:
Dictionary with aggregated metrics and client count Dictionary with aggregated metrics and client count
""" """
logger = logging.getLogger(__name__)
all_metrics: Dict[str, Any] = { all_metrics: Dict[str, Any] = {
"total_requests": 0, "total_requests": 0,
"successful_requests": 0, "successful_requests": 0,
...@@ -382,22 +383,28 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: ...@@ -382,22 +383,28 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]:
with open(profile_json) as f: with open(profile_json) as f:
client_metrics = json.load(f) client_metrics = json.load(f)
# AI-Perf format has "records" dictionary at the top level # AI-Perf format can have "records" dictionary or metrics at top level
# Try records first (older format), then fall back to top level (newer format)
records = client_metrics.get("records", {}) records = client_metrics.get("records", {})
# Extract successful request count # Extract successful request count - check both locations
request_count_record = records.get("request_count", {}) request_count_record = records.get(
"request_count"
) or client_metrics.get("request_count", {})
successful_count = ( successful_count = (
int(request_count_record.get("avg", 0)) int(request_count_record.get("avg", 0))
if request_count_record if request_count_record and isinstance(request_count_record, dict)
else 0 else 0
) )
# Extract error request count # Extract error request count - check both locations
error_request_count_record = records.get("error_request_count", {}) error_request_count_record = records.get(
"error_request_count"
) or client_metrics.get("error_request_count", {})
error_request_count = ( error_request_count = (
int(error_request_count_record.get("avg", 0)) int(error_request_count_record.get("avg", 0))
if error_request_count_record if error_request_count_record
and isinstance(error_request_count_record, dict)
else 0 else 0
) )
...@@ -418,9 +425,17 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: ...@@ -418,9 +425,17 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]:
# Sum up actual error counts from each error type # Sum up actual error counts from each error type
error_count = sum(error.get("count", 0) for error in error_summary) error_count = sum(error.get("count", 0) for error in error_summary)
# Check if test was cancelled # Log if test was cancelled (expected for continuous load mode)
if client_metrics.get("was_cancelled", False): if client_metrics.get("was_cancelled", False):
error_count = request_count # Mark all as failed if cancelled logger.info(
f"AI-Perf client {item} was cancelled - anticipated if running with continuous load mode. "
f"Completed {request_count} requests before cancellation."
)
# Note: If test was cancelled (was_cancelled=True), we still count the requests
# that were successfully completed before cancellation. The request_count
# represents successful requests, and error_count represents actual errors.
# We don't mark cancelled requests as failed - they were just interrupted.
# Validate data consistency # Validate data consistency
if request_count < error_count: if request_count < error_count:
......
...@@ -13,14 +13,17 @@ ...@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging
import re import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Pattern from typing import TYPE_CHECKING, Dict, List, Optional, Pattern
from typing_extensions import TypedDict from typing_extensions import Required, TypedDict
from tests.utils.managed_deployment import DeploymentSpec from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
if TYPE_CHECKING: if TYPE_CHECKING:
from tests.fault_tolerance.deploy.base_checker import BaseChecker from tests.fault_tolerance.deploy.base_checker import BaseChecker
...@@ -54,8 +57,8 @@ class DeploymentInfo(TypedDict, total=False): ...@@ -54,8 +57,8 @@ class DeploymentInfo(TypedDict, total=False):
is_moe: Optional flag indicating if this is a Mixture-of-Experts model is_moe: Optional flag indicating if this is a Mixture-of-Experts model
""" """
spec: DeploymentSpec spec: Required[DeploymentSpec]
backend: str backend: Required[str]
model: str model: str
is_moe: bool is_moe: bool
...@@ -155,14 +158,144 @@ class Load: ...@@ -155,14 +158,144 @@ class Load:
overflow_request_count: int = 15 # Number of overflow requests overflow_request_count: int = 15 # Number of overflow requests
normal_request_count: int = 15 # Number of normal requests after overflow normal_request_count: int = 15 # Number of normal requests after overflow
continuous_load: bool = (
False # If True, use continuous load instead of fixed request count
)
@dataclass @dataclass
class Failure: class Failure(ABC):
"""Base class for all failure types."""
# time to wait in seconds before the failure is injected
time: int time: int
pod_name: str
command: str # names of DGD services to inject the failure into the corresponding pods for
signal: str = "SIGINT" service_names: list[str]
replicas: int = 1
@abstractmethod
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute the failure injection.
Args:
deployment: The managed deployment to inject the failure into
logger: Logger instance for logging failure injection
Returns: List of affected pod names
"""
pass
@abstractmethod
def get_failure_key(self) -> str:
"""Get the failure key for the failure."""
pass
@dataclass
class RollingUpgradeFailure(Failure):
"""Failure type for triggering rolling upgrades."""
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute rolling upgrade failure injection."""
await deployment.trigger_rolling_upgrade(self.service_names)
# Need to wait for the deployment to be unready so we know the rolling upgrade has started
await deployment.wait_for_unready(timeout=60, log_interval=10)
await deployment._wait_for_ready(timeout=1800) # 30 minute timeout
await asyncio.sleep(
self.time
) # have some requests processed after the rolling upgrade has completed
return await deployment.get_pod_names(self.service_names)
def get_failure_key(self) -> str:
"""Get the failure key for the rolling upgrade failure."""
return f"rolling_upgrade:{','.join(self.service_names)}"
@dataclass
class DeletePodFailure(Failure):
"""Failure type for deleting pods."""
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute pod deletion failure injection."""
service_pod_dict = deployment.get_pods(self.service_names)
pod_names: list[str] = []
for service_name, pods in service_pod_dict.items():
for pod in pods:
deployment.get_pod_manifest_logs_metrics(
service_name, pod, ".before_delete"
)
pod.delete(force=True) # force means no graceful termination
pod_names.append(pod.name)
return pod_names
def get_failure_key(self) -> str:
"""Get the failure key for the delete pod failure."""
return f"delete_pod:{','.join(self.service_names)}"
class TerminateProcessFailure(Failure):
"""Failure type for terminating specific processes by name."""
def __init__(
self,
time: int,
service_names: list[str],
signal: str = "SIGINT",
process_name: str = "",
):
"""Initialize TerminateProcessFailure.
Args:
time: Time to wait in seconds before the failure is injected
service_names: Names of DGD services to inject the failure into
signal: Signal to send (default: "SIGINT")
process_name: Name of the process to terminate (required)
end_condition: End condition for failure (e.g., "dgd_ready")
"""
super().__init__(
time=time,
service_names=service_names,
)
if not process_name or not signal:
raise ValueError(
"process_name and signal are required for TerminateProcessFailure"
)
self.process_name = process_name
self.signal = signal
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Execute process termination failure injection."""
service_pod_dict = deployment.get_pods(self.service_names)
pod_names: list[str] = []
for service_name, pods in service_pod_dict.items():
for pod in pods:
processes = deployment.get_processes(pod)
for process in processes:
if self.process_name in process.command:
logger.info(
f"Terminating {service_name} pod {pod} Pid {process.pid} Command {process.command}"
)
process.kill(self.signal)
pod_names.append(pod.name)
return pod_names
def get_failure_key(self) -> str:
"""Get the failure key for the terminate process failure."""
return f"terminate_process:{','.join(self.service_names)}:{self.process_name}:{self.signal}"
@dataclass @dataclass
...@@ -182,13 +315,25 @@ class TokenOverflowFailure(Failure): ...@@ -182,13 +315,25 @@ class TokenOverflowFailure(Failure):
): ):
super().__init__( super().__init__(
time=time, time=time,
pod_name="Client", service_names=["Client"],
command="token_overflow",
) )
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.overflow_multiplier = overflow_multiplier self.overflow_multiplier = overflow_multiplier
self.overflow_token_count = int(max_seq_len * overflow_multiplier) self.overflow_token_count = int(max_seq_len * overflow_multiplier)
async def execute(
self, deployment: ManagedDeployment, logger: logging.Logger
) -> list[str]:
"""Token overflow is handled client-side, so this is a no-op."""
# The actual overflow is handled by the client configuration
# which uses the input_token_length from the Load config
# This is just a placeholder for the abstract method
return []
def get_failure_key(self) -> str:
"""Get the failure key for the token overflow failure."""
return f"token_overflow:{self.overflow_token_count}"
@dataclass @dataclass
class Scenario: class Scenario:
...@@ -206,7 +351,7 @@ class Scenario: ...@@ -206,7 +351,7 @@ class Scenario:
# Helper functions to create deployment specs # Helper functions to create deployment specs
def _create_deployment_spec(backend: str, yaml_path: str) -> DeploymentInfo: def _create_deployment_info(backend: str, yaml_path: str) -> DeploymentInfo:
"""Create a deployment spec with backend information. """Create a deployment spec with backend information.
Args: Args:
...@@ -240,7 +385,9 @@ def _set_replicas(deployment_spec, backend, deploy_type, replicas): ...@@ -240,7 +385,9 @@ def _set_replicas(deployment_spec, backend, deploy_type, replicas):
spec[WORKER_MAP[backend]["prefill"]].replicas = replicas spec[WORKER_MAP[backend]["prefill"]].replicas = replicas
def _set_tensor_parallel(deployment_spec, backend, deploy_type, tp_size): def _set_tensor_parallel(
deployment_spec: DeploymentInfo, backend: str, deploy_type: str, tp_size: int
):
"""Set tensor parallel size for worker components.""" """Set tensor parallel size for worker components."""
spec = deployment_spec["spec"] spec = deployment_spec["spec"]
...@@ -308,7 +455,7 @@ def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]: ...@@ -308,7 +455,7 @@ def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]:
scenario_name = "-".join(name_parts) scenario_name = "-".join(name_parts)
# Create and configure the deployment # Create and configure the deployment
deployment = _create_deployment_spec(backend, yaml_files[deploy_type]) deployment = _create_deployment_info(backend, yaml_files[deploy_type])
if tp_size > 1: if tp_size > 1:
_set_tensor_parallel(deployment, backend, deploy_type, tp_size) _set_tensor_parallel(deployment, backend, deploy_type, tp_size)
if dp_replicas > 1: if dp_replicas > 1:
...@@ -397,34 +544,69 @@ def _create_backend_failures(backend, deploy_type="disagg"): ...@@ -397,34 +544,69 @@ def _create_backend_failures(backend, deploy_type="disagg"):
process_name = f"dynamo.{backend}" process_name = f"dynamo.{backend}"
failures = { failures = {
"frontend": [Failure(30, "Frontend", "dynamo.frontend")], "frontend": [
"frontend_pod": [Failure(30, "Frontend", "delete_pod")], TerminateProcessFailure(
"decode_worker": [Failure(30, decode_worker, process_name, "SIGKILL")], 30, ["Frontend"], "SIGINT", process_name="dynamo.frontend"
"decode_worker_pod": [Failure(30, decode_worker, "delete_pod")], )
"prefill_worker": [Failure(30, prefill_worker, process_name, "SIGKILL")], ],
"prefill_worker_pod": [Failure(30, prefill_worker, "delete_pod")], "frontend_pod": [DeletePodFailure(30, ["Frontend"])],
"decode_worker": [
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name=process_name
)
],
"decode_worker_pod": [DeletePodFailure(30, [decode_worker])],
"prefill_worker": [
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name=process_name
)
],
"prefill_worker_pod": [DeletePodFailure(30, [prefill_worker])],
"none": [], "none": [],
} }
if backend == "vllm": if backend == "vllm":
failures["vllm_decode_engine_core"] = [ failures["vllm_decode_engine_core"] = [
Failure(30, decode_worker, "VLLM::EngineCore", "SIGKILL") TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="VLLM::EngineCore"
)
] ]
failures["vllm_prefill_engine_core"] = [ failures["vllm_prefill_engine_core"] = [
Failure(30, prefill_worker, "VLLM::EngineCore", "SIGKILL") TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="VLLM::EngineCore"
)
] ]
elif backend == "sglang": elif backend == "sglang":
failures["sglang_decode_scheduler"] = [ failures["sglang_decode_scheduler"] = [
Failure(30, decode_worker, "sglang::scheduler", "SIGKILL") TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="sglang::scheduler"
)
] ]
failures["sglang_decode_detokenizer"] = [ failures["sglang_decode_detokenizer"] = [
Failure(30, decode_worker, "sglang::detokenizer", "SIGKILL") TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="sglang::detokenizer"
)
] ]
failures["sglang_prefill_scheduler"] = [ failures["sglang_prefill_scheduler"] = [
Failure(30, prefill_worker, "sglang::scheduler", "SIGKILL") TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="sglang::scheduler"
)
] ]
failures["sglang_prefill_detokenizer"] = [ failures["sglang_prefill_detokenizer"] = [
Failure(30, prefill_worker, "sglang::detokenizer", "SIGKILL") TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="sglang::detokenizer"
)
]
elif backend == "trtllm":
failures["trtllm_decode_engine_core"] = [
TerminateProcessFailure(
30, [decode_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
)
]
failures["trtllm_prefill_engine_core"] = [
TerminateProcessFailure(
30, [prefill_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
)
] ]
return failures return failures
...@@ -533,7 +715,7 @@ model = None ...@@ -533,7 +715,7 @@ model = None
# Populate Scenarios # Populate Scenarios
scenarios = {} scenarios: dict[str, Scenario] = {}
# Map of backend+deploy_type to failure definitions # Map of backend+deploy_type to failure definitions
backend_failure_map = {} backend_failure_map = {}
...@@ -729,5 +911,59 @@ def add_token_overflow_scenarios(): ...@@ -729,5 +911,59 @@ def add_token_overflow_scenarios():
) )
def add_rolling_upgrade_scenarios():
for backend in ["vllm", "sglang", "trtllm"]:
for worker_mode in ["agg", "disagg"]:
yaml_files = {
"agg": f"examples/backends/{backend}/deploy/agg.yaml",
"disagg": f"examples/backends/{backend}/deploy/disagg.yaml",
}
deployment_info = _create_deployment_info(backend, yaml_files[worker_mode])
deployment_spec: DeploymentSpec = deployment_info["spec"]
service_names: list[str] = []
# setting replicas to 2 so we have availability of 1 replica at a time
if worker_mode == "agg" and backend == "trtllm":
service_names.append(WORKER_MAP[backend]["decode_agg"])
else:
service_names.append(WORKER_MAP[backend]["decode"])
if worker_mode == "disagg":
service_names.append(WORKER_MAP[backend]["prefill"])
for service_name in service_names:
deployment_spec.set_service_replicas(service_name, 2)
load = Load(
clients=10,
input_token_length=100,
output_token_length=100,
max_retries=1,
client_type="aiperf",
max_request_rate=1.0,
success_threshold=100.0,
continuous_load=True,
)
scenario_name = f"{backend}-{worker_mode}-rolling-upgrade"
model = "Qwen/Qwen3-0.6B"
failure = RollingUpgradeFailure(
time=30,
service_names=service_names,
)
scenarios[scenario_name] = Scenario(
deployment=deployment_info["spec"],
load=load,
failures=[failure],
model=model,
backend=backend,
)
# Add the token overflow scenarios # Add the token overflow scenarios
add_token_overflow_scenarios() add_token_overflow_scenarios()
# Add the rolling upgrade scenarios
add_rolling_upgrade_scenarios()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
import logging import logging
import multiprocessing import multiprocessing
import os
import re import re
import time import signal
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from multiprocessing.context import SpawnProcess
from typing import Any, Optional
import pytest import pytest
...@@ -17,11 +20,12 @@ from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery ...@@ -17,11 +20,12 @@ from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery
from tests.fault_tolerance.deploy.scenarios import ( from tests.fault_tolerance.deploy.scenarios import (
OVERFLOW_SUFFIX, OVERFLOW_SUFFIX,
RECOVERY_SUFFIX, RECOVERY_SUFFIX,
Failure,
Load, Load,
TokenOverflowFailure, Scenario,
scenarios, scenarios,
) )
from tests.utils.managed_deployment import ManagedDeployment from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
@pytest.fixture @pytest.fixture
...@@ -55,18 +59,18 @@ def scenario(scenario_name, client_type): ...@@ -55,18 +59,18 @@ def scenario(scenario_name, client_type):
@contextmanager @contextmanager
def _clients( def _clients(
logger, logger: logging.Logger,
request, log_dir: str,
deployment_spec, deployment_spec: DeploymentSpec,
namespace, namespace: str,
model, model: str,
load_config: Load, load_config: Load,
): ):
"""Start client processes using factory pattern for client selection. """Start client processes using factory pattern for client selection.
Args: Args:
logger: Logger instance logger: Logger instance
request: Pytest request fixture log_dir: Log directory for output logs and client logs/artifacts
deployment_spec: Deployment specification deployment_spec: Deployment specification
namespace: Kubernetes namespace namespace: Kubernetes namespace
model: Model name to test model: Model name to test
...@@ -79,7 +83,7 @@ def _clients( ...@@ -79,7 +83,7 @@ def _clients(
f"Starting {load_config.clients} clients using '{load_config.client_type}' client" f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
) )
procs = [] procs: list[SpawnProcess] = []
ctx = multiprocessing.get_context("spawn") ctx = multiprocessing.get_context("spawn")
# Determine retry_delay_or_rate based on client type # Determine retry_delay_or_rate based on client type
...@@ -90,6 +94,9 @@ def _clients( ...@@ -90,6 +94,9 @@ def _clients(
# AI-Perf client uses retry_delay between attempts (default 5s) # AI-Perf client uses retry_delay between attempts (default 5s)
retry_delay_or_rate = 5 retry_delay_or_rate = 5
# Check if this is a continuous load test (rolling upgrade scenarios)
continuous_load = getattr(load_config, "continuous_load", False)
# Check if this is a mixed token test (overflow + recovery) # Check if this is a mixed token test (overflow + recovery)
# If mixed_token_test is True, run two phases; otherwise run normally # If mixed_token_test is True, run two phases; otherwise run normally
if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test: if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test:
...@@ -108,13 +115,14 @@ def _clients( ...@@ -108,13 +115,14 @@ def _clients(
deployment_spec, deployment_spec,
namespace, namespace,
model, model,
request.node.name + OVERFLOW_SUFFIX, f"{log_dir}{OVERFLOW_SUFFIX}",
i, i,
load_config.overflow_request_count, # 15 overflow requests load_config.overflow_request_count, # 15 overflow requests
load_config.overflow_token_length, # 2x max_seq_len tokens load_config.overflow_token_length, # 2x max_seq_len tokens
load_config.output_token_length, load_config.output_token_length,
load_config.max_retries, load_config.max_retries,
retry_delay_or_rate, retry_delay_or_rate,
continuous_load,
), ),
) )
proc_overflow.start() proc_overflow.start()
...@@ -128,7 +136,7 @@ def _clients( ...@@ -128,7 +136,7 @@ def _clients(
logger.info("Overflow requests completed. Starting recovery phase...") logger.info("Overflow requests completed. Starting recovery phase...")
# Second phase: Send normal requests to test recovery # Second phase: Send normal requests to test recovery
procs_recovery = [] procs_recovery: list[SpawnProcess] = []
for i in range(load_config.clients): for i in range(load_config.clients):
proc_normal = ctx.Process( proc_normal = ctx.Process(
target=client_func, target=client_func,
...@@ -136,7 +144,7 @@ def _clients( ...@@ -136,7 +144,7 @@ def _clients(
deployment_spec, deployment_spec,
namespace, namespace,
model, model,
request.node.name + RECOVERY_SUFFIX, f"{log_dir}{RECOVERY_SUFFIX}",
i, i,
load_config.normal_request_count, # 15 normal requests load_config.normal_request_count, # 15 normal requests
load_config.input_token_length, # Normal token count load_config.input_token_length, # Normal token count
...@@ -161,13 +169,14 @@ def _clients( ...@@ -161,13 +169,14 @@ def _clients(
deployment_spec, deployment_spec,
namespace, namespace,
model, model,
request.node.name, log_dir,
i, i,
load_config.requests_per_client, load_config.requests_per_client,
load_config.input_token_length, load_config.input_token_length,
load_config.output_token_length, load_config.output_token_length,
load_config.max_retries, load_config.max_retries,
retry_delay_or_rate, retry_delay_or_rate,
continuous_load, # Pass continuous_load flag
), ),
) )
) )
...@@ -182,65 +191,50 @@ def _clients( ...@@ -182,65 +191,50 @@ def _clients(
logger.debug(f"{proc} joined") logger.debug(f"{proc} joined")
def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811 def _terminate_client_processes(
"""Inject failures and return info about affected pods. client_procs: list[SpawnProcess],
logger: logging.Logger,
Returns: ):
Dict mapping failure info to list of affected pod names
Example: {"VllmDecodeWorker:delete_pod": ["pod-abc123", "pod-xyz789"]}
""" """
affected_pods: dict[str, list] = {} Terminate client processes.
"""
for failure in failures: # Send SIGINT to client processes to stop continuous load
time.sleep(failure.time) if client_procs:
logger.info(f"Sending SIGINT to {len(client_procs)} client processes...")
# Handle TokenOverflowFailure differently - it's a client-side injection for proc in client_procs:
if isinstance(failure, TokenOverflowFailure): if proc.is_alive():
# The actual overflow is handled by the client configuration try:
# which uses the input_token_length from the Load config if proc.pid is not None:
# This is just logging for visibility logger.debug(f"Sending SIGINT to client process {proc.pid}")
continue os.kill(proc.pid, signal.SIGINT)
else:
pods = deployment.get_pods(failure.pod_name)[failure.pod_name] raise ValueError(f"Process {proc} has no PID")
except ProcessLookupError:
num_pods = len(pods) logger.debug(f"Process {proc.pid} already terminated")
except Exception as e:
logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}")
logger.info(
"SIGINT sent to all client processes, waiting for graceful shutdown..."
)
else:
logger.warning("No client processes provided to terminate")
if not pods:
continue
replicas = failure.replicas async def _inject_failures(
failures: list[Failure],
logger: logging.Logger,
deployment: ManagedDeployment,
) -> dict[str, list]: # noqa: F811
affected_pods: dict[str, list] = {}
if not replicas: for failure in failures:
replicas = num_pods await asyncio.sleep(failure.time)
logger.info(f"Injecting failure for: {failure}") logger.info(f"Injecting failure for: {failure}")
# Track which pods were affected by this failure affected_pods[failure.get_failure_key()] = await failure.execute(
failure_key = f"{failure.pod_name}:{failure.command}" deployment, logger
if failure_key not in affected_pods: )
affected_pods[failure_key] = []
for x in range(replicas):
pod = pods[x % num_pods]
# Capture the exact pod name before we kill it
pod_name = pod.name
affected_pods[failure_key].append(pod_name)
logger.info(f"Target pod for failure: {pod_name}")
if failure.command == "delete_pod":
deployment.get_pod_logs(failure.pod_name, pod, ".before_delete")
logger.info(f"Deleting pod: {pod_name}")
pod.delete(force=True)
else:
processes = deployment.get_processes(pod)
for process in processes:
if failure.command in process.command:
logger.info(
f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command} in pod {pod_name}"
)
process.kill(failure.signal)
return affected_pods return affected_pods
...@@ -445,11 +439,12 @@ def results_summary(): ...@@ -445,11 +439,12 @@ def results_summary():
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario( async def test_fault_scenario(
scenario, # noqa: F811 scenario: Scenario, # noqa: F811
request, request,
image, image: str,
namespace, namespace: str,
validation_context, # noqa: F811 # Shared context for passing data to validation validation_context, # noqa: F811 # Shared context for passing data to validation
skip_service_restart: bool,
): ):
""" """
Test dynamo serve deployments with injected failures Test dynamo serve deployments with injected failures
...@@ -468,6 +463,7 @@ async def test_fault_scenario( ...@@ -468,6 +463,7 @@ async def test_fault_scenario(
if image: if image:
scenario.deployment.set_image(image) scenario.deployment.set_image(image)
model: Optional[str] = None
if scenario.model: if scenario.model:
scenario.deployment.set_model(scenario.model) scenario.deployment.set_model(scenario.model)
model = scenario.model model = scenario.model
...@@ -500,6 +496,7 @@ async def test_fault_scenario( ...@@ -500,6 +496,7 @@ async def test_fault_scenario(
namespace=namespace, namespace=namespace,
log_dir=request.node.name, log_dir=request.node.name,
deployment_spec=scenario.deployment, deployment_spec=scenario.deployment,
skip_service_restart=skip_service_restart,
) as deployment: ) as deployment:
# Populate shared context for validation # Populate shared context for validation
validation_context["deployment"] = deployment validation_context["deployment"] = deployment
...@@ -507,14 +504,17 @@ async def test_fault_scenario( ...@@ -507,14 +504,17 @@ async def test_fault_scenario(
with _clients( with _clients(
logger, logger,
request, request.node.name,
scenario.deployment, scenario.deployment,
namespace, namespace,
model, model,
scenario.load, # Pass entire Load config object scenario.load, # Pass entire Load config object
): ) as client_procs:
# Inject failures and capture which pods were affected # Inject failures and capture which pods were affected
affected_pods = _inject_failures(scenario.failures, logger, deployment) affected_pods = await _inject_failures(
validation_context["affected_pods"] = affected_pods scenario.failures, logger, deployment
)
logger.info(f"Affected pods during test: {affected_pods}") logger.info(f"Affected pods during test: {affected_pods}")
if scenario.load.continuous_load:
_terminate_client_processes(client_procs, logger)
...@@ -5,18 +5,18 @@ import asyncio ...@@ -5,18 +5,18 @@ import asyncio
import logging import logging
import os import os
import re import re
import secrets
import shlex import shlex
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List, Optional from typing import Any, List, Optional
import kr8s import kr8s
import kubernetes
import requests import requests
import yaml import yaml
from kr8s.objects import Pod as kr8s_Pod from kr8s.objects import Pod, Service
from kr8s.objects import Service as kr8s_Service
from kubernetes_asyncio import client, config from kubernetes_asyncio import client, config
from kubernetes_asyncio.client import exceptions
def _get_workspace_dir() -> str: def _get_workspace_dir() -> str:
...@@ -65,6 +65,15 @@ class ServiceSpec: ...@@ -65,6 +65,15 @@ class ServiceSpec:
self._spec["extraPodSpec"]["mainContainer"] = {} self._spec["extraPodSpec"]["mainContainer"] = {}
self._spec["extraPodSpec"]["mainContainer"]["image"] = value self._spec["extraPodSpec"]["mainContainer"]["image"] = value
@property
def envs(self) -> list[dict[str, str]]:
"""Environment variables for the service"""
return self._spec.get("envs", [])
@envs.setter
def envs(self, value: list[dict[str, str]]):
self._spec["envs"] = value
# ----- Replicas ----- # ----- Replicas -----
@property @property
def replicas(self) -> int: def replicas(self) -> int:
...@@ -314,8 +323,36 @@ class DeploymentSpec: ...@@ -314,8 +323,36 @@ class DeploymentSpec:
return {"jsonl_enabled": jsonl_enabled, "log_level": log_level} return {"jsonl_enabled": jsonl_enabled, "log_level": log_level}
def set_service_env_var(self, service_name: str, name: str, value: str):
"""
Set an environment variable for a specific service
"""
service = self.get_service(service_name)
envs = service.envs if service.envs is not None else []
# if env var already exists, update it
for env in envs:
if env["name"] == name:
env["value"] = value
service.envs = envs # Save back to trigger the setter
return
# if env var does not exist, add it
envs.append({"name": name, "value": value})
service.envs = envs # Save back to trigger the setter
def get_service_env_vars(self, service_name: str) -> list[dict]:
"""
Get all environment variables for a specific service
Returns:
List of environment variable dicts (e.g., [{"name": "VAR", "value": "val"}])
"""
service = self.get_service(service_name)
return service.envs
@property @property
def services(self) -> list: def services(self) -> list[ServiceSpec]:
"""List of ServiceSpec objects""" """List of ServiceSpec objects"""
return [ return [
ServiceSpec(svc, spec) ServiceSpec(svc, spec)
...@@ -340,28 +377,25 @@ class DeploymentSpec: ...@@ -340,28 +377,25 @@ class DeploymentSpec:
arg_name: Argument name (e.g., "--max-model-len", "--max-seq-len") arg_name: Argument name (e.g., "--max-model-len", "--max-seq-len")
arg_value: Argument value (e.g., "1024") arg_value: Argument value (e.g., "1024")
""" """
# Get the service service = self.get_service(service_name)
if service_name not in self._deployment_spec["spec"]["services"]: service_spec = service._spec
raise ValueError(f"Service '{service_name}' not found in deployment spec")
service = self._deployment_spec["spec"]["services"][service_name]
# Ensure args list exists # Ensure args list exists
if "extraPodSpec" not in service: if "extraPodSpec" not in service_spec:
service["extraPodSpec"] = {"mainContainer": {}} service_spec["extraPodSpec"] = {"mainContainer": {}}
if "mainContainer" not in service["extraPodSpec"]: if "mainContainer" not in service_spec["extraPodSpec"]:
service["extraPodSpec"]["mainContainer"] = {} service_spec["extraPodSpec"]["mainContainer"] = {}
if "args" not in service["extraPodSpec"]["mainContainer"]: if "args" not in service_spec["extraPodSpec"]["mainContainer"]:
service["extraPodSpec"]["mainContainer"]["args"] = [] service_spec["extraPodSpec"]["mainContainer"]["args"] = []
args_list = service["extraPodSpec"]["mainContainer"]["args"] args_list = service_spec["extraPodSpec"]["mainContainer"]["args"]
# Convert to list if needed (sometimes it's a single string) # Convert to list if needed (sometimes it's a single string)
if isinstance(args_list, str): if isinstance(args_list, str):
import shlex import shlex
args_list = shlex.split(args_list) args_list = shlex.split(args_list)
service["extraPodSpec"]["mainContainer"]["args"] = args_list service_spec["extraPodSpec"]["mainContainer"]["args"] = args_list
# Find existing argument # Find existing argument
arg_index = None arg_index = None
...@@ -384,6 +418,24 @@ class DeploymentSpec: ...@@ -384,6 +418,24 @@ class DeploymentSpec:
# Add new argument # Add new argument
args_list.extend([arg_name, arg_value]) args_list.extend([arg_name, arg_value])
def get_service(self, service_name: str) -> ServiceSpec:
"""
Get a specific service from the deployment spec
"""
if service_name not in self._deployment_spec["spec"]["services"]:
raise ValueError(f"Service '{service_name}' not found in deployment spec")
return ServiceSpec(
service_name, self._deployment_spec["spec"]["services"][service_name]
)
def set_service_replicas(self, service_name: str, replicas: int):
"""
Set the number of replicas for a specific service
"""
service = self.get_service(service_name)
service.replicas = replicas
def save(self, out_file: str): def save(self, out_file: str):
"""Save updated deployment to file""" """Save updated deployment to file"""
with open(out_file, "w") as f: with open(out_file, "w") as f:
...@@ -391,7 +443,7 @@ class DeploymentSpec: ...@@ -391,7 +443,7 @@ class DeploymentSpec:
class PodProcess: class PodProcess:
def __init__(self, pod: kr8s_Pod, line: str): def __init__(self, pod: Pod, line: str):
self.pid = int(re.split(r"\s+", line)[1]) self.pid = int(re.split(r"\s+", line)[1])
self.command = " ".join( self.command = " ".join(
re.split(r"\s+", line)[10:] re.split(r"\s+", line)[10:]
...@@ -439,10 +491,13 @@ class ManagedDeployment: ...@@ -439,10 +491,13 @@ class ManagedDeployment:
log_dir: str log_dir: str
deployment_spec: DeploymentSpec deployment_spec: DeploymentSpec
namespace: str namespace: str
frontend_service_name: Optional[str] = "Frontend" # TODO: this should be determined by the deployment_spec
# the service containing component_type: Frontend determines what is actually the frontend service
frontend_service_name: str = "Frontend"
skip_service_restart: bool = False
_custom_api: Optional[Any] = None _custom_api: Optional[client.CustomObjectsApi] = None
_core_api: Optional[Any] = None _core_api: Optional[client.CoreV1Api] = None
_in_cluster: bool = False _in_cluster: bool = False
_logger: logging.Logger = logging.getLogger() _logger: logging.Logger = logging.getLogger()
_port_forward: Optional[Any] = None _port_forward: Optional[Any] = None
...@@ -457,7 +512,7 @@ class ManagedDeployment: ...@@ -457,7 +512,7 @@ class ManagedDeployment:
"""Initialize kubernetes client""" """Initialize kubernetes client"""
try: try:
# Try in-cluster config first (for pods with service accounts) # Try in-cluster config first (for pods with service accounts)
await config.load_incluster_config() config.load_incluster_config()
self._in_cluster = True self._in_cluster = True
except Exception: except Exception:
# Fallback to kube config file (for local development) # Fallback to kube config file (for local development)
...@@ -511,6 +566,17 @@ class ManagedDeployment: ...@@ -511,6 +566,17 @@ class ManagedDeployment:
self._logger.info(f"Restarted {name} {label}") self._logger.info(f"Restarted {name} {label}")
async def wait_for_unready(self, timeout: int = 1800, sleep=1, log_interval=60):
"""
Wait for the custom resource to be unready.
Args:
timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while)
"""
return await self._wait_for_condition(
timeout, sleep, log_interval, False, "pending"
)
async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60):
""" """
Wait for the custom resource to be ready. Wait for the custom resource to be ready.
...@@ -518,9 +584,23 @@ class ManagedDeployment: ...@@ -518,9 +584,23 @@ class ManagedDeployment:
Args: Args:
timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while) timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while)
""" """
return await self._wait_for_condition(
timeout, sleep, log_interval, True, "successful"
)
async def _wait_for_condition(
self,
timeout: int = 1800,
sleep=1,
log_interval=60,
desired_ready_condition_val: bool = True,
desired_state_val: str = "successful",
):
start_time = time.time() start_time = time.time()
self._logger.info(f"Waiting for Deployment {self._deployment_name}") self._logger.info(
f"Waiting for Deployment {self._deployment_name} to have Ready condition {desired_ready_condition_val} and state {desired_state_val}"
)
attempt = 0 attempt = 0
...@@ -528,7 +608,7 @@ class ManagedDeployment: ...@@ -528,7 +608,7 @@ class ManagedDeployment:
try: try:
attempt += 1 attempt += 1
assert self._custom_api is not None, "Kubernetes API not initialized" assert self._custom_api is not None, "Kubernetes API not initialized"
status = await self._custom_api.get_namespaced_custom_object( status = await self._custom_api.get_namespaced_custom_object( # type: ignore[awaitable-is-not-coroutine]
group="nvidia.com", group="nvidia.com",
version="v1alpha1", version="v1alpha1",
namespace=self.namespace, namespace=self.namespace,
...@@ -538,29 +618,34 @@ class ManagedDeployment: ...@@ -538,29 +618,34 @@ class ManagedDeployment:
# Check both conditions: # Check both conditions:
# 1. Ready condition is True # 1. Ready condition is True
# 2. State is successful # 2. State is successful
status_obj = status.get("status", {}) status_obj = status.get("status", {}) # type: ignore[attr-defined]
conditions = status_obj.get("conditions", []) conditions = status_obj.get("conditions", []) # type: ignore[attr-defined]
current_state = status_obj.get("state", "unknown") current_state = status_obj.get("state", "unknown") # type: ignore[attr-defined]
ready_condition = False observed_ready_condition_val = ""
for condition in conditions: for condition in conditions:
if ( if condition.get("type") == "Ready":
condition.get("type") == "Ready" observed_ready_condition_val = condition.get("status")
and condition.get("status") == "True" if observed_ready_condition_val == str(
): desired_ready_condition_val
ready_condition = True ):
break break
state_successful = status_obj.get("state") == "successful" observed_state_val = status_obj.get("state") # type: ignore[attr-defined]
if ready_condition and state_successful: if (
observed_ready_condition_val == str(desired_ready_condition_val)
and observed_state_val == desired_state_val
):
self._logger.info(f"Current deployment state: {current_state}") self._logger.info(f"Current deployment state: {current_state}")
self._logger.info(f"Current conditions: {conditions}") self._logger.info(f"Current conditions: {conditions}")
self._logger.info( self._logger.info(
f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s" f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s"
) )
self._logger.info(f"Deployment {self._deployment_name} is ready") self._logger.info(
f"Deployment {self._deployment_name} has Ready condition {desired_ready_condition_val} and state {desired_state_val}"
)
return True return True
else: else:
if attempt % log_interval == 0: if attempt % log_interval == 0:
...@@ -570,10 +655,10 @@ class ManagedDeployment: ...@@ -570,10 +655,10 @@ class ManagedDeployment:
f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s" f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s"
) )
self._logger.info( self._logger.info(
f"Deployment not ready yet - Ready condition: {ready_condition}, State successful: {state_successful}" f"Deployment has Ready condition {observed_ready_condition_val} and state {observed_state_val}, desired condition {desired_ready_condition_val} and state {desired_state_val}"
) )
except kubernetes.client.rest.ApiException as e: except exceptions.ApiException as e:
self._logger.info( self._logger.info(
f"API Exception while checking deployment status: {e}" f"API Exception while checking deployment status: {e}"
) )
...@@ -624,7 +709,7 @@ class ManagedDeployment: ...@@ -624,7 +709,7 @@ class ManagedDeployment:
) )
self._logger.info(self.deployment_spec.spec()) self._logger.info(self.deployment_spec.spec())
self._logger.info(f"Deployment Started {self._deployment_name}") self._logger.info(f"Deployment Started {self._deployment_name}")
except kubernetes.client.rest.ApiException as e: except exceptions.ApiException as e:
if e.status == 409: # Already exists if e.status == 409: # Already exists
self._logger.info(f"Deployment {self._deployment_name} already exists") self._logger.info(f"Deployment {self._deployment_name} already exists")
else: else:
...@@ -633,7 +718,64 @@ class ManagedDeployment: ...@@ -633,7 +718,64 @@ class ManagedDeployment:
) )
raise raise
def get_processes(self, pod) -> list: async def trigger_rolling_upgrade(self, service_names: list[str]):
"""
Triggers a rolling update for a list of services
This is a dummy update - sets an env var on the service
"""
if not service_names:
raise ValueError(
"service_names cannot be empty for trigger_rolling_upgrade"
)
patch_body: dict[str, Any] = {"spec": {"services": {}}}
for service_name in service_names:
self.deployment_spec.set_service_env_var(
service_name, "TEST_ROLLING_UPDATE_TRIGGER", secrets.token_hex(8)
)
updated_envs = self.deployment_spec.get_service_env_vars(service_name)
patch_body["spec"]["services"][service_name] = {"envs": updated_envs}
try:
assert self._custom_api is not None, "Kubernetes API not initialized"
await self._custom_api.patch_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.namespace,
plural="dynamographdeployments",
name=self._deployment_name,
body=patch_body,
_content_type="application/merge-patch+json",
)
except exceptions.ApiException as e:
self._logger.info(
f"Failed to patch deployment {self._deployment_name}: {e}"
)
raise
async def get_pod_names(self, service_names: list[str] | None = None) -> list[str]:
if not service_names:
service_names = [service.name for service in self.deployment_spec.services]
pod_names: list[str] = []
for service_name in service_names:
label_selector = (
f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}"
)
assert self._core_api is not None, "Kubernetes API not initialized"
pods: client.V1PodList = await self._core_api.list_namespaced_pod(
self.namespace, label_selector=label_selector
)
for pod in pods.items:
pod_names.append(pod.metadata.name)
return pod_names
def get_processes(self, pod: Pod) -> list[PodProcess]:
"""Get list of processes in the given pod""" """Get list of processes in the given pod"""
result = pod.exec(["ps", "-aux"]) result = pod.exec(["ps", "-aux"])
lines = result.stdout.decode().splitlines() lines = result.stdout.decode().splitlines()
...@@ -646,38 +788,34 @@ class ManagedDeployment: ...@@ -646,38 +788,34 @@ class ManagedDeployment:
service_name = "" service_name = ""
full_service_name = f"{self._deployment_name}-{service_name.lower()}" full_service_name = f"{self._deployment_name}-{service_name.lower()}"
return kr8s_Service.get(full_service_name, namespace=self.namespace) return Service.get(full_service_name, namespace=self.namespace)
def get_pods(self, service_name=None): def get_pods(self, service_names: list[str] | None = None) -> dict[str, list[Pod]]:
result = {} result: dict[str, list[Pod]] = {}
service_list = [] if not service_names:
service_names = [service.name for service in self.deployment_spec.services]
if not service_name: for service_name in service_names:
service_list = [service.name for service in self.deployment_spec.services]
else:
service_list = [service_name]
for service in service_list:
# List pods for this service using the selector label # List pods for this service using the selector label
# nvidia.com/selector: deployment-name-service # nvidia.com/selector: deployment-name-service
label_selector = ( label_selector = (
f"nvidia.com/selector={self._deployment_name}-{service.lower()}" f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}"
) )
pods = [] pods: list[Pod] = []
for pod in kr8s.get( for pod in kr8s.get(
"pods", namespace=self.namespace, label_selector=label_selector "pods", namespace=self.namespace, label_selector=label_selector
): ):
pods.append(pod) pods.append(pod) # type: ignore[arg-type]
result[service] = pods result[service_name] = pods
return result return result
def get_pod_logs(self, service, pod, suffix=""): def get_pod_manifest_logs_metrics(self, service_name: str, pod: Pod, suffix=""):
directory = os.path.join(self.log_dir, service) directory = os.path.join(self.log_dir, service_name)
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
try: try:
...@@ -699,16 +837,20 @@ class ManagedDeployment: ...@@ -699,16 +837,20 @@ class ManagedDeployment:
except Exception as e: except Exception as e:
self._logger.debug(e) self._logger.debug(e)
self._get_pod_metrics(pod, service, suffix) self._get_pod_metrics(pod, service_name, suffix)
def _get_service_logs(self, service_name=None, suffix=""): def _get_service_logs(self, service_name=None, suffix=""):
service_pods = self.get_pods(service_name) service_names = None
if service_name:
service_names = [service_name]
service_pods = self.get_pods(service_names)
for service, pods in service_pods.items(): for service, pods in service_pods.items():
for i, pod in enumerate(pods): for pod in pods:
self.get_pod_logs(service, pod, suffix) self.get_pod_manifest_logs_metrics(service, pod, suffix)
def _get_pod_metrics(self, pod, service_name, suffix=""): def _get_pod_metrics(self, pod: Pod, service_name: str, suffix=""):
directory = os.path.join(self.log_dir, service_name) directory = os.path.join(self.log_dir, service_name)
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
port = None port = None
...@@ -757,11 +899,13 @@ class ManagedDeployment: ...@@ -757,11 +899,13 @@ class ManagedDeployment:
plural="dynamographdeployments", plural="dynamographdeployments",
name=self._deployment_name, name=self._deployment_name,
) )
except client.exceptions.ApiException as e: except exceptions.ApiException as e:
if e.status != 404: # Ignore if already deleted if e.status != 404: # Ignore if already deleted
raise raise
def port_forward(self, pod, remote_port, max_connection_attempts=3): def port_forward(
self, pod: Pod, remote_port: int, max_connection_attempts: int = 3
):
"""Attempt to connect to a pod and return the port-forward object on success. """Attempt to connect to a pod and return the port-forward object on success.
Note: Port forwards run in background threads. When pods are terminated, Note: Port forwards run in background threads. When pods are terminated,
...@@ -866,9 +1010,13 @@ class ManagedDeployment: ...@@ -866,9 +1010,13 @@ class ManagedDeployment:
self._deployment_name = self.deployment_spec.name self._deployment_name = self.deployment_spec.name
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
await self._init_kubernetes() await self._init_kubernetes()
await self._delete_deployment()
await self._restart_etcd() # Run delete deployment and service restarts in parallel
await self._restart_nats() tasks = [self._delete_deployment()]
if not self.skip_service_restart:
tasks.extend([self._restart_etcd(), self._restart_nats()])
await asyncio.gather(*tasks)
await self._create_deployment() await self._create_deployment()
await self._wait_for_ready() await self._wait_for_ready()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment