test_deployment.py 18.8 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import multiprocessing
6
import re
7
8
import time
from contextlib import contextmanager
9
from typing import Any
10
11
12

import pytest

13
from tests.fault_tolerance.deploy.base_checker import ValidationContext
14
15
from tests.fault_tolerance.deploy.client_factory import get_client_function
from tests.fault_tolerance.deploy.parse_factory import parse_test_results
16
17
18
19
20
21
22
23
from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery_test
from tests.fault_tolerance.deploy.scenarios import (
    OVERFLOW_SUFFIX,
    RECOVERY_SUFFIX,
    Load,
    TokenOverflowFailure,
    scenarios,
)
24
25
26
from tests.utils.managed_deployment import ManagedDeployment


27
28
@pytest.fixture
def scenario(scenario_name, client_type):
29
30
31
32
    """Get scenario and optionally override client type from command line.

    If --client-type is specified, it overrides the scenario's default client type.
    """
33
    scenario_obj = scenarios[scenario_name]
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    # Override client type if specified on command line
    if client_type is not None:
        # Create a copy of the load config with overridden client type
        import copy

        scenario_obj = copy.deepcopy(scenario_obj)
        scenario_obj.load.client_type = client_type

        # Adjust retry settings based on client type
        if client_type == "legacy":
            # Legacy uses per-request retries
            if scenario_obj.load.max_retries > 1:
                scenario_obj.load.max_retries = 1
        elif client_type == "aiperf":
            # AI-Perf uses full test retries
            if scenario_obj.load.max_retries < 3:
                scenario_obj.load.max_retries = 3

    return scenario_obj
54
55
56
57
58
59
60
61
62


@contextmanager
def _clients(
    logger,
    request,
    deployment_spec,
    namespace,
    model,
63
    load_config: Load,
64
):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    """Start client processes using factory pattern for client selection.

    Args:
        logger: Logger instance
        request: Pytest request fixture
        deployment_spec: Deployment specification
        namespace: Kubernetes namespace
        model: Model name to test
        load_config: Load configuration object containing client settings
    """
    # Get appropriate client function based on configuration
    client_func = get_client_function(load_config.client_type)

    logger.info(
        f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
    )

82
83
    procs = []
    ctx = multiprocessing.get_context("spawn")
84
85
86
87
88
89
90
91
92

    # Determine retry_delay_or_rate based on client type
    if load_config.client_type == "legacy":
        # Legacy client uses max_request_rate for rate limiting
        retry_delay_or_rate = load_config.max_request_rate
    else:
        # AI-Perf client uses retry_delay between attempts (default 5s)
        retry_delay_or_rate = 5

93
94
95
96
97
98
99
100
101
102
103
104
105
    # Check if this is a mixed token test (overflow + recovery)
    # If mixed_token_test is True, run two phases; otherwise run normally
    if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test:
        logger.info(
            f"Mixed token test: {load_config.overflow_request_count} overflow requests "
            f"({load_config.overflow_token_length} tokens) + "
            f"{load_config.normal_request_count} normal requests "
            f"({load_config.input_token_length} tokens)"
        )

        # First phase: Send overflow requests
        for i in range(load_config.clients):
            proc_overflow = ctx.Process(
106
                target=client_func,
107
108
109
110
                args=(
                    deployment_spec,
                    namespace,
                    model,
111
                    request.node.name + OVERFLOW_SUFFIX,
112
                    i,
113
114
                    load_config.overflow_request_count,  # 15 overflow requests
                    load_config.overflow_token_length,  # 2x max_seq_len tokens
115
116
117
                    load_config.output_token_length,
                    load_config.max_retries,
                    retry_delay_or_rate,
118
119
                ),
            )
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            proc_overflow.start()
            procs.append(proc_overflow)
            logger.debug(f"Started overflow client {i} (PID: {proc_overflow.pid})")

        # Wait for overflow requests to complete
        for proc in procs:
            proc.join()

        logger.info("Overflow requests completed. Starting recovery phase...")

        # Second phase: Send normal requests to test recovery
        procs_recovery = []
        for i in range(load_config.clients):
            proc_normal = ctx.Process(
                target=client_func,
                args=(
                    deployment_spec,
                    namespace,
                    model,
                    request.node.name + RECOVERY_SUFFIX,
                    i,
                    load_config.normal_request_count,  # 15 normal requests
                    load_config.input_token_length,  # Normal token count
                    load_config.output_token_length,
                    load_config.max_retries,
                    retry_delay_or_rate,
                ),
            )
            proc_normal.start()
            procs_recovery.append(proc_normal)
            logger.debug(f"Started recovery client {i} (PID: {proc_normal.pid})")

        # Add recovery processes to main list
        procs.extend(procs_recovery)
    else:
        # Normal test - single phase
        for i in range(load_config.clients):
            procs.append(
                ctx.Process(
                    target=client_func,
                    args=(
                        deployment_spec,
                        namespace,
                        model,
                        request.node.name,
                        i,
                        load_config.requests_per_client,
                        load_config.input_token_length,
                        load_config.output_token_length,
                        load_config.max_retries,
                        retry_delay_or_rate,
                    ),
                )
            )
            procs[-1].start()
            logger.debug(f"Started client {i} (PID: {procs[-1].pid})")
176

177
178
179
180
181
182
183
184
185
    yield procs

    for proc in procs:
        logger.debug(f"{proc} waiting for join")
        proc.join()
        logger.debug(f"{proc} joined")


def _inject_failures(failures, logger, deployment: ManagedDeployment):  # noqa: F811
186
187
188
189
190
191
192
193
    """Inject failures and return info about affected pods.

    Returns:
        Dict mapping failure info to list of affected pod names
        Example: {"VllmDecodeWorker:delete_pod": ["pod-abc123", "pod-xyz789"]}
    """
    affected_pods: dict[str, list] = {}

194
195
196
    for failure in failures:
        time.sleep(failure.time)

197
198
199
200
201
202
203
        # Handle TokenOverflowFailure differently - it's a client-side injection
        if isinstance(failure, TokenOverflowFailure):
            # The actual overflow is handled by the client configuration
            # which uses the input_token_length from the Load config
            # This is just logging for visibility
            continue

204
205
206
207
208
209
210
211
212
213
214
215
216
217
        pods = deployment.get_pods(failure.pod_name)[failure.pod_name]

        num_pods = len(pods)

        if not pods:
            continue

        replicas = failure.replicas

        if not replicas:
            replicas = num_pods

        logger.info(f"Injecting failure for: {failure}")

218
219
220
221
222
        # Track which pods were affected by this failure
        failure_key = f"{failure.pod_name}:{failure.command}"
        if failure_key not in affected_pods:
            affected_pods[failure_key] = []

223
224
225
        for x in range(replicas):
            pod = pods[x % num_pods]

226
227
228
229
230
231
            # Capture the exact pod name before we kill it
            pod_name = pod.name
            affected_pods[failure_key].append(pod_name)

            logger.info(f"Target pod for failure: {pod_name}")

232
233
            if failure.command == "delete_pod":
                deployment.get_pod_logs(failure.pod_name, pod, ".before_delete")
234
                logger.info(f"Deleting pod: {pod_name}")
235
236
237
238
239
240
                pod.delete(force=True)
            else:
                processes = deployment.get_processes(pod)
                for process in processes:
                    if failure.command in process.command:
                        logger.info(
241
                            f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command} in pod {pod_name}"
242
243
244
                        )
                        process.kill(failure.signal)

245
246
    return affected_pods

247
248

global_result_list = []
249
250
# Global storage for test results (used by validation fixture)
test_results_cache = {}
251
252
253


@pytest.fixture(autouse=True)
254
255
256
257
258
259
def validation_context(request, scenario):  # noqa: F811
    """Provides shared context between test execution and validation.

    This fixture creates a shared dictionary that the test populates during
    execution (deployment, namespace, affected_pods), then uses that data
    in teardown to parse results and run checkers.
260
261

    Automatically detects result type (AI-Perf or legacy) and uses
262
    the appropriate parser. After parsing, immediately runs validation checkers.
263
    """
264
265
266
267
268
269
270
271
    # Shared context that test will populate during execution
    context: dict[str, Any] = {
        "deployment": None,
        "namespace": None,
        "affected_pods": {},
    }

    yield context  # Test receives this and populates it
272

273
274
    # Determine log paths based on whether this is a mixed token test
    log_paths = []
275
276
277
    test_name = request.node.name
    logger = logging.getLogger(test_name)

278
279
280
281
282
283
284
285
286
287
288
289
290
    if hasattr(scenario.load, "mixed_token_test") and scenario.load.mixed_token_test:
        # For mixed token tests, we have separate overflow and recovery directories
        overflow_dir = f"{request.node.name}{OVERFLOW_SUFFIX}"
        recovery_dir = f"{request.node.name}{RECOVERY_SUFFIX}"
        log_paths = [overflow_dir, recovery_dir]

        logging.info("Mixed token test detected. Looking for results in:")
        logging.info(f"  - Overflow phase: {overflow_dir}")
        logging.info(f"  - Recovery phase: {recovery_dir}")
    else:
        # Standard test with single directory
        log_paths = [request.node.name]

291
292
    # Use factory to auto-detect and parse results
    try:
293
        results = parse_test_results(
294
            log_dir=None,
295
            log_paths=log_paths,
296
297
            tablefmt="fancy_grid",
            sla=scenario.load.sla,
298
299
            success_threshold=scenario.load.success_threshold,
            print_output=True,
300
301
302
            # force_parser can be set based on client_type if needed
            # force_parser=scenario.load.client_type,
        )
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        # Store results for reference
        if results:
            logging.info(f"Results parsed: {type(results)}")
            test_results_cache[test_name] = results

            # IMMEDIATELY run validation now that we have results
            try:
                logger.info("\n" + "=" * 60)
                logger.info("Running validation checks...")
                logger.info("=" * 60)

                # Extract metrics and recovery time from parsed results
                if isinstance(results, list) and len(results) > 0:
                    result = results[0]
                elif isinstance(results, dict):
                    result = results
                else:
                    logger.warning(f"Unexpected result format: {type(results)}")
                    result = None

                if result:
                    metrics = result.get("metrics", {})
                    recovery_time = result.get("recovery_time")

                    # Create ValidationContext for all checkers
                    validation_ctx = ValidationContext(
                        scenario=scenario,
                        log_dir=test_name,
                        metrics=metrics,
                        deployment=context.get("deployment"),
                        namespace=context.get("namespace"),
                        recovery_time=recovery_time,
                        affected_pods=context.get("affected_pods", {}),
                    )

                    # Use pre-generated checkers from scenario
                    # Checkers were already determined during scenario creation
                    checkers = scenario.checkers or []

                    # Run all checkers
                    for checker in checkers:
                        logger.info(f"\nRunning checker: {checker.name}")
                        checker.check(validation_ctx)

                    logger.info("=" * 60)
                    logger.info("✓ All validation checks passed")
                    logger.info("=" * 60 + "\n")

            except AssertionError as e:
                logger.error("=" * 60)
                logger.error(f"✗ Validation failed: {e}")
                logger.error("=" * 60 + "\n")
                # Re-raise to fail the test
                raise
            except Exception as e:
                logger.error(f"Validation error: {e}")
                # Don't fail test on validation errors (non-assertion exceptions)
                logger.warning("Skipping validation due to error")

362
    except Exception:
363
        logging.exception("Failed to parse results for %s", test_name)
364

365
366
    # Add all directories to global list for session summary
    global_result_list.extend(log_paths)
367
368
369
370


@pytest.fixture(autouse=True, scope="session")
def results_summary():
371
372
    """
    Session summary that processes all tests but only prints paired tests.
373
    """
374
    yield
375
376
377
378

    if not global_result_list:
        return

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    # Step 1: Group directories
    test_groups: dict[str, dict[str, str]] = {}

    for log_path in global_result_list:
        if log_path.endswith(OVERFLOW_SUFFIX):
            base_name = log_path[: -len(OVERFLOW_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["overflow"] = log_path
        elif log_path.endswith(RECOVERY_SUFFIX):
            base_name = log_path[: -len(RECOVERY_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["recovery"] = log_path

    # Step 2: Process all tests (get results) but only print paired ones
395
    try:
396
        # First, silently parse all tests to get results (for any downstream processing)
397
398
399
400
        parse_test_results(
            log_dir=None,
            log_paths=global_result_list,
            tablefmt="fancy_grid",
401
            print_output=False,  # Don't print anything
402
        )
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

        for base_name, paths in test_groups.items():
            if "overflow" in paths and "recovery" in paths:
                # Extract scenario from test name to pass configs
                scenario_obj = None
                match = re.search(r"\[(.*)\]", base_name)
                if match:
                    scenario_name = match.group(1)
                    if scenario_name in scenarios:
                        scenario_obj = scenarios[scenario_name]
                        logging.info(
                            f"Found scenario '{scenario_name}' for combined results."
                        )

                if not scenario_obj:
                    logging.warning(
                        f"Could not find scenario for '{base_name}'. Using default thresholds."
                    )

                success_threshold = (
                    scenario_obj.load.success_threshold if scenario_obj else 90.0
                )
                logging.info(
                    f"Using success_threshold: {success_threshold} for combined summary of '{base_name}'"
                )

                # This function will print the combined summary
                process_overflow_recovery_test(
                    overflow_path=paths["overflow"],
                    recovery_path=paths["recovery"],
                    tablefmt="fancy_grid",
                    sla=scenario_obj.load.sla if scenario_obj else None,
                    success_threshold=success_threshold,
                )

438
439
    except Exception as e:
        logging.error(f"Failed to parse combined results: {e}")
440
441


442
443
@pytest.mark.k8s
@pytest.mark.fault_tolerance
444
445
446
447
448
449
450
451
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
    scenario,  # noqa: F811
    request,
    image,
    namespace,
452
    validation_context,  # noqa: F811  # Shared context for passing data to validation
453
454
455
):
    """
    Test dynamo serve deployments with injected failures
456
457
458
459
460
461

    Flow:
    1. validation_context fixture creates empty dict: {"deployment": None, "namespace": None, "affected_pods": {}}
    2. This test populates it: validation_context["deployment"] = deployment, etc.
    3. After test completes, fixture reads validation_context and runs validation checkers
    4. Checkers use the populated ValidationContext to verify test results and K8s events
462
463
464
465
466
467
468
469
470
471
472
473
474
    """

    logger = logging.getLogger(request.node.name)

    scenario.deployment.name = "fault-tolerance-test"

    if image:
        scenario.deployment.set_image(image)

    if scenario.model:
        scenario.deployment.set_model(scenario.model)
        model = scenario.model
    else:
475
476
477
478
479
480
        # Get model from the appropriate worker based on backend
        try:
            if scenario.backend == "vllm":
                model = scenario.deployment["VllmDecodeWorker"].model
            elif scenario.backend == "sglang":
                model = scenario.deployment["decode"].model
481
482
483
484
485
486
487
488
489
            elif scenario.backend == "trtllm":
                # Determine deployment type from scenario deployment name
                if (
                    "agg" in scenario.deployment.name
                    and "disagg" not in scenario.deployment.name
                ):
                    model = scenario.deployment["TRTLLMWorker"].model
                else:
                    model = scenario.deployment["TRTLLMDecodeWorker"].model
490
491
492
493
494
495
            else:
                model = None
        except (KeyError, AttributeError):
            model = None
    # Fallback to default if still None
    model = model or "Qwen/Qwen3-0.6B"
496
497
498
499
500
501
502
503

    scenario.deployment.set_logging(True, "info")

    async with ManagedDeployment(
        namespace=namespace,
        log_dir=request.node.name,
        deployment_spec=scenario.deployment,
    ) as deployment:
504
505
506
507
        # Populate shared context for validation
        validation_context["deployment"] = deployment
        validation_context["namespace"] = namespace

508
509
510
511
512
513
        with _clients(
            logger,
            request,
            scenario.deployment,
            namespace,
            model,
514
            scenario.load,  # Pass entire Load config object
515
        ):
516
517
518
519
520
            # Inject failures and capture which pods were affected
            affected_pods = _inject_failures(scenario.failures, logger, deployment)
            validation_context["affected_pods"] = affected_pods

            logger.info(f"Affected pods during test: {affected_pods}")