Unverified Commit c0e394d5 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Add k8s Fault Tolerance Validations infra (#4067)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent 9b8b9988
...@@ -236,6 +236,179 @@ test_fault_scenario[sglang-agg-tp-1-dp-1-frontend] ...@@ -236,6 +236,179 @@ test_fault_scenario[sglang-agg-tp-1-dp-1-frontend]
{"time": "2025-10-03T10:30:47", "results": [{"status": 200, "request_elapsed_time": 1.18, "url": "http://localhost:8000/v1/chat/completions", "pod": "frontend-pod"}], "total_time": 1.20} {"time": "2025-10-03T10:30:47", "results": [{"status": 200, "request_elapsed_time": 1.18, "url": "http://localhost:8000/v1/chat/completions", "pod": "frontend-pod"}], "total_time": 1.20}
``` ```
## Validation Framework
### Overview
The fault tolerance test suite includes an automated validation framework that verifies both the test execution and the results. Validation runs automatically after each test completes, ensuring that:
1. **The failure was actually injected** (Stage 1: Scenario Verification)
2. **The system recovered appropriately** (Stage 2: Results Verification)
### Two-Stage Validation Approach
#### Stage 1: Scenario Verification
Verifies that the test scenario executed correctly by checking Kubernetes events and pod states:
**For Pod Deletions (`*_pod` failures):**
- Confirms specific pods were deleted via K8s events (`Killing`, `Terminating`)
- Validates pod recreation and lifecycle transitions
- Logs deletion confirmation with timestamps
**For Process Terminations (non-`*_pod` failures):**
- Checks container restart counts (`restartCount` field)
- Looks for container restart events (`Started`, `BackOff`, `CrashLoopBackOff`)
- **Main process terminations** (e.g., `decode_worker`) → container restarts (verifiable via `restartCount`)
- **Subprocess terminations** (e.g., `sglang_*_scheduler`, `sglang_*_detokenizer`) → no container restart (subprocess becomes zombie/defunct). These produce warnings but are documented known limitations (see Backend-Specific Validations below)
**Example Stage 1 Output:**
```
╔══════════════════════════════════════════════════════════════════════════════╗
║ STAGE 1: SCENARIO VERIFICATION ║
║ (Verify test scenario executed correctly) ║
╚══════════════════════════════════════════════════════════════════════════════╝
────────────────────────────────────────────────────────────────────────────────
1.1 Verifying Specific Pod Deletion via K8s Events
────────────────────────────────────────────────────────────────────────────────
Target pod(s) for deletion: ['fault-tolerance-test-0-vllmdecodeworker-abc123']
✓ DELETION CONFIRMED: [Normal] Killing - Stopping container main
✓ Pod fault-tolerance-test-0-vllmdecodeworker-abc123 deletion verified via K8s events
✓ STAGE 1.1 PASSED: Pod deletion confirmed via K8s events
```
#### Stage 2: Results Verification
Validates system behavior based on deployment redundancy:
**High Availability (DP > 1):**
- Success rate: ≥99%
- Recovery time: <60 seconds
- Minimal impact on ongoing requests
**Single Worker (DP = 1):**
- Success rate: ≥10% (allows for failures during recovery)
- Recovery time: <180 seconds
- System eventually recovers
**Baseline (No Failures):**
- Success rate: 100%
- No failed requests
**Example Stage 2 Output:**
```
╔══════════════════════════════════════════════════════════════════════════════╗
║ STAGE 2: RESULTS VERIFICATION ║
║ (Single worker - no redundancy) ║
╚══════════════════════════════════════════════════════════════════════════════╝
────────────────────────────────────────────────────────────────────────────────
2.1 Basic Recovery Check
────────────────────────────────────────────────────────────────────────────────
✓ System recovered: 1470 requests succeeded
────────────────────────────────────────────────────────────────────────────────
2.2 Success Rate Validation (Single Worker)
────────────────────────────────────────────────────────────────────────────────
Success rate: 98.00% (1470/1500 requests)
✓ STAGE 2.2 PASSED: Success rate meets threshold (10%)
────────────────────────────────────────────────────────────────────────────────
2.3 Recovery Time Validation
────────────────────────────────────────────────────────────────────────────────
Recovery time: 150.52 seconds
✓ STAGE 2.3 PASSED: Recovery time within acceptable range (180s max)
```
### Validation Architecture
The validation system uses a **factory pattern** for flexible, extensible validation:
```
┌─────────────────────────────────────────────────────────────┐
│ test_deployment.py │
│ (validation_context fixture) │
└──────────────────────┬───────────────────────────────────────┘
│ After test completes
┌────────────▼─────────────┐
│ checker_factory.py │
│ (Checker Factory) │
└──────┬───────────┬───────┘
│ │
┌────────────▼───┐ ┌───▼────────────────┐
│ Scenario │ │ Results │
│ Checkers │ │ Checkers │
│ (Stage 1) │ │ (Stage 2) │
└────────┬───────┘ └───┬────────────────┘
│ │
┌────────▼───────┐ ┌───▼────────────────┐
│ k8s_utils.py │ │ validation_checks │
│ - Pod events │ │ - Success rate │
│ - Restart cnt │ │ - Recovery time │
└────────────────┘ └────────────────────┘
```
### Factory Functions
#### `get_checkers_for_scenario(test_name, scenario)`
Determines which checkers to run based on:
1. **Explicit checkers** in `scenario.checkers` (highest priority)
- Allows scenarios to specify custom checker lists
2. **Pattern matching** on test name:
- Delegates to `get_scenario_checker()` for Stage 1
- Delegates to `get_results_checker()` for Stage 2
#### `get_scenario_checker(test_name, scenario)`
Selects scenario checker (Stage 1) based on test name pattern:
- `*-none]``NoFailureChecker` (baseline)
- `*_pod]``PodDeletionChecker` (pod deletions)
- `*decode_worker]`, `*prefill_worker]`, `*frontend]`, `*scheduler]`, `*detokenizer]`, `*engine_core]``ProcessTerminationChecker` (process terminations)
#### `get_results_checker(test_name, scenario)`
Selects results checker (Stage 2) based on deployment redundancy:
- `*-none]``BaselineResultsChecker` (100% success required)
- **DP > 1**`HighAvailabilityResultsChecker` (≥90% success, ≤60s recovery)
- **DP = 1**`SingleWorkerResultsChecker` (≥10% success, ≤180s recovery)
### Backend-Specific Validations
#### SGLang Subprocess Limitations
SGLang has a **known limitation** where subprocess termination leads to zombie processes without automatic recovery:
**Affected Failures:**
- `sglang_decode_scheduler` - Scheduler subprocess becomes `<defunct>`
- `sglang_decode_detokenizer` - Detokenizer subprocess becomes `<defunct>`
**Expected Behavior:**
```
Process killed → becomes zombie (PID exists with Z state, <defunct>)
Container does NOT restart (main process PID 1 still running)
No new subprocess spawned
System does NOT recover automatically
```
**Validation Approach:**
- Confirms container restart count = 0 (subprocess kill, not container crash)
- Documents limitation in test output
- Does not expect recovery
#### vLLM Disaggregated Prefill Worker Resilience
vLLM decode workers use `--kv-connector-role kv_both` by default, allowing them to handle both prefill and decode operations. When a prefill worker fails, decode workers automatically take over prefill requests, resulting in 100% success rate with minimal impact.
**Expected Behavior:** Prefill worker failures don't cause request failures - this is vLLM's built-in fault tolerance, not a test issue.
### Summary Results ### Summary Results
Results are parsed from AI-Perf metrics and presented in table format after each test. The parsing script (`parse_results.py`) extracts comprehensive metrics for each scenario: Results are parsed from AI-Perf metrics and presented in table format after each test. The parsing script (`parse_results.py`) extracts comprehensive metrics for each scenario:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base checker class and validation context for fault tolerance testing.
This module provides:
1. ValidationContext - standardized input for all checkers
2. BaseChecker - abstract base class for all validation checks
3. Common interface for scenario and results validation
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
@dataclass
class ValidationContext:
"""Standardized context passed to all checkers.
This ensures all checkers receive the same input structure.
Attributes:
scenario: Scenario object being tested
log_dir: Test log directory
metrics: Parsed metrics from results (success rate, latencies, etc.)
deployment: ManagedDeployment instance (optional)
namespace: Kubernetes namespace (optional)
recovery_time: Recovery time in seconds (optional)
affected_pods: Dict mapping failure key to affected pod names (optional)
"""
scenario: Any # Scenario type to avoid circular import
log_dir: str
metrics: Dict[str, Any]
deployment: Optional[Any] = None # ManagedDeployment
namespace: Optional[str] = None
recovery_time: Optional[float] = None
affected_pods: Optional[Dict[str, list]] = None
class BaseChecker(ABC):
"""Abstract base class for all validation checkers.
All checkers must:
1. Implement check() method
2. Accept ValidationContext as input
3. Raise AssertionError on validation failure
4. Log validation progress and results
Usage:
class MyChecker(BaseChecker):
def check(self, context: ValidationContext) -> None:
if not validate_something(context.metrics):
raise AssertionError("Validation failed")
"""
def __init__(self, name: Optional[str] = None):
"""Initialize checker with optional name.
Args:
name: Human-readable name for the checker (defaults to class name)
"""
self.name = name or self.__class__.__name__
self.logger = logging.getLogger(f"{__name__}.{self.name}")
@abstractmethod
def check(self, context: ValidationContext) -> None:
"""Perform validation check.
Args:
context: ValidationContext with all necessary data
Raises:
AssertionError: If validation fails
"""
pass
def __call__(self, context: ValidationContext) -> None:
"""Allow checker to be called directly.
Args:
context: ValidationContext with all necessary data
"""
self.logger.info(f"Running checker: {self.name}")
self.check(context)
self.logger.debug(f"✓ {self.name} passed")
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name='{self.name}')"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory functions for generating checker lists for scenarios.
This module provides factory functions that determine which checkers
to run based on:
1. Explicit checkers in scenario (highest priority)
2. Pattern matching on scenario name
3. Deployment configuration (redundancy level)
"""
import logging
from typing import List, Optional
from tests.fault_tolerance.deploy.base_checker import BaseChecker
from tests.fault_tolerance.deploy.checkers import (
BaselineResultsChecker,
HighAvailabilityResultsChecker,
NoFailureChecker,
PodDeletionChecker,
ProcessTerminationChecker,
SingleWorkerResultsChecker,
)
from tests.fault_tolerance.deploy.scenarios import Scenario
logger = logging.getLogger(__name__)
def get_checkers_for_scenario(test_name: str, scenario: Scenario) -> List[BaseChecker]:
"""Get appropriate list of checkers for a scenario.
This factory function determines which checkers to use based on:
1. Explicit checkers in scenario object (highest priority)
2. Pattern matching on test name
3. Deployment redundancy (DP > 1)
Args:
test_name: Full test name (e.g., "test_fault_scenario[vllm-agg-tp-1-dp-1-decode_worker_pod]")
scenario: Scenario object
Returns:
List of BaseChecker instances to run
"""
# 1. Explicit checkers take priority
if scenario.checkers is not None:
logger.info(f"Using explicit checkers for {test_name}: {scenario.checkers}")
return scenario.checkers
# 2. Pattern-based checker selection
logger.info(f"Using pattern-based checker selection for {test_name}")
checkers: List[BaseChecker] = []
# Stage 1: Scenario verification
scenario_checker = get_scenario_checker(test_name, scenario)
if scenario_checker:
checkers.append(scenario_checker)
# Stage 2: Results verification
results_checker = get_results_checker(test_name, scenario)
if results_checker:
checkers.append(results_checker)
logger.info(f"Selected checkers: {[c.name for c in checkers]}")
return checkers
def get_scenario_checker(test_name: str, scenario: Scenario) -> Optional[BaseChecker]:
"""Get appropriate scenario checker (Stage 1).
Args:
test_name: Full test name
scenario: Scenario object
Returns:
Scenario checker instance or None
"""
# No failures scenario
if test_name.endswith("-none]"):
return NoFailureChecker()
# Pod deletion scenarios
if "_pod]" in test_name:
return PodDeletionChecker()
# Process termination scenarios (not pod deletions)
if any(
x in test_name
for x in [
"decode_worker]",
"prefill_worker]",
"frontend]",
"scheduler]",
"detokenizer]",
"engine_core]",
]
):
return ProcessTerminationChecker()
# Default: no specific scenario checker
logger.info(f"No specific scenario checker for {test_name}")
return None
def get_results_checker(test_name: str, scenario: Scenario) -> BaseChecker:
"""Get appropriate results checker (Stage 2).
Determines checker based on deployment redundancy (DP).
Args:
test_name: Full test name
scenario: Scenario object
Returns:
Results checker instance
"""
# No failures baseline
if test_name.endswith("-none]"):
return BaselineResultsChecker()
# Determine if deployment has redundancy (DP > 1)
has_redundancy = False
# Determine worker service name based on backend and deployment type
if scenario.backend == "vllm":
worker_service_name = "VllmDecodeWorker"
elif scenario.backend == "sglang":
worker_service_name = "decode"
elif scenario.backend == "trtllm":
# TensorRT-LLM uses different names for agg vs disagg
# Check test name to determine deployment type
if "disagg" in test_name:
worker_service_name = "TRTLLMDecodeWorker"
else:
# Agg deployment uses TRTLLMWorker
worker_service_name = "TRTLLMWorker"
else:
logger.warning(
f"Unsupported backend: {scenario.backend}, using default checker"
)
return SingleWorkerResultsChecker()
try:
worker_spec = scenario.deployment[worker_service_name]
if worker_spec and hasattr(worker_spec, "replicas"):
has_redundancy = worker_spec.replicas > 1
except (KeyError, AttributeError) as e:
logger.warning(f"Could not determine redundancy: {e}")
# Select appropriate results checker
if has_redundancy:
logger.info("Using HighAvailabilityResultsChecker (DP > 1)")
return HighAvailabilityResultsChecker()
else:
logger.info("Using SingleWorkerResultsChecker (DP = 1)")
return SingleWorkerResultsChecker()
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kubernetes utility functions for fault tolerance testing.
This module provides utilities for interacting with Kubernetes:
- Fetching pod events
- Listing pods in namespaces
- Logging K8s event summaries
"""
import json
import logging
import subprocess
logger = logging.getLogger(__name__)
def get_pod_restart_count(deployment, pod_name: str, namespace: str) -> dict:
"""Get container restart counts for a pod.
Args:
deployment: ManagedDeployment instance
pod_name: Name of the pod
namespace: Kubernetes namespace
Returns:
Dict with container names as keys and restart counts as values
Example: {"main": 2, "sidecar": 0}
"""
try:
cmd = [
"kubectl",
"get",
"pod",
pod_name,
"-n",
namespace,
"-o",
"json",
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0:
pod_data = json.loads(result.stdout)
restart_counts = {}
# Get restart counts from container statuses
container_statuses = pod_data.get("status", {}).get("containerStatuses", [])
for container in container_statuses:
container_name = container.get("name", "unknown")
restart_count = container.get("restartCount", 0)
restart_counts[container_name] = restart_count
# Also log if container recently restarted
state = container.get("state", {})
if "running" in state and restart_count > 0:
started_at = state["running"].get("startedAt", "unknown")
logger.info(
f"Container {container_name} restarted {restart_count} times, "
f"last started at {started_at}"
)
return restart_counts
except Exception as e:
logger.debug(f"Could not get pod restart count: {e}")
return {}
def get_k8s_events_for_pod(deployment, pod_name: str, namespace: str) -> list:
"""Get Kubernetes events for a specific pod using kubectl.
Args:
deployment: ManagedDeployment instance
pod_name: Name of the pod (can be partial match)
namespace: Kubernetes namespace
Returns:
List of event dictionaries with keys: type, reason, message, timestamp
"""
try:
# Get events for the pod using kubectl
cmd = [
"kubectl",
"get",
"events",
"-n",
namespace,
"--field-selector",
f"involvedObject.name={pod_name}",
"-o",
"json",
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0:
events_data = json.loads(result.stdout)
events = []
for item in events_data.get("items", []):
events.append(
{
"type": item.get("type", ""),
"reason": item.get("reason", ""),
"message": item.get("message", ""),
"timestamp": item.get(
"lastTimestamp", item.get("eventTime", "")
),
"count": item.get("count", 1),
}
)
return events
except Exception as e:
logger.debug(f"Could not get K8s events: {e}")
return []
def check_container_restart_events(deployment, pod_name: str, namespace: str) -> bool:
"""Check if there are container restart/crash events for a pod.
This looks for events like:
- BackOff, CrashLoopBackOff: Container keeps crashing
- Killing: Container was terminated
- Started: Container was restarted
Args:
deployment: ManagedDeployment instance
pod_name: Name of the pod
namespace: Kubernetes namespace
Returns:
True if restart/crash events found, False otherwise
"""
events = get_k8s_events_for_pod(deployment, pod_name, namespace)
restart_related_reasons = [
"BackOff",
"CrashLoopBackOff",
"Killing",
"Started",
"Unhealthy",
"FailedMount",
]
found_restart = False
for event in events:
if event["reason"] in restart_related_reasons:
logger.info(
f"Container event detected: [{event['type']}] {event['reason']} - "
f"{event['message']} (count: {event.get('count', 1)})"
)
found_restart = True
return found_restart
...@@ -279,7 +279,6 @@ def client( ...@@ -279,7 +279,6 @@ def client(
f"Status: {result['results'][-1]['status']} " f"Status: {result['results'][-1]['status']} "
f"Latency: {result['results'][-1]['request_elapsed_time']}" f"Latency: {result['results'][-1]['request_elapsed_time']}"
) )
# Write to JSONL log file # Write to JSONL log file
log.write(json.dumps(result) + "\n") log.write(json.dumps(result) + "\n")
log.flush() log.flush()
......
...@@ -559,6 +559,66 @@ def main(logs_dir, tablefmt, log_paths=None, sla=None, print_output=True): ...@@ -559,6 +559,66 @@ def main(logs_dir, tablefmt, log_paths=None, sla=None, print_output=True):
) )
logging.info("\n" + "=" * 80) logging.info("\n" + "=" * 80)
# Return results for programmatic use (e.g., validation)
# Transform legacy format to validation-compatible format
if results and len(results) == 1:
# Single test result - return in validation-compatible format
r = results[0]
success_before = r.get("success_before_requests") or 0
failed_before = r.get("failed_before_requests") or 0
success_after = r.get("success_after_requests") or 0
failed_after = r.get("failed_after_requests") or 0
return {
"log_dir": log_paths[0] if log_paths else logs_dir,
"num_clients": 10, # Default from Load config
"startup_time": r.get("start_time"),
"recovery_time": r.get("recovery_time"),
"metrics": {
"total_requests": success_before
+ failed_before
+ success_after
+ failed_after,
"successful_requests": success_before + success_after,
"failed_requests": failed_before + failed_after,
"latencies": [], # Legacy doesn't track per-client latencies
"p50_latencies": [],
"p90_latencies": [],
"p99_latencies": [],
"ttft": [],
"itl": [],
"throughputs": [],
"num_clients": 10,
},
}
elif results:
# Multiple test results - return list
return [
{
"log_dir": r.get("test", ""),
"metrics": {
"total_requests": (
(r.get("success_before_requests") or 0)
+ (r.get("failed_before_requests") or 0)
+ (r.get("success_after_requests") or 0)
+ (r.get("failed_after_requests") or 0)
),
"successful_requests": (
(r.get("success_before_requests") or 0)
+ (r.get("success_after_requests") or 0)
),
"failed_requests": (
(r.get("failed_before_requests") or 0)
+ (r.get("failed_after_requests") or 0)
),
},
"recovery_time": r.get("recovery_time"),
}
for r in results
]
return None
if __name__ == "__main__": if __name__ == "__main__":
# Configure logging # Configure logging
......
...@@ -14,14 +14,27 @@ ...@@ -14,14 +14,27 @@
# limitations under the License. # limitations under the License.
import re import re
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Optional, Pattern from typing import TYPE_CHECKING, Dict, List, Optional, Pattern
from typing_extensions import TypedDict from typing_extensions import TypedDict
from tests.utils.managed_deployment import DeploymentSpec from tests.utils.managed_deployment import DeploymentSpec
if TYPE_CHECKING:
from tests.fault_tolerance.deploy.base_checker import BaseChecker
# Import checker factory (actual import, not TYPE_CHECKING)
def _get_checkers_for_scenario(
scenario_name: str, scenario: "Scenario"
) -> List["BaseChecker"]:
"""Lazy import to avoid circular dependencies during module initialization."""
from tests.fault_tolerance.deploy.checker_factory import get_checkers_for_scenario
return get_checkers_for_scenario(scenario_name, scenario)
class TestPhase(Enum): class TestPhase(Enum):
"""Enum representing different test phases in fault tolerance testing.""" """Enum representing different test phases in fault tolerance testing."""
...@@ -187,6 +200,9 @@ class Scenario: ...@@ -187,6 +200,9 @@ class Scenario:
# When set to True, the test will be automatically marked with @pytest.mark.custom_build # When set to True, the test will be automatically marked with @pytest.mark.custom_build
# and excluded from default test runs unless --include-custom-build flag is used # and excluded from default test runs unless --include-custom-build flag is used
requires_custom_build: bool = False # Flag for tests needing custom builds/setup requires_custom_build: bool = False # Flag for tests needing custom builds/setup
# List of checkers to run for validation (scenario + results checkers)
# If None, factory will determine checkers based on scenario name and deployment
checkers: Optional[List["BaseChecker"]] = field(default=None)
# Helper functions to create deployment specs # Helper functions to create deployment specs
...@@ -410,13 +426,6 @@ def _create_backend_failures(backend, deploy_type="disagg"): ...@@ -410,13 +426,6 @@ def _create_backend_failures(backend, deploy_type="disagg"):
failures["sglang_prefill_detokenizer"] = [ failures["sglang_prefill_detokenizer"] = [
Failure(30, prefill_worker, "sglang::detokenizer", "SIGKILL") Failure(30, prefill_worker, "sglang::detokenizer", "SIGKILL")
] ]
elif backend == "trtllm":
failures["trtllm_decode_engine_core"] = [
Failure(30, decode_worker, "TRTLLM::EngineCore", "SIGKILL")
]
failures["trtllm_prefill_engine_core"] = [
Failure(30, prefill_worker, "TRTLLM::EngineCore", "SIGKILL")
]
return failures return failures
...@@ -569,15 +578,23 @@ for deployment_name, deployment_info in DEPLOYMENT_SPECS.items(): ...@@ -569,15 +578,23 @@ for deployment_name, deployment_info in DEPLOYMENT_SPECS.items():
# Get model from deployment info or use the global model # Get model from deployment info or use the global model
scenario_model = deployment_info.get("model", model) scenario_model = deployment_info.get("model", model)
scenarios[scenario_name] = Scenario( # Create scenario first (without checkers)
scenario = Scenario(
deployment=deployment_info["spec"], deployment=deployment_info["spec"],
load=load_config, load=load_config,
failures=failure, failures=failure,
model=scenario_model, model=scenario_model,
backend=backend, backend=backend,
checkers=None, # Will be populated below
requires_custom_build=is_moe, # MoE models require custom builds requires_custom_build=is_moe, # MoE models require custom builds
) )
# Generate checkers for this scenario
# This uses the checker factory to determine appropriate validation checks
scenario.checkers = _get_checkers_for_scenario(scenario_name, scenario)
scenarios[scenario_name] = scenario
# Add token overflow test scenarios # Add token overflow test scenarios
def add_token_overflow_scenarios(): def add_token_overflow_scenarios():
......
...@@ -6,9 +6,11 @@ import multiprocessing ...@@ -6,9 +6,11 @@ import multiprocessing
import re import re
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
import pytest import pytest
from tests.fault_tolerance.deploy.base_checker import ValidationContext
from tests.fault_tolerance.deploy.client_factory import get_client_function from tests.fault_tolerance.deploy.client_factory import get_client_function
from tests.fault_tolerance.deploy.parse_factory import parse_test_results from tests.fault_tolerance.deploy.parse_factory import parse_test_results
from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery_test from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery_test
...@@ -181,6 +183,14 @@ def _clients( ...@@ -181,6 +183,14 @@ def _clients(
def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811 def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811
"""Inject failures and return info about affected pods.
Returns:
Dict mapping failure info to list of affected pod names
Example: {"VllmDecodeWorker:delete_pod": ["pod-abc123", "pod-xyz789"]}
"""
affected_pods: dict[str, list] = {}
for failure in failures: for failure in failures:
time.sleep(failure.time) time.sleep(failure.time)
...@@ -205,36 +215,66 @@ def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: ...@@ -205,36 +215,66 @@ def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa:
logger.info(f"Injecting failure for: {failure}") logger.info(f"Injecting failure for: {failure}")
# Track which pods were affected by this failure
failure_key = f"{failure.pod_name}:{failure.command}"
if failure_key not in affected_pods:
affected_pods[failure_key] = []
for x in range(replicas): for x in range(replicas):
pod = pods[x % num_pods] pod = pods[x % num_pods]
# Capture the exact pod name before we kill it
pod_name = pod.name
affected_pods[failure_key].append(pod_name)
logger.info(f"Target pod for failure: {pod_name}")
if failure.command == "delete_pod": if failure.command == "delete_pod":
deployment.get_pod_logs(failure.pod_name, pod, ".before_delete") deployment.get_pod_logs(failure.pod_name, pod, ".before_delete")
logger.info(f"Deleting pod: {pod_name}")
pod.delete(force=True) pod.delete(force=True)
else: else:
processes = deployment.get_processes(pod) processes = deployment.get_processes(pod)
for process in processes: for process in processes:
if failure.command in process.command: if failure.command in process.command:
logger.info( logger.info(
f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command}" f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command} in pod {pod_name}"
) )
process.kill(failure.signal) process.kill(failure.signal)
return affected_pods
global_result_list = [] global_result_list = []
# Global storage for test results (used by validation fixture)
test_results_cache = {}
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def results_table(request, scenario): # noqa: F811 def validation_context(request, scenario): # noqa: F811
"""Parse and display results for individual test using factory pattern. """Provides shared context between test execution and validation.
This fixture creates a shared dictionary that the test populates during
execution (deployment, namespace, affected_pods), then uses that data
in teardown to parse results and run checkers.
Automatically detects result type (AI-Perf or legacy) and uses Automatically detects result type (AI-Perf or legacy) and uses
the appropriate parser. the appropriate parser. After parsing, immediately runs validation checkers.
""" """
yield # Shared context that test will populate during execution
context: dict[str, Any] = {
"deployment": None,
"namespace": None,
"affected_pods": {},
}
yield context # Test receives this and populates it
# Determine log paths based on whether this is a mixed token test # Determine log paths based on whether this is a mixed token test
log_paths = [] log_paths = []
test_name = request.node.name
logger = logging.getLogger(test_name)
if hasattr(scenario.load, "mixed_token_test") and scenario.load.mixed_token_test: if hasattr(scenario.load, "mixed_token_test") and scenario.load.mixed_token_test:
# For mixed token tests, we have separate overflow and recovery directories # For mixed token tests, we have separate overflow and recovery directories
overflow_dir = f"{request.node.name}{OVERFLOW_SUFFIX}" overflow_dir = f"{request.node.name}{OVERFLOW_SUFFIX}"
...@@ -250,7 +290,7 @@ def results_table(request, scenario): # noqa: F811 ...@@ -250,7 +290,7 @@ def results_table(request, scenario): # noqa: F811
# Use factory to auto-detect and parse results # Use factory to auto-detect and parse results
try: try:
parse_test_results( results = parse_test_results(
log_dir=None, log_dir=None,
log_paths=log_paths, log_paths=log_paths,
tablefmt="fancy_grid", tablefmt="fancy_grid",
...@@ -260,8 +300,67 @@ def results_table(request, scenario): # noqa: F811 ...@@ -260,8 +300,67 @@ def results_table(request, scenario): # noqa: F811
# force_parser can be set based on client_type if needed # force_parser can be set based on client_type if needed
# force_parser=scenario.load.client_type, # force_parser=scenario.load.client_type,
) )
# Store results for reference
if results:
logging.info(f"Results parsed: {type(results)}")
test_results_cache[test_name] = results
# IMMEDIATELY run validation now that we have results
try:
logger.info("\n" + "=" * 60)
logger.info("Running validation checks...")
logger.info("=" * 60)
# Extract metrics and recovery time from parsed results
if isinstance(results, list) and len(results) > 0:
result = results[0]
elif isinstance(results, dict):
result = results
else:
logger.warning(f"Unexpected result format: {type(results)}")
result = None
if result:
metrics = result.get("metrics", {})
recovery_time = result.get("recovery_time")
# Create ValidationContext for all checkers
validation_ctx = ValidationContext(
scenario=scenario,
log_dir=test_name,
metrics=metrics,
deployment=context.get("deployment"),
namespace=context.get("namespace"),
recovery_time=recovery_time,
affected_pods=context.get("affected_pods", {}),
)
# Use pre-generated checkers from scenario
# Checkers were already determined during scenario creation
checkers = scenario.checkers or []
# Run all checkers
for checker in checkers:
logger.info(f"\nRunning checker: {checker.name}")
checker.check(validation_ctx)
logger.info("=" * 60)
logger.info("✓ All validation checks passed")
logger.info("=" * 60 + "\n")
except AssertionError as e:
logger.error("=" * 60)
logger.error(f"✗ Validation failed: {e}")
logger.error("=" * 60 + "\n")
# Re-raise to fail the test
raise
except Exception as e:
logger.error(f"Validation error: {e}")
# Don't fail test on validation errors (non-assertion exceptions)
logger.warning("Skipping validation due to error")
except Exception: except Exception:
logging.exception("Failed to parse results for %s", request.node.name) logging.exception("Failed to parse results for %s", test_name)
# Add all directories to global list for session summary # Add all directories to global list for session summary
global_result_list.extend(log_paths) global_result_list.extend(log_paths)
...@@ -350,9 +449,16 @@ async def test_fault_scenario( ...@@ -350,9 +449,16 @@ async def test_fault_scenario(
request, request,
image, image,
namespace, namespace,
validation_context, # noqa: F811 # Shared context for passing data to validation
): ):
""" """
Test dynamo serve deployments with injected failures Test dynamo serve deployments with injected failures
Flow:
1. validation_context fixture creates empty dict: {"deployment": None, "namespace": None, "affected_pods": {}}
2. This test populates it: validation_context["deployment"] = deployment, etc.
3. After test completes, fixture reads validation_context and runs validation checkers
4. Checkers use the populated ValidationContext to verify test results and K8s events
""" """
logger = logging.getLogger(request.node.name) logger = logging.getLogger(request.node.name)
...@@ -395,6 +501,10 @@ async def test_fault_scenario( ...@@ -395,6 +501,10 @@ async def test_fault_scenario(
log_dir=request.node.name, log_dir=request.node.name,
deployment_spec=scenario.deployment, deployment_spec=scenario.deployment,
) as deployment: ) as deployment:
# Populate shared context for validation
validation_context["deployment"] = deployment
validation_context["namespace"] = namespace
with _clients( with _clients(
logger, logger,
request, request,
...@@ -403,4 +513,8 @@ async def test_fault_scenario( ...@@ -403,4 +513,8 @@ async def test_fault_scenario(
model, model,
scenario.load, # Pass entire Load config object scenario.load, # Pass entire Load config object
): ):
_inject_failures(scenario.failures, logger, deployment) # Inject failures and capture which pods were affected
affected_pods = _inject_failures(scenario.failures, logger, deployment)
validation_context["affected_pods"] = affected_pods
logger.info(f"Affected pods during test: {affected_pods}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic validation check functions for fault tolerance testing.
This module provides atomic validation primitives:
- Success rate validation
- Recovery time validation
- Latency SLA validation
- Kubernetes health checks
"""
import logging
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
def check_success_rate(metrics: Dict[str, Any], min_threshold: float = 0.80) -> None:
"""Check that success rate meets minimum threshold.
Args:
metrics: Parsed metrics dictionary with request counts
min_threshold: Minimum acceptable success rate (0.0 to 1.0)
Raises:
AssertionError: If success rate is below threshold
"""
total_requests = metrics.get("total_requests", 0)
successful_requests = metrics.get("successful_requests", 0)
if total_requests == 0:
logger.warning("No requests found in metrics")
raise AssertionError("Validation failed: No requests were executed")
success_rate = successful_requests / total_requests
logger.info(
f"Success rate: {success_rate:.2%} "
f"({successful_requests}/{total_requests} requests)"
)
if success_rate < min_threshold:
raise AssertionError(
f"Success rate {success_rate:.2%} is below threshold {min_threshold:.2%}"
)
def check_recovery_time(
recovery_time: Optional[float], max_seconds: Optional[float] = None
) -> None:
"""Check that recovery time is within acceptable bounds.
Args:
recovery_time: Recovery time in seconds (can be None for no-failure scenarios)
max_seconds: Maximum acceptable recovery time (None = no check)
Raises:
AssertionError: If recovery time exceeds maximum
"""
if recovery_time is None:
logger.info("No recovery time measured (expected for no-failure scenarios)")
return
logger.info(f"Recovery time: {recovery_time:.2f} seconds")
if max_seconds is not None and recovery_time > max_seconds:
raise AssertionError(
f"Recovery time {recovery_time:.2f}s exceeds maximum {max_seconds}s"
)
def check_no_failures(metrics: Dict[str, Any]) -> None:
"""Check that there were no failed requests.
Args:
metrics: Parsed metrics dictionary
Raises:
AssertionError: If any requests failed
"""
failed_requests = metrics.get("failed_requests", 0)
total_requests = metrics.get("total_requests", 0)
if failed_requests > 0:
raise AssertionError(
f"Expected no failures, but {failed_requests}/{total_requests} requests failed"
)
logger.info(f"All {total_requests} requests succeeded")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment