test_deployment.py 18.8 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
import asyncio
5
6
import logging
import multiprocessing
7
import os
8
import re
9
import signal
10
from contextlib import contextmanager
11
12
from multiprocessing.context import SpawnProcess
from typing import Any, Optional
13
14
15

import pytest

16
from tests.fault_tolerance.deploy.base_checker import ValidationContext
17
18
from tests.fault_tolerance.deploy.client_factory import get_client_function
from tests.fault_tolerance.deploy.parse_factory import parse_test_results
19
20
21
22
from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery_test
from tests.fault_tolerance.deploy.scenarios import (
    OVERFLOW_SUFFIX,
    RECOVERY_SUFFIX,
23
    Failure,
24
    Load,
25
    Scenario,
26
27
    scenarios,
)
28
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
29
from tests.utils.test_output import resolve_test_output_path
30
31


32
33
@pytest.fixture
def scenario(scenario_name, client_type):
34
35
36
37
    """Get scenario and optionally override client type from command line.

    If --client-type is specified, it overrides the scenario's default client type.
    """
38
    scenario_obj = scenarios[scenario_name]
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

    # Override client type if specified on command line
    if client_type is not None:
        # Create a copy of the load config with overridden client type
        import copy

        scenario_obj = copy.deepcopy(scenario_obj)
        scenario_obj.load.client_type = client_type

        # Adjust retry settings based on client type
        if client_type == "legacy":
            # Legacy uses per-request retries
            if scenario_obj.load.max_retries > 1:
                scenario_obj.load.max_retries = 1
        elif client_type == "aiperf":
            # AI-Perf uses full test retries
            if scenario_obj.load.max_retries < 3:
                scenario_obj.load.max_retries = 3

    return scenario_obj
59
60
61
62


@contextmanager
def _clients(
63
64
65
66
67
    logger: logging.Logger,
    log_dir: str,
    deployment_spec: DeploymentSpec,
    namespace: str,
    model: str,
68
    load_config: Load,
69
):
70
71
72
73
    """Start client processes using factory pattern for client selection.

    Args:
        logger: Logger instance
74
        log_dir: Log directory for output logs and client logs/artifacts
75
76
77
78
79
80
81
82
83
84
85
86
        deployment_spec: Deployment specification
        namespace: Kubernetes namespace
        model: Model name to test
        load_config: Load configuration object containing client settings
    """
    # Get appropriate client function based on configuration
    client_func = get_client_function(load_config.client_type)

    logger.info(
        f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
    )

87
    procs: list[SpawnProcess] = []
88
    ctx = multiprocessing.get_context("spawn")
89

90
91
    # Both client types use max_request_rate for rate limiting (requests/sec)
    max_request_rate = load_config.max_request_rate
92

93
94
95
    # Check if this is a continuous load test (rolling upgrade scenarios)
    continuous_load = getattr(load_config, "continuous_load", False)

96
97
98
99
100
101
102
103
104
105
106
107
108
    # Check if this is a mixed token test (overflow + recovery)
    # If mixed_token_test is True, run two phases; otherwise run normally
    if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test:
        logger.info(
            f"Mixed token test: {load_config.overflow_request_count} overflow requests "
            f"({load_config.overflow_token_length} tokens) + "
            f"{load_config.normal_request_count} normal requests "
            f"({load_config.input_token_length} tokens)"
        )

        # First phase: Send overflow requests
        for i in range(load_config.clients):
            proc_overflow = ctx.Process(
109
                target=client_func,
110
111
112
113
                args=(
                    deployment_spec,
                    namespace,
                    model,
114
                    f"{log_dir}{OVERFLOW_SUFFIX}",
115
                    i,
116
117
                    load_config.overflow_request_count,  # 15 overflow requests
                    load_config.overflow_token_length,  # 2x max_seq_len tokens
118
119
                    load_config.output_token_length,
                    load_config.max_retries,
120
                    max_request_rate,
121
                    continuous_load,
122
123
                ),
            )
124
125
126
127
128
129
130
131
132
133
134
            proc_overflow.start()
            procs.append(proc_overflow)
            logger.debug(f"Started overflow client {i} (PID: {proc_overflow.pid})")

        # Wait for overflow requests to complete
        for proc in procs:
            proc.join()

        logger.info("Overflow requests completed. Starting recovery phase...")

        # Second phase: Send normal requests to test recovery
135
        procs_recovery: list[SpawnProcess] = []
136
137
138
139
140
141
142
        for i in range(load_config.clients):
            proc_normal = ctx.Process(
                target=client_func,
                args=(
                    deployment_spec,
                    namespace,
                    model,
143
                    f"{log_dir}{RECOVERY_SUFFIX}",
144
145
146
147
148
                    i,
                    load_config.normal_request_count,  # 15 normal requests
                    load_config.input_token_length,  # Normal token count
                    load_config.output_token_length,
                    load_config.max_retries,
149
                    max_request_rate,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                ),
            )
            proc_normal.start()
            procs_recovery.append(proc_normal)
            logger.debug(f"Started recovery client {i} (PID: {proc_normal.pid})")

        # Add recovery processes to main list
        procs.extend(procs_recovery)
    else:
        # Normal test - single phase
        for i in range(load_config.clients):
            procs.append(
                ctx.Process(
                    target=client_func,
                    args=(
                        deployment_spec,
                        namespace,
                        model,
168
                        log_dir,
169
170
171
172
173
                        i,
                        load_config.requests_per_client,
                        load_config.input_token_length,
                        load_config.output_token_length,
                        load_config.max_retries,
174
                        max_request_rate,
175
                        continuous_load,  # Pass continuous_load flag
176
177
178
179
180
                    ),
                )
            )
            procs[-1].start()
            logger.debug(f"Started client {i} (PID: {procs[-1].pid})")
181

182
183
184
185
186
187
188
189
    yield procs

    for proc in procs:
        logger.debug(f"{proc} waiting for join")
        proc.join()
        logger.debug(f"{proc} joined")


190
191
192
193
def _terminate_client_processes(
    client_procs: list[SpawnProcess],
    logger: logging.Logger,
):
194
    """
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    Terminate client processes.
    """
    # Send SIGINT to client processes to stop continuous load
    if client_procs:
        logger.info(f"Sending SIGINT to {len(client_procs)} client processes...")
        for proc in client_procs:
            if proc.is_alive():
                try:
                    if proc.pid is not None:
                        logger.debug(f"Sending SIGINT to client process {proc.pid}")
                        os.kill(proc.pid, signal.SIGINT)
                    else:
                        raise ValueError(f"Process {proc} has no PID")
                except ProcessLookupError:
                    logger.debug(f"Process {proc.pid} already terminated")
                except Exception as e:
                    logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}")
        logger.info(
            "SIGINT sent to all client processes, waiting for graceful shutdown..."
        )
    else:
        logger.warning("No client processes provided to terminate")
217
218


219
220
221
222
223
224
async def _inject_failures(
    failures: list[Failure],
    logger: logging.Logger,
    deployment: ManagedDeployment,
) -> dict[str, list]:  # noqa: F811
    affected_pods: dict[str, list] = {}
225

226
227
    for failure in failures:
        await asyncio.sleep(failure.time)
228
229
230

        logger.info(f"Injecting failure for: {failure}")

231
232
233
        affected_pods[failure.get_failure_key()] = await failure.execute(
            deployment, logger
        )
234

235
236
    return affected_pods

237
238

global_result_list = []
239
240
# Global storage for test results (used by validation fixture)
test_results_cache = {}
241
242
243


@pytest.fixture(autouse=True)
244
245
246
247
248
249
def validation_context(request, scenario):  # noqa: F811
    """Provides shared context between test execution and validation.

    This fixture creates a shared dictionary that the test populates during
    execution (deployment, namespace, affected_pods), then uses that data
    in teardown to parse results and run checkers.
250
251

    Automatically detects result type (AI-Perf or legacy) and uses
252
    the appropriate parser. After parsing, immediately runs validation checkers.
253
    """
254
255
256
257
258
259
260
261
    # Shared context that test will populate during execution
    context: dict[str, Any] = {
        "deployment": None,
        "namespace": None,
        "affected_pods": {},
    }

    yield context  # Test receives this and populates it
262

263
264
    # Determine log paths based on whether this is a mixed token test
    log_paths = []
265
266
267
    test_name = request.node.name
    logger = logging.getLogger(test_name)

268
269
    if hasattr(scenario.load, "mixed_token_test") and scenario.load.mixed_token_test:
        # For mixed token tests, we have separate overflow and recovery directories
270
271
        overflow_dir = resolve_test_output_path(f"{request.node.name}{OVERFLOW_SUFFIX}")
        recovery_dir = resolve_test_output_path(f"{request.node.name}{RECOVERY_SUFFIX}")
272
273
274
275
276
277
278
        log_paths = [overflow_dir, recovery_dir]

        logging.info("Mixed token test detected. Looking for results in:")
        logging.info(f"  - Overflow phase: {overflow_dir}")
        logging.info(f"  - Recovery phase: {recovery_dir}")
    else:
        # Standard test with single directory
279
        log_paths = [resolve_test_output_path(request.node.name)]
280

281
282
    # Use factory to auto-detect and parse results
    try:
283
        results = parse_test_results(
284
            log_dir=None,
285
            log_paths=log_paths,
286
287
            tablefmt="fancy_grid",
            sla=scenario.load.sla,
288
289
            success_threshold=scenario.load.success_threshold,
            print_output=True,
290
291
292
            # force_parser can be set based on client_type if needed
            # force_parser=scenario.load.client_type,
        )
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        # Store results for reference
        if results:
            logging.info(f"Results parsed: {type(results)}")
            test_results_cache[test_name] = results

            # IMMEDIATELY run validation now that we have results
            try:
                logger.info("\n" + "=" * 60)
                logger.info("Running validation checks...")
                logger.info("=" * 60)

                # Extract metrics and recovery time from parsed results
                if isinstance(results, list) and len(results) > 0:
                    result = results[0]
                elif isinstance(results, dict):
                    result = results
                else:
                    logger.warning(f"Unexpected result format: {type(results)}")
                    result = None

                if result:
                    metrics = result.get("metrics", {})
                    recovery_time = result.get("recovery_time")

                    # Create ValidationContext for all checkers
                    validation_ctx = ValidationContext(
                        scenario=scenario,
320
                        log_dir=resolve_test_output_path(test_name),
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
                        metrics=metrics,
                        deployment=context.get("deployment"),
                        namespace=context.get("namespace"),
                        recovery_time=recovery_time,
                        affected_pods=context.get("affected_pods", {}),
                    )

                    # Use pre-generated checkers from scenario
                    # Checkers were already determined during scenario creation
                    checkers = scenario.checkers or []

                    # Run all checkers
                    for checker in checkers:
                        logger.info(f"\nRunning checker: {checker.name}")
                        checker.check(validation_ctx)

                    logger.info("=" * 60)
                    logger.info("✓ All validation checks passed")
                    logger.info("=" * 60 + "\n")

            except AssertionError as e:
                logger.error("=" * 60)
                logger.error(f"✗ Validation failed: {e}")
                logger.error("=" * 60 + "\n")
                # Re-raise to fail the test
                raise
            except Exception as e:
                logger.error(f"Validation error: {e}")
                # Don't fail test on validation errors (non-assertion exceptions)
                logger.warning("Skipping validation due to error")

352
    except Exception:
353
        logging.exception("Failed to parse results for %s", test_name)
354

355
356
    # Add all directories to global list for session summary
    global_result_list.extend(log_paths)
357
358
359
360


@pytest.fixture(autouse=True, scope="session")
def results_summary():
361
362
    """
    Session summary that processes all tests but only prints paired tests.
363
    """
364
    yield
365
366
367
368

    if not global_result_list:
        return

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    # Step 1: Group directories
    test_groups: dict[str, dict[str, str]] = {}

    for log_path in global_result_list:
        if log_path.endswith(OVERFLOW_SUFFIX):
            base_name = log_path[: -len(OVERFLOW_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["overflow"] = log_path
        elif log_path.endswith(RECOVERY_SUFFIX):
            base_name = log_path[: -len(RECOVERY_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["recovery"] = log_path

    # Step 2: Process all tests (get results) but only print paired ones
385
    try:
386
        # First, silently parse all tests to get results (for any downstream processing)
387
388
389
390
        parse_test_results(
            log_dir=None,
            log_paths=global_result_list,
            tablefmt="fancy_grid",
391
            print_output=False,  # Don't print anything
392
        )
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427

        for base_name, paths in test_groups.items():
            if "overflow" in paths and "recovery" in paths:
                # Extract scenario from test name to pass configs
                scenario_obj = None
                match = re.search(r"\[(.*)\]", base_name)
                if match:
                    scenario_name = match.group(1)
                    if scenario_name in scenarios:
                        scenario_obj = scenarios[scenario_name]
                        logging.info(
                            f"Found scenario '{scenario_name}' for combined results."
                        )

                if not scenario_obj:
                    logging.warning(
                        f"Could not find scenario for '{base_name}'. Using default thresholds."
                    )

                success_threshold = (
                    scenario_obj.load.success_threshold if scenario_obj else 90.0
                )
                logging.info(
                    f"Using success_threshold: {success_threshold} for combined summary of '{base_name}'"
                )

                # This function will print the combined summary
                process_overflow_recovery_test(
                    overflow_path=paths["overflow"],
                    recovery_path=paths["recovery"],
                    tablefmt="fancy_grid",
                    sla=scenario_obj.load.sla if scenario_obj else None,
                    success_threshold=success_threshold,
                )

428
429
    except Exception as e:
        logging.error(f"Failed to parse combined results: {e}")
430
431


432
433
@pytest.mark.k8s
@pytest.mark.fault_tolerance
434
@pytest.mark.post_merge
435
436
437
438
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
439
    scenario: Scenario,  # noqa: F811
440
    request,
441
442
    image: str,
    namespace: str,
443
    validation_context,  # noqa: F811  # Shared context for passing data to validation
444
    skip_service_restart: bool,
445
446
447
):
    """
    Test dynamo serve deployments with injected failures
448
449
450
451
452
453

    Flow:
    1. validation_context fixture creates empty dict: {"deployment": None, "namespace": None, "affected_pods": {}}
    2. This test populates it: validation_context["deployment"] = deployment, etc.
    3. After test completes, fixture reads validation_context and runs validation checkers
    4. Checkers use the populated ValidationContext to verify test results and K8s events
454
455
456
457
458
459
460
461
462
    """

    logger = logging.getLogger(request.node.name)

    scenario.deployment.name = "fault-tolerance-test"

    if image:
        scenario.deployment.set_image(image)

463
    model: Optional[str] = None
464
465
466
467
    if scenario.model:
        scenario.deployment.set_model(scenario.model)
        model = scenario.model
    else:
468
469
470
471
472
473
        # Get model from the appropriate worker based on backend
        try:
            if scenario.backend == "vllm":
                model = scenario.deployment["VllmDecodeWorker"].model
            elif scenario.backend == "sglang":
                model = scenario.deployment["decode"].model
474
475
476
477
478
479
480
481
482
            elif scenario.backend == "trtllm":
                # Determine deployment type from scenario deployment name
                if (
                    "agg" in scenario.deployment.name
                    and "disagg" not in scenario.deployment.name
                ):
                    model = scenario.deployment["TRTLLMWorker"].model
                else:
                    model = scenario.deployment["TRTLLMDecodeWorker"].model
483
484
485
486
487
488
            else:
                model = None
        except (KeyError, AttributeError):
            model = None
    # Fallback to default if still None
    model = model or "Qwen/Qwen3-0.6B"
489
490
491
492
493
494
495

    scenario.deployment.set_logging(True, "info")

    async with ManagedDeployment(
        namespace=namespace,
        log_dir=request.node.name,
        deployment_spec=scenario.deployment,
496
        skip_service_restart=skip_service_restart,
497
    ) as deployment:
498
499
500
501
        # Populate shared context for validation
        validation_context["deployment"] = deployment
        validation_context["namespace"] = namespace

502
503
        with _clients(
            logger,
504
            request.node.name,
505
506
507
            scenario.deployment,
            namespace,
            model,
508
            scenario.load,  # Pass entire Load config object
509
        ) as client_procs:
510
            # Inject failures and capture which pods were affected
511
512
513
            affected_pods = await _inject_failures(
                scenario.failures, logger, deployment
            )
514
            logger.info(f"Affected pods during test: {affected_pods}")
515
516
517

            if scenario.load.continuous_load:
                _terminate_client_processes(client_procs, logger)