scenarios.py 47.1 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
import asyncio
import logging
18
import re
19
import time
20
from abc import ABC, abstractmethod
21
from dataclasses import dataclass, field
22
from enum import Enum, auto
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern
24

25
from typing_extensions import Required, TypedDict
26

27
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
28

29
30
31
if TYPE_CHECKING:
    from tests.fault_tolerance.deploy.base_checker import BaseChecker

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
logger = logging.getLogger(__name__)


# Lazy import to avoid kubernetes dependency during module import
def _get_gpu_helpers():
    """Lazily import GPU helper functions to avoid kubernetes dependency at module level."""
    from kubernetes.client.rest import ApiException

    from tests.fault_tolerance.hardware.fault_injection_service.helpers import (
        get_available_gpu_ids,
        get_gpu_info,
        get_processes_on_gpu,
    )

    return get_available_gpu_ids, get_gpu_info, get_processes_on_gpu, ApiException

48
49
50
51
52
53
54
55
56
57

# Import checker factory (actual import, not TYPE_CHECKING)
def _get_checkers_for_scenario(
    scenario_name: str, scenario: "Scenario"
) -> List["BaseChecker"]:
    """Lazy import to avoid circular dependencies during module initialization."""
    from tests.fault_tolerance.deploy.checker_factory import get_checkers_for_scenario

    return get_checkers_for_scenario(scenario_name, scenario)

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

class TestPhase(Enum):
    """Enum representing different test phases in fault tolerance testing."""

    STANDARD = auto()
    OVERFLOW = auto()
    RECOVERY = auto()


class DeploymentInfo(TypedDict, total=False):
    """Information about a deployment configuration.

    Attributes:
        spec: DeploymentSpec object defining the deployment configuration
        backend: Backend type - "vllm", "sglang", or "trtllm"
        model: Optional model identifier (e.g., "deepseek-ai/DeepSeek-V2-Lite")
        is_moe: Optional flag indicating if this is a Mixture-of-Experts model
    """

77
78
    spec: Required[DeploymentSpec]
    backend: Required[str]
79
80
81
82
83
84
85
86
    model: str
    is_moe: bool


# Test phase suffixes derived from TestPhase enum
OVERFLOW_SUFFIX = f"_{TestPhase.OVERFLOW.name.lower()}"
RECOVERY_SUFFIX = f"_{TestPhase.RECOVERY.name.lower()}"

87
# Worker name mapping for different backends
88
89
90
91
92
93
94
95
96
WORKER_MAP = {
    "vllm": {
        "decode": "VllmDecodeWorker",
        "prefill": "VllmPrefillWorker",
    },
    "sglang": {
        "decode": "decode",
        "prefill": "prefill",
    },
97
98
99
100
101
    "trtllm": {
        "decode": "TRTLLMDecodeWorker",
        "decode_agg": "TRTLLMWorker",  # Aggregated uses different name
        "prefill": "TRTLLMPrefillWorker",
    },
102
103
}

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Process ready patterns for recovery detection
WORKER_READY_PATTERNS: Dict[str, Pattern] = {
    # Frontend
    "Frontend": re.compile(r"added model"),
    # vLLM workers
    "VllmDecodeWorker": re.compile(
        r"VllmWorker for (?P<model_name>.*?) has been initialized"
    ),
    "VllmPrefillWorker": re.compile(
        r"VllmWorker for (?P<model_name>.*?) has been initialized"
    ),
    # SGLang workers - look for their specific initialization messages
    "decode": re.compile(
        r"Model registration succeeded|Decode worker handler initialized|Worker handler initialized"
    ),
    "prefill": re.compile(
        r"Model registration succeeded|Prefill worker handler initialized|Worker handler initialized"
    ),
122
123
124
125
126
127
128
129
130
131
    # TensorRT-LLM workers
    "TRTLLMWorker": re.compile(
        r"TrtllmWorker for (?P<model_name>.*?) has been initialized|Model registration succeeded"
    ),
    "TRTLLMDecodeWorker": re.compile(
        r"TrtllmWorker for (?P<model_name>.*?) has been initialized|Model registration succeeded"
    ),
    "TRTLLMPrefillWorker": re.compile(
        r"TrtllmWorker for (?P<model_name>.*?) has been initialized|Model registration succeeded"
    ),
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
}


def get_all_worker_types() -> list[str]:
    """Get all worker type names for both vLLM and SGLang."""
    worker_types = ["Frontend"]
    for backend in WORKER_MAP.values():
        worker_types.extend(backend.values())
    # Remove duplicates while preserving order
    seen = set()
    result = []
    for x in worker_types:
        if x not in seen:
            seen.add(x)
            result.append(x)
    return result


def get_worker_ready_pattern(worker_name: str) -> Optional[Pattern]:
    """Get the ready pattern for a specific worker type."""
    return WORKER_READY_PATTERNS.get(worker_name)


def get_backend_workers(backend: str) -> Dict[str, str]:
    """Get worker mapping for a specific backend."""
    return WORKER_MAP.get(backend, {})

159
160
161
162
163
164
165

@dataclass
class Load:
    clients: int = 10
    requests_per_client: int = 150
    input_token_length: int = 100
    output_token_length: int = 100
166
    max_retries: int = 3  # Increased for fault tolerance
167
    sla: Optional[float] = None
168
    client_type: str = "aiperf"  # "aiperf" or "legacy"
169
170
171
    max_request_rate: float = (
        1.0  # Rate limiting (requests/sec) for both AI-Perf and legacy clients
    )
172
173
174
175
176
177
178
    success_threshold: float = 90.0  # Success rate threshold for tests

    # For mixed token testing (overflow + recovery)
    mixed_token_test: bool = False
    overflow_token_length: Optional[int] = None  # Tokens for overflow requests
    overflow_request_count: int = 15  # Number of overflow requests
    normal_request_count: int = 15  # Number of normal requests after overflow
179

180
181
182
183
    continuous_load: bool = (
        False  # If True, use continuous load instead of fixed request count
    )

184
185

@dataclass
186
187
188
189
class Failure(ABC):
    """Base class for all failure types."""

    # time to wait in seconds before the failure is injected
190
    time: int
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    # names of DGD services to inject the failure into the corresponding pods for
    service_names: list[str]

    @abstractmethod
    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute the failure injection.

        Args:
            deployment: The managed deployment to inject the failure into
            logger: Logger instance for logging failure injection

        Returns: List of affected pod names
        """
        pass

    @abstractmethod
    def get_failure_key(self) -> str:
        """Get the failure key for the failure."""
        pass


@dataclass
class RollingUpgradeFailure(Failure):
    """Failure type for triggering rolling upgrades."""

    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute rolling upgrade failure injection."""
        await deployment.trigger_rolling_upgrade(self.service_names)

        # Need to wait for the deployment to be unready so we know the rolling upgrade has started
        await deployment.wait_for_unready(timeout=60, log_interval=10)

        await deployment._wait_for_ready(timeout=1800)  # 30 minute timeout

        await asyncio.sleep(
            self.time
        )  # have some requests processed after the rolling upgrade has completed

        return await deployment.get_pod_names(self.service_names)

    def get_failure_key(self) -> str:
        """Get the failure key for the rolling upgrade failure."""
        return f"rolling_upgrade:{','.join(self.service_names)}"


@dataclass
class DeletePodFailure(Failure):
    """Failure type for deleting pods."""

    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute pod deletion failure injection."""
        service_pod_dict = deployment.get_pods(self.service_names)
        pod_names: list[str] = []
        for service_name, pods in service_pod_dict.items():
            for pod in pods:
                deployment.get_pod_manifest_logs_metrics(
                    service_name, pod, ".before_delete"
                )
                pod.delete(force=True)  # force means no graceful termination
                pod_names.append(pod.name)

        return pod_names

    def get_failure_key(self) -> str:
        """Get the failure key for the delete pod failure."""
        return f"delete_pod:{','.join(self.service_names)}"


class TerminateProcessFailure(Failure):
    """Failure type for terminating specific processes by name."""

    def __init__(
        self,
        time: int,
        service_names: list[str],
        signal: str = "SIGINT",
        process_name: str = "",
    ):
        """Initialize TerminateProcessFailure.

        Args:
            time: Time to wait in seconds before the failure is injected
            service_names: Names of DGD services to inject the failure into
            signal: Signal to send (default: "SIGINT")
            process_name: Name of the process to terminate (required)
            end_condition: End condition for failure (e.g., "dgd_ready")
        """
        super().__init__(
            time=time,
            service_names=service_names,
        )
        if not process_name or not signal:
            raise ValueError(
                "process_name and signal are required for TerminateProcessFailure"
            )
        self.process_name = process_name
        self.signal = signal

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    def _log_process_list(self, pod):
        """Log filtered process list from ps aux."""
        *_, ApiException = _get_gpu_helpers()

        try:
            result = pod.exec(["ps", "aux"])
            if result.returncode != 0:
                logger.warning(f"ps aux command exited with code {result.returncode}")
                return
            ps_output = result.stdout.decode() if result.stdout else ""
            lines = ps_output.split("\n")

            relevant_processes = [
                line
                for line in lines[1:]
                if any(
                    keyword in line.lower() for keyword in ["python", "vllm", "dynamo"]
                )
            ]

            # Log as single block to avoid [TEST] prefix on each line
            output_lines = ["\n--- Process List (ps aux) ---", lines[0]]  # Header
            output_lines.extend(relevant_processes)
            logger.info("\n".join(output_lines))

        except ApiException as e:
            logger.warning(f"Kubernetes API error getting ps aux: {e}")
        except Exception:
            logger.exception("Unexpected error getting process list")

    def _get_process_details_string(self, pod, pid: int) -> str:
        """Get detailed information for a specific PID as a string."""
        *_, ApiException = _get_gpu_helpers()

        try:
            ps_result = pod.exec(["ps", "-p", str(pid), "-o", "pid,comm,args"])
            if ps_result.returncode != 0:
                return ""

            ps_line = ps_result.stdout.decode().strip()
            ps_lines = ps_line.split("\n")

            if len(ps_lines) > 1:
                return f"    PID {pid}: {ps_lines[1]}"

            return ""
        except ApiException:
            # Process may not exist or API unavailable - expected during termination
            return ""
        except Exception:
            # Unexpected error (AttributeError, IndexError, UnicodeDecodeError, etc.)
            logger.exception(f"Unexpected error getting process details for PID {pid}")
            return ""

    def _log_gpu_discovery_info(self, pod):
        """Log GPU information using gpu_discovery utilities."""
        try:
            (
                get_available_gpu_ids,
                get_gpu_info,
                get_processes_on_gpu,
                ApiException,
            ) = _get_gpu_helpers()
            gpu_ids = get_available_gpu_ids(pod)

            if not gpu_ids:
                logger.warning("No GPUs found in pod")
                return

            # Build output as single message
            output_lines = [
                "\n--- GPU Information ---",
                f"Available GPUs: {gpu_ids}",
                "\n--- Per-GPU Process Mapping (from query-compute-apps) ---",
            ]

            for gpu_id in gpu_ids:
                gpu_info_lines = self._get_single_gpu_info(pod, gpu_id)
                output_lines.extend(gpu_info_lines)

            logger.info("\n".join(output_lines))

        except ApiException as e:
            logger.warning(f"Kubernetes API error getting GPU information: {e}")
        except Exception:
            logger.exception("Unexpected error getting GPU information")

    def _get_single_gpu_info(self, pod, gpu_id: int) -> list[str]:
        """Get information for a single GPU as list of strings."""
        (
            get_available_gpu_ids,
            get_gpu_info,
            get_processes_on_gpu,
        ) = _get_gpu_helpers()
        lines = []
        gpu_info = get_gpu_info(pod, gpu_id)

        if gpu_info:
            lines.append(
                f"\nGPU {gpu_id}: {gpu_info.get('name', 'Unknown')} "
                f"(Memory: {gpu_info.get('memory_total', 'Unknown')})"
            )
        else:
            lines.append(f"\nGPU {gpu_id}:")

        pids = get_processes_on_gpu(pod, gpu_id)

        if pids:
            lines.append(f"  Processes (PIDs): {pids}")
            for pid in pids:
                proc_details = self._get_process_details_string(pod, pid)
                if proc_details:
                    lines.append(proc_details)
        else:
            lines.append(
                "  No processes running (note: small memory footprints may not appear)"
            )

        return lines

    def _parse_nvidia_smi_process_line(self, line: str):
        """Parse a single line from nvidia-smi processes section.

        Returns:
            Tuple of (gpu_id, pid, process_name, memory) or None if parsing fails
        """
        parts = [p.strip() for p in line.split("|") if p.strip()]
        if not parts:
            return None

        fields = parts[0].split()
        if len(fields) < 6:
            return None

        try:
            gpu_id = fields[0]
            pid = fields[3]
            process_name = " ".join(fields[5:-1])
            memory = fields[-1]
            return (gpu_id, pid, process_name, memory)
        except (ValueError, IndexError):
            return None

    def _log_nvidia_smi_output(self, pod):
        """Log complete nvidia-smi output with parsed process mapping."""
        *_, ApiException = _get_gpu_helpers()

        try:
            result = pod.exec(["nvidia-smi"])
            if result.returncode != 0:
                logger.warning(
                    f"nvidia-smi command exited with code {result.returncode}"
                )
                return
            gpu_status = result.stdout.decode() if result.stdout else ""

            output_lines = [
                "\n--- Complete GPU->Process Mapping (from full nvidia-smi) ---"
            ]

            if "Processes:" in gpu_status:
                output_lines.extend(self._get_parsed_nvidia_smi_processes(gpu_status))

            output_lines.append("\n--- Full nvidia-smi Output (for reference) ---")
            output_lines.append(gpu_status)

            logger.info("\n".join(output_lines))

        except ApiException as e:
            logger.warning(f"Kubernetes API error getting nvidia-smi: {e}")
        except Exception:
            logger.exception("Unexpected error getting nvidia-smi output")

    def _get_parsed_nvidia_smi_processes(self, gpu_status: str) -> list[str]:
        """Parse nvidia-smi processes section and return as list of strings."""
        lines = ["GPU -> PID -> Process Name -> Memory:"]

        try:
            processes_section = gpu_status.split("Processes:")[1]
            processes_lines = processes_section.split("\n")

            for line in processes_lines:
                if "MiB" in line and "|" in line:
                    parsed = self._parse_nvidia_smi_process_line(line)
                    if parsed:
                        gpu_id, pid, process_name, memory = parsed
                        lines.append(
                            f"  GPU {gpu_id}: PID {pid} ({process_name}) - {memory}"
                        )
        except (IndexError, ValueError) as e:
            # Expected if nvidia-smi output format is unexpected
            logger.debug(f"Failed to parse nvidia-smi processes: {e}")
        except Exception:
            # Unexpected error - should be investigated
            logger.exception("Unexpected error parsing nvidia-smi processes")

        return lines

    def _log_pod_diagnostics(self, pod, phase: str):
        """Log comprehensive pod diagnostics including process list, GPU info, and nvidia-smi."""
        logger.info(
            f"\n{'=' * 80}\nPOD DIAGNOSTICS - {phase}\nPod: {pod.name}\n{'=' * 80}"
        )

        self._log_process_list(pod)
        self._log_gpu_discovery_info(pod)
        self._log_nvidia_smi_output(pod)

        logger.info("=" * 80)

    def _wait_for_pod_ready(
        self,
        pod,
        max_wait: int = 120,
        poll_interval: int = 1,
    ) -> Optional[int]:
        """Poll for pod to become ready and return elapsed time or None if timeout.

        Checks Kubernetes pod readiness (readiness probe passes). Clients perform
        their own service health checks independently.

        Args:
            pod: Kubernetes pod to check
            max_wait: Maximum seconds to wait (default: 120)
            poll_interval: Seconds between polls (default: 1)

        Returns:
            Elapsed seconds when pod becomes ready, or None if timeout
        """
        *_, ApiException = _get_gpu_helpers()

        for elapsed in range(max_wait):
            time.sleep(poll_interval)
            try:
                pod.refresh()
                if pod.ready():
                    actual_elapsed = (elapsed + 1) * poll_interval
                    logger.info(
                        f"Pod '{pod.name}' became ready after ~{actual_elapsed}s"
                    )
                    return actual_elapsed
            except ApiException as e:
                logger.debug(f"Kubernetes API error checking pod status: {e}")
            except Exception as e:
                logger.exception(
                    f"Unexpected error checking pod readiness for {pod.name}: {e}"
                )
                raise

        logger.warning(f"Pod '{pod.name}' did not become ready within {max_wait}s")
        return None

    def _check_frontend_health_after_restart(
        self,
        deployment,
        service_name: str,
        base_status: str,
    ) -> str:
        """Check Frontend service health after a pod restart.

        Args:
            deployment: ManagedDeployment instance
            service_name: Name of the service that was restarted
            base_status: Base status string (e.g., "ready after 102s")

        Returns:
            Updated status string with Frontend health check result
        """
        from tests.fault_tolerance.deploy.client import get_frontend_port
        from tests.utils.client import wait_for_model_availability

        logger.info(
            f"Checking Frontend service health (after {service_name} pod restart)..."
        )

571
        pod_ports: dict[str, Any] = {}  # Temporary dict for port forward tracking
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        try:
            logger.info("Getting frontend pod and setting up port forward...")
            frontend_pod_name, local_port, frontend_pod = get_frontend_port(
                managed_deployment=deployment,
                client_index=0,  # Use first frontend pod
                deployment_spec=deployment.deployment_spec,
                pod_ports=pod_ports,
                logger=logger,
            )

            if not frontend_pod_name or not local_port:
                logger.warning("Failed to get frontend port forward")
                return f"{base_status}, Frontend port forward failed"

            # Get model from deployment spec
            model = self._get_model_from_deployment_spec(deployment, service_name)
            endpoint = getattr(
                deployment.deployment_spec, "_endpoint", "/v1/chat/completions"
            )

            logger.info(
                f"Checking model '{model}' availability at localhost:{local_port}..."
            )
            url = f"http://localhost:{local_port}"
            service_healthy = wait_for_model_availability(
                url=url,
                endpoint=endpoint,
                model=model,
                logger=logger,
            )

            if service_healthy:
                logger.info("Frontend service health check passed")
                return f"{base_status}, Frontend healthy"
            else:
                logger.warning("Frontend service health check failed")
                return f"{base_status}, Frontend health check failed"

        except Exception as e:
            logger.exception(f"Error checking Frontend health: {e}")
            return f"{base_status}, Frontend health check error"
        finally:
            # Clean up port forwards
            for pf_name, port_forward in pod_ports.items():
                try:
                    port_forward.stop()
                except Exception as e:
                    logger.warning(f"Error stopping port forward: {e}")

    def _get_model_from_deployment_spec(
        self,
        deployment,
        service_name: str,
    ) -> str:
        """Get model name from deployment spec.

        Tries to get model from the terminated service, otherwise uses default.

        Args:
            deployment: ManagedDeployment instance
            service_name: Name of the service that was terminated

        Returns:
            Model name (always returns a value, uses default as fallback)
        """
        logger.info(f"Attempting to get model from terminated service '{service_name}'")
        try:
            terminated_service_spec = deployment.deployment_spec[service_name]
            model = terminated_service_spec.model
            if model:
                logger.info(
                    f"Got model '{model}' from terminated service '{service_name}'"
                )
                return model
        except (KeyError, AttributeError) as e:
            logger.info(f"Could not get model from {service_name}: {e}")

        # Fallback to default
        model = "Qwen/Qwen3-0.6B"
        logger.info(f"Using default model: {model}")
        return model

654
655
656
657
658
659
660
661
    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute process termination failure injection."""
        service_pod_dict = deployment.get_pods(self.service_names)
        pod_names: list[str] = []
        for service_name, pods in service_pod_dict.items():
            for pod in pods:
662
663
664
                # Log diagnostics before termination
                self._log_pod_diagnostics(pod, "BEFORE PROCESS TERMINATION")

665
666
667
668
669
670
671
                processes = deployment.get_processes(pod)
                for process in processes:
                    if self.process_name in process.command:
                        logger.info(
                            f"Terminating {service_name} pod {pod} Pid {process.pid} Command {process.command}"
                        )
                        process.kill(self.signal)
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693

                # Wait for pod to recover after process termination
                logger.info(
                    f"\nWaiting for pod '{pod.name}' to become ready (max {120}s)..."
                )
                elapsed = self._wait_for_pod_ready(pod)

                if not elapsed:
                    restart_status = f"timeout after {120}s"
                    self._log_pod_diagnostics(pod, f"AFTER RESTART ({restart_status})")
                    pod_names.append(pod.name)
                    continue

                # Check Frontend service health after pod is ready
                restart_status = self._check_frontend_health_after_restart(
                    deployment=deployment,
                    service_name=service_name,
                    base_status=f"ready after {elapsed}s",
                )

                self._log_pod_diagnostics(pod, f"AFTER RESTART ({restart_status})")

694
695
696
697
698
699
700
                pod_names.append(pod.name)

        return pod_names

    def get_failure_key(self) -> str:
        """Get the failure key for the terminate process failure."""
        return f"terminate_process:{','.join(self.service_names)}:{self.process_name}:{self.signal}"
701
702


703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
@dataclass
class TokenOverflowFailure(Failure):
    """
    Failure type for injecting token overflow (prompt > max_seq_len)
    """

    overflow_multiplier: float = 2.0  # How much to exceed max_seq_len (e.g., 2.0 = 2x)
    max_seq_len: int = 1024

    def __init__(
        self,
        time: int,
        max_seq_len: int = 1024,
        overflow_multiplier: float = 2.0,
    ):
        super().__init__(
            time=time,
720
            service_names=["Client"],
721
722
723
724
725
        )
        self.max_seq_len = max_seq_len
        self.overflow_multiplier = overflow_multiplier
        self.overflow_token_count = int(max_seq_len * overflow_multiplier)

726
727
728
729
730
731
732
733
734
735
736
737
738
    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Token overflow is handled client-side, so this is a no-op."""
        # The actual overflow is handled by the client configuration
        # which uses the input_token_length from the Load config
        # This is just a placeholder for the abstract method
        return []

    def get_failure_key(self) -> str:
        """Get the failure key for the token overflow failure."""
        return f"token_overflow:{self.overflow_token_count}"

739

740
741
742
743
744
745
@dataclass
class Scenario:
    deployment: DeploymentSpec
    load: Load
    failures: list[Failure]
    model: Optional[str] = None
746
    backend: str = "vllm"  # Backend type for tracking
747
748
749
    # When set to True, the test will be automatically marked with @pytest.mark.custom_build
    # and excluded from default test runs unless --include-custom-build flag is used
    requires_custom_build: bool = False  # Flag for tests needing custom builds/setup
750
751
752
    # List of checkers to run for validation (scenario + results checkers)
    # If None, factory will determine checkers based on scenario name and deployment
    checkers: Optional[List["BaseChecker"]] = field(default=None)
753
754


755
# Helper functions to create deployment specs
756
def _create_deployment_info(backend: str, yaml_path: str) -> DeploymentInfo:
757
758
759
760
761
762
763
764
765
766
    """Create a deployment spec with backend information.

    Args:
        backend: Backend type ("vllm", "sglang", or "trtllm")
        yaml_path: Path to the deployment YAML file

    Returns:
        DeploymentInfo dictionary with spec and backend
    """
    return DeploymentInfo(spec=DeploymentSpec(yaml_path), backend=backend)
767
768


769
770
771
772
773
774
775
776
def _set_replicas(deployment_spec, backend, deploy_type, replicas):
    """Set replicas for all components in a deployment based on backend type."""
    spec = deployment_spec["spec"]

    # Frontend is common for all backends
    spec["Frontend"].replicas = replicas

    if backend in WORKER_MAP:
777
778
779
780
781
782
        # For trtllm agg deployments, use different worker name
        if backend == "trtllm" and deploy_type == "agg":
            decode_worker = WORKER_MAP[backend]["decode_agg"]
        else:
            decode_worker = WORKER_MAP[backend]["decode"]

783
        # always scale decode
784
        spec[decode_worker].replicas = replicas
785
786
787
788
789
        # scale prefill only for disagg
        if deploy_type == "disagg":
            spec[WORKER_MAP[backend]["prefill"]].replicas = replicas


790
791
792
def _set_tensor_parallel(
    deployment_spec: DeploymentInfo, backend: str, deploy_type: str, tp_size: int
):
793
794
795
796
    """Set tensor parallel size for worker components."""
    spec = deployment_spec["spec"]

    if backend in WORKER_MAP:
797
798
799
800
801
        # For trtllm agg deployments, use different worker name
        if backend == "trtllm" and deploy_type == "agg":
            decode_worker = WORKER_MAP[backend]["decode_agg"]
        else:
            decode_worker = WORKER_MAP[backend]["decode"]
802
803
804
805
806
807
808
809
810
811
812
813
        prefill_worker = WORKER_MAP[backend]["prefill"]

        if deploy_type == "agg":
            if hasattr(spec, "set_tensor_parallel"):
                spec.set_tensor_parallel(tp_size, [decode_worker])
            else:
                spec[decode_worker].tensor_parallel_size = tp_size
        elif deploy_type == "disagg":
            spec[prefill_worker].tensor_parallel_size = tp_size
            spec[decode_worker].tensor_parallel_size = tp_size


814
815
816
817
818
819
820
821
822
823
def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]:
    """Create all deployment specifications for a given backend.

    Args:
        backend: Backend type ("vllm", "sglang", or "trtllm")

    Returns:
        Dictionary mapping deployment names to DeploymentInfo objects
    """
    deployments: Dict[str, DeploymentInfo] = {}
824
825
826

    # Define the yaml files for agg and disagg deployments
    yaml_files = {
827
828
        "agg": f"examples/backends/{backend}/deploy/agg.yaml",
        "disagg": f"examples/backends/{backend}/deploy/disagg.yaml",
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    }

    # Define the different configurations to test
    configurations = [
        {"tp": 1, "dp": 1},
        {"tp": 1, "dp": 2},
        {"tp": 2, "dp": 1},
        {"tp": 4, "dp": 1},
    ]

    for deploy_type in ["agg", "disagg"]:
        for config in configurations:
            tp_size = config["tp"]
            dp_replicas = config["dp"]
            # Skip creating disagg scenarios for TP > 1 if DP is also > 1 (uncommon case)
            if deploy_type == "disagg" and tp_size > 1 and dp_replicas > 1:
                continue

            # Construct the scenario name
            name_parts = [backend, deploy_type]

            if deploy_type == "agg":
                name_parts.append(f"tp-{tp_size}")
            elif deploy_type == "disagg":
                name_parts.append(f"prefill-tp-{tp_size}-decode-tp-{tp_size}")

            name_parts.append(f"dp-{dp_replicas}")

            scenario_name = "-".join(name_parts)

            # Create and configure the deployment
860
            deployment = _create_deployment_info(backend, yaml_files[deploy_type])
861
862
863
864
865
866
867
868
869
870
            if tp_size > 1:
                _set_tensor_parallel(deployment, backend, deploy_type, tp_size)
            if dp_replicas > 1:
                _set_replicas(deployment, backend, deploy_type, dp_replicas)

            deployments[scenario_name] = deployment

    return deployments


871
872
873
874
875
876
877
878
879
880
881
882
def _create_moe_deployments_for_backend(
    backend: str = "vllm",
) -> Dict[str, DeploymentInfo]:
    """Create MoE-specific deployment configurations for DeepSeek-V2-Lite.

    Args:
        backend: Backend type (default: "vllm")

    Returns:
        Dictionary mapping deployment names to DeploymentInfo objects
    """
    deployments: Dict[str, DeploymentInfo] = {}
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897

    # Only test tp=1, dp=2 for now
    tp_size = 1
    dp_replicas = (
        2  # Note: this is handled internally by vLLM with --data-parallel-size
    )

    template_dir = "tests/fault_tolerance/deploy/templates"
    yaml_files = {
        "agg": f"{template_dir}/{backend}/moe_agg.yaml",
        "disagg": f"{template_dir}/{backend}/moe_disagg.yaml",
    }

    for deploy_type in ["agg", "disagg"]:
        scenario_name = f"{backend}-moe-{deploy_type}-tp-{tp_size}-dp-{dp_replicas}"
898
899
900
901
902
903
        deployment = DeploymentInfo(
            spec=DeploymentSpec(yaml_files[deploy_type]),
            backend=backend,
            model="deepseek-ai/DeepSeek-V2-Lite",
            is_moe=True,
        )
904
905
906
907
908
909

        deployments[scenario_name] = deployment

    return deployments


910
# Create all deployment specifications
911
912
913
914
DEPLOYMENT_SPECS: Dict[str, DeploymentInfo] = {}
DEPLOYMENT_SPECS.update(_create_deployments_for_backend("vllm"))
DEPLOYMENT_SPECS.update(_create_deployments_for_backend("sglang"))
DEPLOYMENT_SPECS.update(_create_deployments_for_backend("trtllm"))
915

916
# Add MoE deployments for vLLM only
917
DEPLOYMENT_SPECS.update(_create_moe_deployments_for_backend("vllm"))
918

919
920
921
922
923
924
925
926

# Each failure scenaro contains a list of failure injections
# Each failure injection has a time in seconds after the pervious injection and
# a list of failures to inject including the number of failures for each type.
# Failures are currently process termination or pod deletion
#
# Example:
#
927
#   "prefill_worker": [Failure(30, "VllmPrefillWorker", "dynamo.vllm", "SIGKILL")],
928
929
#
# terminates 1 prefill worker after 30 seconds
930
931
932
933
934
935
936
def _create_backend_failures(backend, deploy_type="disagg"):
    """Generate backend-specific failure scenarios.

    Args:
        backend: Backend type (vllm, sglang, trtllm)
        deploy_type: Deployment type (agg or disagg)
    """
937
    workers = WORKER_MAP[backend]
938
939
940
941
942
943
944

    # Use correct worker name based on deployment type
    if backend == "trtllm" and deploy_type == "agg":
        decode_worker = workers["decode_agg"]
    else:
        decode_worker = workers["decode"]

945
946
947
948
    prefill_worker = workers["prefill"]
    process_name = f"dynamo.{backend}"

    failures = {
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        "frontend": [
            TerminateProcessFailure(
                30, ["Frontend"], "SIGINT", process_name="dynamo.frontend"
            )
        ],
        "frontend_pod": [DeletePodFailure(30, ["Frontend"])],
        "decode_worker": [
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name=process_name
            )
        ],
        "decode_worker_pod": [DeletePodFailure(30, [decode_worker])],
        "prefill_worker": [
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name=process_name
            )
        ],
        "prefill_worker_pod": [DeletePodFailure(30, [prefill_worker])],
967
968
969
970
971
        "none": [],
    }

    if backend == "vllm":
        failures["vllm_decode_engine_core"] = [
972
973
974
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="VLLM::EngineCore"
            )
975
976
        ]
        failures["vllm_prefill_engine_core"] = [
977
978
979
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="VLLM::EngineCore"
            )
980
981
982
        ]
    elif backend == "sglang":
        failures["sglang_decode_scheduler"] = [
983
984
985
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="sglang::scheduler"
            )
986
987
        ]
        failures["sglang_decode_detokenizer"] = [
988
989
990
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="sglang::detokenizer"
            )
991
992
        ]
        failures["sglang_prefill_scheduler"] = [
993
994
995
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="sglang::scheduler"
            )
996
997
        ]
        failures["sglang_prefill_detokenizer"] = [
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="sglang::detokenizer"
            )
        ]
    elif backend == "trtllm":
        failures["trtllm_decode_engine_core"] = [
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
            )
        ]
        failures["trtllm_prefill_engine_core"] = [
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
            )
1012
1013
1014
        ]

    return failures
1015
1016


1017
1018
1019
1020
1021
1022
1023
1024
def create_aiperf_load(
    clients: int = 10,
    requests_per_client: int = 150,
    input_token_length: int = 100,
    output_token_length: int = 100,
    max_retries: int = 3,
    sla: Optional[float] = None,
    max_request_rate: float = 1.0,
1025
    success_threshold: float = 90.0,
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
) -> Load:
    """Create a Load configuration for AI-Perf client.

    Args:
        clients: Number of concurrent clients (default: 10)
        requests_per_client: Number of requests per client (default: 150)
        input_token_length: Input token count (default: 100)
        output_token_length: Output token count (default: 100)
        max_retries: Maximum retry attempts - AI-Perf retries entire test (default: 3)
        sla: Optional SLA threshold for latency (default: None)
        max_request_rate: Rate limiting for requests/sec (default: 1.0)
1037
        success_threshold: Success rate threshold for pass/fail (default: 90.0)
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053

    Returns:
        Load instance configured for AI-Perf client

    Example:
        >>> load = create_aiperf_load(clients=20, requests_per_client=200)
    """
    return Load(
        clients=clients,
        requests_per_client=requests_per_client,
        input_token_length=input_token_length,
        output_token_length=output_token_length,
        max_retries=max_retries,
        sla=sla,
        client_type="aiperf",
        max_request_rate=max_request_rate,
1054
        success_threshold=success_threshold,
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    )


def create_legacy_load(
    clients: int = 10,
    requests_per_client: int = 100,
    input_token_length: int = 100,
    output_token_length: int = 100,
    max_retries: int = 1,
    sla: Optional[float] = None,
    max_request_rate: float = 1.0,
1066
    success_threshold: float = 90.0,
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
) -> Load:
    """Create a Load configuration for legacy custom client.

    Args:
        clients: Number of concurrent clients (default: 10)
        requests_per_client: Number of requests per client (default: 100, fewer than AI-Perf)
        input_token_length: Input token count (default: 100)
        output_token_length: Output token count (default: 100)
        max_retries: Maximum retry attempts - legacy retries per request (default: 1)
        sla: Optional SLA threshold for latency (default: None)
        max_request_rate: Rate limiting for requests/sec (default: 1.0)
1078
        success_threshold: Success rate threshold for pass/fail (default: 90.0)
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

    Returns:
        Load instance configured for legacy client

    Example:
        >>> load = create_legacy_load(clients=10, max_request_rate=2.0)
    """
    return Load(
        clients=clients,
        requests_per_client=requests_per_client,
        input_token_length=input_token_length,
        output_token_length=output_token_length,
        max_retries=max_retries,
        sla=sla,
        client_type="legacy",
        max_request_rate=max_request_rate,
1095
        success_threshold=success_threshold,
1096
1097
1098
1099
    )


# Default load configuration (using AI-Perf)
1100
1101
load = Load()

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
# MoE-specific load configuration
moe_load = Load(
    clients=3,  # Fewer clients for MoE testing
    requests_per_client=30,  # Reduced for MoE complexity
    input_token_length=100,
    output_token_length=100,
    max_retries=3,
    sla=None,
    client_type="aiperf",
    max_request_rate=0.5,  # Lower rate for MoE
)

1114
1115
1116
1117
1118
1119
# model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

model = None

# Populate Scenarios

1120
scenarios: dict[str, Scenario] = {}
1121

1122
1123
1124
1125
1126
1127
1128
# Map of backend+deploy_type to failure definitions
backend_failure_map = {}
for backend in ["vllm", "sglang", "trtllm"]:
    backend_failure_map[f"{backend}_agg"] = _create_backend_failures(backend, "agg")
    backend_failure_map[f"{backend}_disagg"] = _create_backend_failures(
        backend, "disagg"
    )
1129

1130
for deployment_name, deployment_info in DEPLOYMENT_SPECS.items():
1131
1132
    backend = deployment_info["backend"]

1133
1134
1135
    # Check if this is an MoE deployment
    is_moe = deployment_info.get("is_moe", False)

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    # Determine deployment type from deployment name
    deploy_type = (
        "agg"
        if ("agg" in deployment_name and "disagg" not in deployment_name)
        else "disagg"
    )

    # Get the appropriate failure set for this backend+deploy_type
    failure_map_key = f"{backend}_{deploy_type}"
    if failure_map_key not in backend_failure_map:
1146
        raise ValueError(
1147
            f"Unsupported backend+deploy_type: {failure_map_key}. Available: {list(backend_failure_map.keys())}"
1148
1149
        )

1150
    failure_set = backend_failure_map[failure_map_key]
1151
1152
1153

    for failure_name, failure in failure_set.items():
        # Skip prefill failures for aggregated deployments
1154
        if "prefill" in failure_name and deploy_type == "agg":
1155
            continue
1156
1157

        scenario_name = f"{deployment_name}-{failure_name}"
1158
1159
1160
1161
1162
1163
1164

        # Use MoE-specific load configuration if it's an MoE model
        load_config = moe_load if is_moe else load

        # Get model from deployment info or use the global model
        scenario_model = deployment_info.get("model", model)

1165
1166
        # Create scenario first (without checkers)
        scenario = Scenario(
1167
            deployment=deployment_info["spec"],
1168
            load=load_config,
1169
            failures=failure,
1170
            model=scenario_model,
1171
            backend=backend,
1172
            checkers=None,  # Will be populated below
1173
            requires_custom_build=is_moe,  # MoE models require custom builds
1174
        )
1175

1176
1177
1178
1179
1180
1181
        # Generate checkers for this scenario
        # This uses the checker factory to determine appropriate validation checks
        scenario.checkers = _get_checkers_for_scenario(scenario_name, scenario)

        scenarios[scenario_name] = scenario

1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315

# Add token overflow test scenarios
def add_token_overflow_scenarios():
    """
    Add test scenarios for token overflow (prompt > max_seq_len) failures
    """
    overflow_test_configs = [
        # vLLM tests
        {
            "name": "vllm_agg_token_overflow_2x",
            "deployment_key": "vllm-agg-tp-1-dp-1",
            "backend": "vllm",
        },
        {
            "name": "vllm_disagg_token_overflow_2x",
            "deployment_key": "vllm-disagg-prefill-tp-2-decode-tp-2-dp-1",
            "backend": "vllm",
        },
        # TRT-LLM tests
        {
            "name": "trtllm_agg_token_overflow_2x",
            "deployment_key": "trtllm-agg-tp-1-dp-1",
            "backend": "trtllm",
        },
        {
            "name": "trtllm_disagg_token_overflow_2x",
            "deployment_key": "trtllm-disagg-prefill-tp-2-decode-tp-2-dp-1",
            "backend": "trtllm",
        },
        # SGLang tests
        {
            "name": "sglang_agg_token_overflow_2x",
            "deployment_key": "sglang-agg-tp-1-dp-1",
            "backend": "sglang",
        },
        {
            "name": "sglang_disagg_token_overflow_2x",
            "deployment_key": "sglang-disagg-prefill-tp-2-decode-tp-2-dp-1",
            "backend": "sglang",
        },
    ]

    # Common configuration for all tests
    MAX_SEQ_LEN = 1024
    OVERFLOW_MULTIPLIER = 2.0
    OVERFLOW_REQUESTS = 15  # Number of oversized requests to send
    NORMAL_REQUESTS = 15  # Number of normal requests to send after overflow

    for config in overflow_test_configs:
        # Skip if deployment doesn't exist
        if config["deployment_key"] not in DEPLOYMENT_SPECS:
            continue

        overflow_scenario_name = config["name"]
        deployment_info = DEPLOYMENT_SPECS[config["deployment_key"]]

        scenario_model = deployment_info.get("model", model)

        deployment_spec = deployment_info["spec"]

        backend = config["backend"]
        is_agg = (
            "disagg" not in config["deployment_key"]
        )  # If not disaggregated, then it's aggregated

        workers = WORKER_MAP[backend]

        # Get the correct decode worker name
        if backend == "trtllm" and is_agg:
            decode_worker = workers["decode_agg"]
        else:
            decode_worker = workers["decode"]

        prefill_worker = workers["prefill"]

        # Determine argument name based on backend
        if backend == "trtllm":
            arg_name = "--max-seq-len"
        elif backend == "sglang":
            arg_name = "--context-length"
        else:  # vllm
            arg_name = "--max-model-len"

        # Add arguments to appropriate workers
        if is_agg:
            # For aggregated, add only to decode worker
            deployment_spec.add_arg_to_service(
                decode_worker, arg_name, str(MAX_SEQ_LEN)
            )
        else:
            # For disaggregated, add to both prefill and decode workers
            deployment_spec.add_arg_to_service(
                prefill_worker, arg_name, str(MAX_SEQ_LEN)
            )
            deployment_spec.add_arg_to_service(
                decode_worker, arg_name, str(MAX_SEQ_LEN)
            )

        # Create overflow failure
        overflow_failure = TokenOverflowFailure(
            time=30,  # Start after 30 seconds
            max_seq_len=MAX_SEQ_LEN,
            overflow_multiplier=OVERFLOW_MULTIPLIER,
        )

        # Create mixed load configuration for overflow + recovery testing
        overflow_tokens = int(MAX_SEQ_LEN * OVERFLOW_MULTIPLIER)
        normal_tokens = 512  # Well within MAX_SEQ_LEN

        # Total requests = overflow + normal
        total_requests = OVERFLOW_REQUESTS + NORMAL_REQUESTS

        # Mixed load that tests both rejection and recovery
        mixed_load = Load(
            clients=3,
            requests_per_client=total_requests,
            input_token_length=normal_tokens,
            output_token_length=50,
            # Mixed token test configuration
            mixed_token_test=True,
            overflow_token_length=overflow_tokens,
            overflow_request_count=OVERFLOW_REQUESTS,
            normal_request_count=NORMAL_REQUESTS,
        )

        scenarios[overflow_scenario_name] = Scenario(
            deployment=deployment_spec,
            load=mixed_load,
            failures=[overflow_failure],
            model=scenario_model,
            backend=backend,
        )


1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
def add_rolling_upgrade_scenarios():
    for backend in ["vllm", "sglang", "trtllm"]:
        for worker_mode in ["agg", "disagg"]:
            yaml_files = {
                "agg": f"examples/backends/{backend}/deploy/agg.yaml",
                "disagg": f"examples/backends/{backend}/deploy/disagg.yaml",
            }
            deployment_info = _create_deployment_info(backend, yaml_files[worker_mode])
            deployment_spec: DeploymentSpec = deployment_info["spec"]

            service_names: list[str] = []

            # setting replicas to 2 so we have availability of 1 replica at a time
            if worker_mode == "agg" and backend == "trtllm":
                service_names.append(WORKER_MAP[backend]["decode_agg"])
            else:
                service_names.append(WORKER_MAP[backend]["decode"])

            if worker_mode == "disagg":
                service_names.append(WORKER_MAP[backend]["prefill"])

            for service_name in service_names:
                deployment_spec.set_service_replicas(service_name, 2)

            load = Load(
                clients=10,
                input_token_length=100,
                output_token_length=100,
                max_retries=1,
                client_type="aiperf",
                max_request_rate=1.0,
                success_threshold=100.0,
                continuous_load=True,
            )

            scenario_name = f"{backend}-{worker_mode}-rolling-upgrade"
            model = "Qwen/Qwen3-0.6B"

            failure = RollingUpgradeFailure(
                time=30,
                service_names=service_names,
            )
            scenarios[scenario_name] = Scenario(
                deployment=deployment_info["spec"],
                load=load,
                failures=[failure],
                model=model,
                backend=backend,
            )


1367
1368
# Add the token overflow scenarios
add_token_overflow_scenarios()
1369
1370
1371

# Add the rolling upgrade scenarios
add_rolling_upgrade_scenarios()