scenarios.py 32.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
import asyncio
import logging
18
import re
19
from abc import ABC, abstractmethod
20
from dataclasses import dataclass, field
21
from enum import Enum, auto
22
from typing import TYPE_CHECKING, Dict, List, Optional, Pattern
23

24
from typing_extensions import Required, TypedDict
25

26
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
27

28
29
30
31
32
33
34
35
36
37
38
39
40
if TYPE_CHECKING:
    from tests.fault_tolerance.deploy.base_checker import BaseChecker


# Import checker factory (actual import, not TYPE_CHECKING)
def _get_checkers_for_scenario(
    scenario_name: str, scenario: "Scenario"
) -> List["BaseChecker"]:
    """Lazy import to avoid circular dependencies during module initialization."""
    from tests.fault_tolerance.deploy.checker_factory import get_checkers_for_scenario

    return get_checkers_for_scenario(scenario_name, scenario)

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

class TestPhase(Enum):
    """Enum representing different test phases in fault tolerance testing."""

    STANDARD = auto()
    OVERFLOW = auto()
    RECOVERY = auto()


class DeploymentInfo(TypedDict, total=False):
    """Information about a deployment configuration.

    Attributes:
        spec: DeploymentSpec object defining the deployment configuration
        backend: Backend type - "vllm", "sglang", or "trtllm"
        model: Optional model identifier (e.g., "deepseek-ai/DeepSeek-V2-Lite")
        is_moe: Optional flag indicating if this is a Mixture-of-Experts model
    """

60
61
    spec: Required[DeploymentSpec]
    backend: Required[str]
62
63
64
65
66
67
68
69
    model: str
    is_moe: bool


# Test phase suffixes derived from TestPhase enum
OVERFLOW_SUFFIX = f"_{TestPhase.OVERFLOW.name.lower()}"
RECOVERY_SUFFIX = f"_{TestPhase.RECOVERY.name.lower()}"

70
# Worker name mapping for different backends
71
72
73
74
75
76
77
78
79
WORKER_MAP = {
    "vllm": {
        "decode": "VllmDecodeWorker",
        "prefill": "VllmPrefillWorker",
    },
    "sglang": {
        "decode": "decode",
        "prefill": "prefill",
    },
80
81
82
83
84
    "trtllm": {
        "decode": "TRTLLMDecodeWorker",
        "decode_agg": "TRTLLMWorker",  # Aggregated uses different name
        "prefill": "TRTLLMPrefillWorker",
    },
85
86
}

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Process ready patterns for recovery detection
WORKER_READY_PATTERNS: Dict[str, Pattern] = {
    # Frontend
    "Frontend": re.compile(r"added model"),
    # vLLM workers
    "VllmDecodeWorker": re.compile(
        r"VllmWorker for (?P<model_name>.*?) has been initialized"
    ),
    "VllmPrefillWorker": re.compile(
        r"VllmWorker for (?P<model_name>.*?) has been initialized"
    ),
    # SGLang workers - look for their specific initialization messages
    "decode": re.compile(
        r"Model registration succeeded|Decode worker handler initialized|Worker handler initialized"
    ),
    "prefill": re.compile(
        r"Model registration succeeded|Prefill worker handler initialized|Worker handler initialized"
    ),
105
106
107
108
109
110
111
112
113
114
    # TensorRT-LLM workers
    "TRTLLMWorker": re.compile(
        r"TrtllmWorker for (?P<model_name>.*?) has been initialized|Model registration succeeded"
    ),
    "TRTLLMDecodeWorker": re.compile(
        r"TrtllmWorker for (?P<model_name>.*?) has been initialized|Model registration succeeded"
    ),
    "TRTLLMPrefillWorker": re.compile(
        r"TrtllmWorker for (?P<model_name>.*?) has been initialized|Model registration succeeded"
    ),
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
}


def get_all_worker_types() -> list[str]:
    """Get all worker type names for both vLLM and SGLang."""
    worker_types = ["Frontend"]
    for backend in WORKER_MAP.values():
        worker_types.extend(backend.values())
    # Remove duplicates while preserving order
    seen = set()
    result = []
    for x in worker_types:
        if x not in seen:
            seen.add(x)
            result.append(x)
    return result


def get_worker_ready_pattern(worker_name: str) -> Optional[Pattern]:
    """Get the ready pattern for a specific worker type."""
    return WORKER_READY_PATTERNS.get(worker_name)


def get_backend_workers(backend: str) -> Dict[str, str]:
    """Get worker mapping for a specific backend."""
    return WORKER_MAP.get(backend, {})

142
143
144
145
146
147
148

@dataclass
class Load:
    clients: int = 10
    requests_per_client: int = 150
    input_token_length: int = 100
    output_token_length: int = 100
149
    max_retries: int = 3  # Increased for fault tolerance
150
    sla: Optional[float] = None
151
152
    client_type: str = "aiperf"  # "aiperf" or "legacy"
    max_request_rate: float = 1.0  # Rate limiting for legacy client (requests/sec)
153
154
155
156
157
158
159
    success_threshold: float = 90.0  # Success rate threshold for tests

    # For mixed token testing (overflow + recovery)
    mixed_token_test: bool = False
    overflow_token_length: Optional[int] = None  # Tokens for overflow requests
    overflow_request_count: int = 15  # Number of overflow requests
    normal_request_count: int = 15  # Number of normal requests after overflow
160

161
162
163
164
    continuous_load: bool = (
        False  # If True, use continuous load instead of fixed request count
    )

165
166

@dataclass
167
168
169
170
class Failure(ABC):
    """Base class for all failure types."""

    # time to wait in seconds before the failure is injected
171
    time: int
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    # names of DGD services to inject the failure into the corresponding pods for
    service_names: list[str]

    @abstractmethod
    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute the failure injection.

        Args:
            deployment: The managed deployment to inject the failure into
            logger: Logger instance for logging failure injection

        Returns: List of affected pod names
        """
        pass

    @abstractmethod
    def get_failure_key(self) -> str:
        """Get the failure key for the failure."""
        pass


@dataclass
class RollingUpgradeFailure(Failure):
    """Failure type for triggering rolling upgrades."""

    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute rolling upgrade failure injection."""
        await deployment.trigger_rolling_upgrade(self.service_names)

        # Need to wait for the deployment to be unready so we know the rolling upgrade has started
        await deployment.wait_for_unready(timeout=60, log_interval=10)

        await deployment._wait_for_ready(timeout=1800)  # 30 minute timeout

        await asyncio.sleep(
            self.time
        )  # have some requests processed after the rolling upgrade has completed

        return await deployment.get_pod_names(self.service_names)

    def get_failure_key(self) -> str:
        """Get the failure key for the rolling upgrade failure."""
        return f"rolling_upgrade:{','.join(self.service_names)}"


@dataclass
class DeletePodFailure(Failure):
    """Failure type for deleting pods."""

    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute pod deletion failure injection."""
        service_pod_dict = deployment.get_pods(self.service_names)
        pod_names: list[str] = []
        for service_name, pods in service_pod_dict.items():
            for pod in pods:
                deployment.get_pod_manifest_logs_metrics(
                    service_name, pod, ".before_delete"
                )
                pod.delete(force=True)  # force means no graceful termination
                pod_names.append(pod.name)

        return pod_names

    def get_failure_key(self) -> str:
        """Get the failure key for the delete pod failure."""
        return f"delete_pod:{','.join(self.service_names)}"


class TerminateProcessFailure(Failure):
    """Failure type for terminating specific processes by name."""

    def __init__(
        self,
        time: int,
        service_names: list[str],
        signal: str = "SIGINT",
        process_name: str = "",
    ):
        """Initialize TerminateProcessFailure.

        Args:
            time: Time to wait in seconds before the failure is injected
            service_names: Names of DGD services to inject the failure into
            signal: Signal to send (default: "SIGINT")
            process_name: Name of the process to terminate (required)
            end_condition: End condition for failure (e.g., "dgd_ready")
        """
        super().__init__(
            time=time,
            service_names=service_names,
        )
        if not process_name or not signal:
            raise ValueError(
                "process_name and signal are required for TerminateProcessFailure"
            )
        self.process_name = process_name
        self.signal = signal

    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Execute process termination failure injection."""
        service_pod_dict = deployment.get_pods(self.service_names)
        pod_names: list[str] = []
        for service_name, pods in service_pod_dict.items():
            for pod in pods:
                processes = deployment.get_processes(pod)
                for process in processes:
                    if self.process_name in process.command:
                        logger.info(
                            f"Terminating {service_name} pod {pod} Pid {process.pid} Command {process.command}"
                        )
                        process.kill(self.signal)
                pod_names.append(pod.name)

        return pod_names

    def get_failure_key(self) -> str:
        """Get the failure key for the terminate process failure."""
        return f"terminate_process:{','.join(self.service_names)}:{self.process_name}:{self.signal}"
299
300


301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
@dataclass
class TokenOverflowFailure(Failure):
    """
    Failure type for injecting token overflow (prompt > max_seq_len)
    """

    overflow_multiplier: float = 2.0  # How much to exceed max_seq_len (e.g., 2.0 = 2x)
    max_seq_len: int = 1024

    def __init__(
        self,
        time: int,
        max_seq_len: int = 1024,
        overflow_multiplier: float = 2.0,
    ):
        super().__init__(
            time=time,
318
            service_names=["Client"],
319
320
321
322
323
        )
        self.max_seq_len = max_seq_len
        self.overflow_multiplier = overflow_multiplier
        self.overflow_token_count = int(max_seq_len * overflow_multiplier)

324
325
326
327
328
329
330
331
332
333
334
335
336
    async def execute(
        self, deployment: ManagedDeployment, logger: logging.Logger
    ) -> list[str]:
        """Token overflow is handled client-side, so this is a no-op."""
        # The actual overflow is handled by the client configuration
        # which uses the input_token_length from the Load config
        # This is just a placeholder for the abstract method
        return []

    def get_failure_key(self) -> str:
        """Get the failure key for the token overflow failure."""
        return f"token_overflow:{self.overflow_token_count}"

337

338
339
340
341
342
343
@dataclass
class Scenario:
    deployment: DeploymentSpec
    load: Load
    failures: list[Failure]
    model: Optional[str] = None
344
    backend: str = "vllm"  # Backend type for tracking
345
346
347
    # When set to True, the test will be automatically marked with @pytest.mark.custom_build
    # and excluded from default test runs unless --include-custom-build flag is used
    requires_custom_build: bool = False  # Flag for tests needing custom builds/setup
348
349
350
    # List of checkers to run for validation (scenario + results checkers)
    # If None, factory will determine checkers based on scenario name and deployment
    checkers: Optional[List["BaseChecker"]] = field(default=None)
351
352


353
# Helper functions to create deployment specs
354
def _create_deployment_info(backend: str, yaml_path: str) -> DeploymentInfo:
355
356
357
358
359
360
361
362
363
364
    """Create a deployment spec with backend information.

    Args:
        backend: Backend type ("vllm", "sglang", or "trtllm")
        yaml_path: Path to the deployment YAML file

    Returns:
        DeploymentInfo dictionary with spec and backend
    """
    return DeploymentInfo(spec=DeploymentSpec(yaml_path), backend=backend)
365
366


367
368
369
370
371
372
373
374
def _set_replicas(deployment_spec, backend, deploy_type, replicas):
    """Set replicas for all components in a deployment based on backend type."""
    spec = deployment_spec["spec"]

    # Frontend is common for all backends
    spec["Frontend"].replicas = replicas

    if backend in WORKER_MAP:
375
376
377
378
379
380
        # For trtllm agg deployments, use different worker name
        if backend == "trtllm" and deploy_type == "agg":
            decode_worker = WORKER_MAP[backend]["decode_agg"]
        else:
            decode_worker = WORKER_MAP[backend]["decode"]

381
        # always scale decode
382
        spec[decode_worker].replicas = replicas
383
384
385
386
387
        # scale prefill only for disagg
        if deploy_type == "disagg":
            spec[WORKER_MAP[backend]["prefill"]].replicas = replicas


388
389
390
def _set_tensor_parallel(
    deployment_spec: DeploymentInfo, backend: str, deploy_type: str, tp_size: int
):
391
392
393
394
    """Set tensor parallel size for worker components."""
    spec = deployment_spec["spec"]

    if backend in WORKER_MAP:
395
396
397
398
399
        # For trtllm agg deployments, use different worker name
        if backend == "trtllm" and deploy_type == "agg":
            decode_worker = WORKER_MAP[backend]["decode_agg"]
        else:
            decode_worker = WORKER_MAP[backend]["decode"]
400
401
402
403
404
405
406
407
408
409
410
411
        prefill_worker = WORKER_MAP[backend]["prefill"]

        if deploy_type == "agg":
            if hasattr(spec, "set_tensor_parallel"):
                spec.set_tensor_parallel(tp_size, [decode_worker])
            else:
                spec[decode_worker].tensor_parallel_size = tp_size
        elif deploy_type == "disagg":
            spec[prefill_worker].tensor_parallel_size = tp_size
            spec[decode_worker].tensor_parallel_size = tp_size


412
413
414
415
416
417
418
419
420
421
def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]:
    """Create all deployment specifications for a given backend.

    Args:
        backend: Backend type ("vllm", "sglang", or "trtllm")

    Returns:
        Dictionary mapping deployment names to DeploymentInfo objects
    """
    deployments: Dict[str, DeploymentInfo] = {}
422
423
424

    # Define the yaml files for agg and disagg deployments
    yaml_files = {
425
426
        "agg": f"examples/backends/{backend}/deploy/agg.yaml",
        "disagg": f"examples/backends/{backend}/deploy/disagg.yaml",
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    }

    # Define the different configurations to test
    configurations = [
        {"tp": 1, "dp": 1},
        {"tp": 1, "dp": 2},
        {"tp": 2, "dp": 1},
        {"tp": 4, "dp": 1},
    ]

    for deploy_type in ["agg", "disagg"]:
        for config in configurations:
            tp_size = config["tp"]
            dp_replicas = config["dp"]
            # Skip creating disagg scenarios for TP > 1 if DP is also > 1 (uncommon case)
            if deploy_type == "disagg" and tp_size > 1 and dp_replicas > 1:
                continue

            # Construct the scenario name
            name_parts = [backend, deploy_type]

            if deploy_type == "agg":
                name_parts.append(f"tp-{tp_size}")
            elif deploy_type == "disagg":
                name_parts.append(f"prefill-tp-{tp_size}-decode-tp-{tp_size}")

            name_parts.append(f"dp-{dp_replicas}")

            scenario_name = "-".join(name_parts)

            # Create and configure the deployment
458
            deployment = _create_deployment_info(backend, yaml_files[deploy_type])
459
460
461
462
463
464
465
466
467
468
            if tp_size > 1:
                _set_tensor_parallel(deployment, backend, deploy_type, tp_size)
            if dp_replicas > 1:
                _set_replicas(deployment, backend, deploy_type, dp_replicas)

            deployments[scenario_name] = deployment

    return deployments


469
470
471
472
473
474
475
476
477
478
479
480
def _create_moe_deployments_for_backend(
    backend: str = "vllm",
) -> Dict[str, DeploymentInfo]:
    """Create MoE-specific deployment configurations for DeepSeek-V2-Lite.

    Args:
        backend: Backend type (default: "vllm")

    Returns:
        Dictionary mapping deployment names to DeploymentInfo objects
    """
    deployments: Dict[str, DeploymentInfo] = {}
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495

    # Only test tp=1, dp=2 for now
    tp_size = 1
    dp_replicas = (
        2  # Note: this is handled internally by vLLM with --data-parallel-size
    )

    template_dir = "tests/fault_tolerance/deploy/templates"
    yaml_files = {
        "agg": f"{template_dir}/{backend}/moe_agg.yaml",
        "disagg": f"{template_dir}/{backend}/moe_disagg.yaml",
    }

    for deploy_type in ["agg", "disagg"]:
        scenario_name = f"{backend}-moe-{deploy_type}-tp-{tp_size}-dp-{dp_replicas}"
496
497
498
499
500
501
        deployment = DeploymentInfo(
            spec=DeploymentSpec(yaml_files[deploy_type]),
            backend=backend,
            model="deepseek-ai/DeepSeek-V2-Lite",
            is_moe=True,
        )
502
503
504
505
506
507

        deployments[scenario_name] = deployment

    return deployments


508
# Create all deployment specifications
509
510
511
512
DEPLOYMENT_SPECS: Dict[str, DeploymentInfo] = {}
DEPLOYMENT_SPECS.update(_create_deployments_for_backend("vllm"))
DEPLOYMENT_SPECS.update(_create_deployments_for_backend("sglang"))
DEPLOYMENT_SPECS.update(_create_deployments_for_backend("trtllm"))
513

514
# Add MoE deployments for vLLM only
515
DEPLOYMENT_SPECS.update(_create_moe_deployments_for_backend("vllm"))
516

517
518
519
520
521
522
523
524

# Each failure scenaro contains a list of failure injections
# Each failure injection has a time in seconds after the pervious injection and
# a list of failures to inject including the number of failures for each type.
# Failures are currently process termination or pod deletion
#
# Example:
#
525
#   "prefill_worker": [Failure(30, "VllmPrefillWorker", "dynamo.vllm", "SIGKILL")],
526
527
#
# terminates 1 prefill worker after 30 seconds
528
529
530
531
532
533
534
def _create_backend_failures(backend, deploy_type="disagg"):
    """Generate backend-specific failure scenarios.

    Args:
        backend: Backend type (vllm, sglang, trtllm)
        deploy_type: Deployment type (agg or disagg)
    """
535
    workers = WORKER_MAP[backend]
536
537
538
539
540
541
542

    # Use correct worker name based on deployment type
    if backend == "trtllm" and deploy_type == "agg":
        decode_worker = workers["decode_agg"]
    else:
        decode_worker = workers["decode"]

543
544
545
546
    prefill_worker = workers["prefill"]
    process_name = f"dynamo.{backend}"

    failures = {
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        "frontend": [
            TerminateProcessFailure(
                30, ["Frontend"], "SIGINT", process_name="dynamo.frontend"
            )
        ],
        "frontend_pod": [DeletePodFailure(30, ["Frontend"])],
        "decode_worker": [
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name=process_name
            )
        ],
        "decode_worker_pod": [DeletePodFailure(30, [decode_worker])],
        "prefill_worker": [
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name=process_name
            )
        ],
        "prefill_worker_pod": [DeletePodFailure(30, [prefill_worker])],
565
566
567
568
569
        "none": [],
    }

    if backend == "vllm":
        failures["vllm_decode_engine_core"] = [
570
571
572
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="VLLM::EngineCore"
            )
573
574
        ]
        failures["vllm_prefill_engine_core"] = [
575
576
577
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="VLLM::EngineCore"
            )
578
579
580
        ]
    elif backend == "sglang":
        failures["sglang_decode_scheduler"] = [
581
582
583
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="sglang::scheduler"
            )
584
585
        ]
        failures["sglang_decode_detokenizer"] = [
586
587
588
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="sglang::detokenizer"
            )
589
590
        ]
        failures["sglang_prefill_scheduler"] = [
591
592
593
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="sglang::scheduler"
            )
594
595
        ]
        failures["sglang_prefill_detokenizer"] = [
596
597
598
599
600
601
602
603
604
605
606
607
608
609
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="sglang::detokenizer"
            )
        ]
    elif backend == "trtllm":
        failures["trtllm_decode_engine_core"] = [
            TerminateProcessFailure(
                30, [decode_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
            )
        ]
        failures["trtllm_prefill_engine_core"] = [
            TerminateProcessFailure(
                30, [prefill_worker], "SIGKILL", process_name="TRTLLM::EngineCore"
            )
610
611
612
        ]

    return failures
613
614


615
616
617
618
619
620
621
622
def create_aiperf_load(
    clients: int = 10,
    requests_per_client: int = 150,
    input_token_length: int = 100,
    output_token_length: int = 100,
    max_retries: int = 3,
    sla: Optional[float] = None,
    max_request_rate: float = 1.0,
623
    success_threshold: float = 90.0,
624
625
626
627
628
629
630
631
632
633
634
) -> Load:
    """Create a Load configuration for AI-Perf client.

    Args:
        clients: Number of concurrent clients (default: 10)
        requests_per_client: Number of requests per client (default: 150)
        input_token_length: Input token count (default: 100)
        output_token_length: Output token count (default: 100)
        max_retries: Maximum retry attempts - AI-Perf retries entire test (default: 3)
        sla: Optional SLA threshold for latency (default: None)
        max_request_rate: Rate limiting for requests/sec (default: 1.0)
635
        success_threshold: Success rate threshold for pass/fail (default: 90.0)
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651

    Returns:
        Load instance configured for AI-Perf client

    Example:
        >>> load = create_aiperf_load(clients=20, requests_per_client=200)
    """
    return Load(
        clients=clients,
        requests_per_client=requests_per_client,
        input_token_length=input_token_length,
        output_token_length=output_token_length,
        max_retries=max_retries,
        sla=sla,
        client_type="aiperf",
        max_request_rate=max_request_rate,
652
        success_threshold=success_threshold,
653
654
655
656
657
658
659
660
661
662
663
    )


def create_legacy_load(
    clients: int = 10,
    requests_per_client: int = 100,
    input_token_length: int = 100,
    output_token_length: int = 100,
    max_retries: int = 1,
    sla: Optional[float] = None,
    max_request_rate: float = 1.0,
664
    success_threshold: float = 90.0,
665
666
667
668
669
670
671
672
673
674
675
) -> Load:
    """Create a Load configuration for legacy custom client.

    Args:
        clients: Number of concurrent clients (default: 10)
        requests_per_client: Number of requests per client (default: 100, fewer than AI-Perf)
        input_token_length: Input token count (default: 100)
        output_token_length: Output token count (default: 100)
        max_retries: Maximum retry attempts - legacy retries per request (default: 1)
        sla: Optional SLA threshold for latency (default: None)
        max_request_rate: Rate limiting for requests/sec (default: 1.0)
676
        success_threshold: Success rate threshold for pass/fail (default: 90.0)
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692

    Returns:
        Load instance configured for legacy client

    Example:
        >>> load = create_legacy_load(clients=10, max_request_rate=2.0)
    """
    return Load(
        clients=clients,
        requests_per_client=requests_per_client,
        input_token_length=input_token_length,
        output_token_length=output_token_length,
        max_retries=max_retries,
        sla=sla,
        client_type="legacy",
        max_request_rate=max_request_rate,
693
        success_threshold=success_threshold,
694
695
696
697
    )


# Default load configuration (using AI-Perf)
698
699
load = Load()

700
701
702
703
704
705
706
707
708
709
710
711
# MoE-specific load configuration
moe_load = Load(
    clients=3,  # Fewer clients for MoE testing
    requests_per_client=30,  # Reduced for MoE complexity
    input_token_length=100,
    output_token_length=100,
    max_retries=3,
    sla=None,
    client_type="aiperf",
    max_request_rate=0.5,  # Lower rate for MoE
)

712
713
714
715
716
717
# model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

model = None

# Populate Scenarios

718
scenarios: dict[str, Scenario] = {}
719

720
721
722
723
724
725
726
# Map of backend+deploy_type to failure definitions
backend_failure_map = {}
for backend in ["vllm", "sglang", "trtllm"]:
    backend_failure_map[f"{backend}_agg"] = _create_backend_failures(backend, "agg")
    backend_failure_map[f"{backend}_disagg"] = _create_backend_failures(
        backend, "disagg"
    )
727

728
for deployment_name, deployment_info in DEPLOYMENT_SPECS.items():
729
730
    backend = deployment_info["backend"]

731
732
733
    # Check if this is an MoE deployment
    is_moe = deployment_info.get("is_moe", False)

734
735
736
737
738
739
740
741
742
743
    # Determine deployment type from deployment name
    deploy_type = (
        "agg"
        if ("agg" in deployment_name and "disagg" not in deployment_name)
        else "disagg"
    )

    # Get the appropriate failure set for this backend+deploy_type
    failure_map_key = f"{backend}_{deploy_type}"
    if failure_map_key not in backend_failure_map:
744
        raise ValueError(
745
            f"Unsupported backend+deploy_type: {failure_map_key}. Available: {list(backend_failure_map.keys())}"
746
747
        )

748
    failure_set = backend_failure_map[failure_map_key]
749
750
751

    for failure_name, failure in failure_set.items():
        # Skip prefill failures for aggregated deployments
752
        if "prefill" in failure_name and deploy_type == "agg":
753
            continue
754
755

        scenario_name = f"{deployment_name}-{failure_name}"
756
757
758
759
760
761
762

        # Use MoE-specific load configuration if it's an MoE model
        load_config = moe_load if is_moe else load

        # Get model from deployment info or use the global model
        scenario_model = deployment_info.get("model", model)

763
764
        # Create scenario first (without checkers)
        scenario = Scenario(
765
            deployment=deployment_info["spec"],
766
            load=load_config,
767
            failures=failure,
768
            model=scenario_model,
769
            backend=backend,
770
            checkers=None,  # Will be populated below
771
            requires_custom_build=is_moe,  # MoE models require custom builds
772
        )
773

774
775
776
777
778
779
        # Generate checkers for this scenario
        # This uses the checker factory to determine appropriate validation checks
        scenario.checkers = _get_checkers_for_scenario(scenario_name, scenario)

        scenarios[scenario_name] = scenario

780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913

# Add token overflow test scenarios
def add_token_overflow_scenarios():
    """
    Add test scenarios for token overflow (prompt > max_seq_len) failures
    """
    overflow_test_configs = [
        # vLLM tests
        {
            "name": "vllm_agg_token_overflow_2x",
            "deployment_key": "vllm-agg-tp-1-dp-1",
            "backend": "vllm",
        },
        {
            "name": "vllm_disagg_token_overflow_2x",
            "deployment_key": "vllm-disagg-prefill-tp-2-decode-tp-2-dp-1",
            "backend": "vllm",
        },
        # TRT-LLM tests
        {
            "name": "trtllm_agg_token_overflow_2x",
            "deployment_key": "trtllm-agg-tp-1-dp-1",
            "backend": "trtllm",
        },
        {
            "name": "trtllm_disagg_token_overflow_2x",
            "deployment_key": "trtllm-disagg-prefill-tp-2-decode-tp-2-dp-1",
            "backend": "trtllm",
        },
        # SGLang tests
        {
            "name": "sglang_agg_token_overflow_2x",
            "deployment_key": "sglang-agg-tp-1-dp-1",
            "backend": "sglang",
        },
        {
            "name": "sglang_disagg_token_overflow_2x",
            "deployment_key": "sglang-disagg-prefill-tp-2-decode-tp-2-dp-1",
            "backend": "sglang",
        },
    ]

    # Common configuration for all tests
    MAX_SEQ_LEN = 1024
    OVERFLOW_MULTIPLIER = 2.0
    OVERFLOW_REQUESTS = 15  # Number of oversized requests to send
    NORMAL_REQUESTS = 15  # Number of normal requests to send after overflow

    for config in overflow_test_configs:
        # Skip if deployment doesn't exist
        if config["deployment_key"] not in DEPLOYMENT_SPECS:
            continue

        overflow_scenario_name = config["name"]
        deployment_info = DEPLOYMENT_SPECS[config["deployment_key"]]

        scenario_model = deployment_info.get("model", model)

        deployment_spec = deployment_info["spec"]

        backend = config["backend"]
        is_agg = (
            "disagg" not in config["deployment_key"]
        )  # If not disaggregated, then it's aggregated

        workers = WORKER_MAP[backend]

        # Get the correct decode worker name
        if backend == "trtllm" and is_agg:
            decode_worker = workers["decode_agg"]
        else:
            decode_worker = workers["decode"]

        prefill_worker = workers["prefill"]

        # Determine argument name based on backend
        if backend == "trtllm":
            arg_name = "--max-seq-len"
        elif backend == "sglang":
            arg_name = "--context-length"
        else:  # vllm
            arg_name = "--max-model-len"

        # Add arguments to appropriate workers
        if is_agg:
            # For aggregated, add only to decode worker
            deployment_spec.add_arg_to_service(
                decode_worker, arg_name, str(MAX_SEQ_LEN)
            )
        else:
            # For disaggregated, add to both prefill and decode workers
            deployment_spec.add_arg_to_service(
                prefill_worker, arg_name, str(MAX_SEQ_LEN)
            )
            deployment_spec.add_arg_to_service(
                decode_worker, arg_name, str(MAX_SEQ_LEN)
            )

        # Create overflow failure
        overflow_failure = TokenOverflowFailure(
            time=30,  # Start after 30 seconds
            max_seq_len=MAX_SEQ_LEN,
            overflow_multiplier=OVERFLOW_MULTIPLIER,
        )

        # Create mixed load configuration for overflow + recovery testing
        overflow_tokens = int(MAX_SEQ_LEN * OVERFLOW_MULTIPLIER)
        normal_tokens = 512  # Well within MAX_SEQ_LEN

        # Total requests = overflow + normal
        total_requests = OVERFLOW_REQUESTS + NORMAL_REQUESTS

        # Mixed load that tests both rejection and recovery
        mixed_load = Load(
            clients=3,
            requests_per_client=total_requests,
            input_token_length=normal_tokens,
            output_token_length=50,
            # Mixed token test configuration
            mixed_token_test=True,
            overflow_token_length=overflow_tokens,
            overflow_request_count=OVERFLOW_REQUESTS,
            normal_request_count=NORMAL_REQUESTS,
        )

        scenarios[overflow_scenario_name] = Scenario(
            deployment=deployment_spec,
            load=mixed_load,
            failures=[overflow_failure],
            model=scenario_model,
            backend=backend,
        )


914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
def add_rolling_upgrade_scenarios():
    for backend in ["vllm", "sglang", "trtllm"]:
        for worker_mode in ["agg", "disagg"]:
            yaml_files = {
                "agg": f"examples/backends/{backend}/deploy/agg.yaml",
                "disagg": f"examples/backends/{backend}/deploy/disagg.yaml",
            }
            deployment_info = _create_deployment_info(backend, yaml_files[worker_mode])
            deployment_spec: DeploymentSpec = deployment_info["spec"]

            service_names: list[str] = []

            # setting replicas to 2 so we have availability of 1 replica at a time
            if worker_mode == "agg" and backend == "trtllm":
                service_names.append(WORKER_MAP[backend]["decode_agg"])
            else:
                service_names.append(WORKER_MAP[backend]["decode"])

            if worker_mode == "disagg":
                service_names.append(WORKER_MAP[backend]["prefill"])

            for service_name in service_names:
                deployment_spec.set_service_replicas(service_name, 2)

            load = Load(
                clients=10,
                input_token_length=100,
                output_token_length=100,
                max_retries=1,
                client_type="aiperf",
                max_request_rate=1.0,
                success_threshold=100.0,
                continuous_load=True,
            )

            scenario_name = f"{backend}-{worker_mode}-rolling-upgrade"
            model = "Qwen/Qwen3-0.6B"

            failure = RollingUpgradeFailure(
                time=30,
                service_names=service_names,
            )
            scenarios[scenario_name] = Scenario(
                deployment=deployment_info["spec"],
                load=load,
                failures=[failure],
                model=model,
                backend=backend,
            )


965
966
# Add the token overflow scenarios
add_token_overflow_scenarios()
967
968
969

# Add the rolling upgrade scenarios
add_rolling_upgrade_scenarios()