test_deployment.py 20.1 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
import asyncio
5
6
import logging
import multiprocessing
7
import os
8
import re
9
import signal
10
from contextlib import contextmanager
11
from multiprocessing.context import SpawnProcess
12
from typing import Any, Optional
13
14
15

import pytest

16
from tests.fault_tolerance.deploy.base_checker import ValidationContext
17
18
from tests.fault_tolerance.deploy.client_factory import get_client_function
from tests.fault_tolerance.deploy.parse_factory import parse_test_results
19
20
21
22
from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery_test
from tests.fault_tolerance.deploy.scenarios import (
    OVERFLOW_SUFFIX,
    RECOVERY_SUFFIX,
23
    Failure,
24
    Load,
25
    Scenario,
26
27
    scenarios,
)
28
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
29
from tests.utils.test_output import resolve_test_output_path
30
31


32
33
def get_model_from_deployment(
    deployment_spec: DeploymentSpec,
34
35
    scenario: Optional[Scenario] = None,
    service_name: Optional[str] = None,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
) -> str:
    """Get model name from deployment spec.

    Args:
        deployment_spec: Deployment specification
        scenario: Optional Scenario object with backend and model info
        service_name: Optional specific service to get model from

    Returns:
        Model name (never None, falls back to default)
    """
    # If scenario specifies a model, use that
    if scenario and scenario.model:
        return scenario.model

    # Try to get model from specified service
    if service_name:
        try:
            service_spec = deployment_spec[service_name]
            if service_spec and service_spec.model:
                return service_spec.model
        except (KeyError, AttributeError):
            pass

    # Get model from backend-specific worker (if scenario provided)
    if scenario:
        try:
63
            model: Optional[str] = None
64
            if scenario.backend == "vllm":
65
                model = deployment_spec["VllmDecodeWorker"].model
66
            elif scenario.backend == "sglang":
67
                model = deployment_spec["decode"].model
68
69
70
71
72
73
            elif scenario.backend == "trtllm":
                # Determine deployment type from scenario deployment name
                if (
                    "agg" in deployment_spec.name
                    and "disagg" not in deployment_spec.name
                ):
74
                    model = deployment_spec["TRTLLMWorker"].model
75
                else:
76
77
78
                    model = deployment_spec["TRTLLMDecodeWorker"].model
            if model:
                return model
79
80
81
82
83
84
85
86
87
88
89
        except (KeyError, AttributeError) as e:
            logging.warning(
                f"Could not get model from backend-specific worker "
                f"(backend={scenario.backend}): {e}"
            )

    # Fallback to default
    logging.info("Using default model: Qwen/Qwen3-0.6B")
    return "Qwen/Qwen3-0.6B"


90
91
@pytest.fixture
def scenario(scenario_name, client_type):
92
93
94
95
    """Get scenario and optionally override client type from command line.

    If --client-type is specified, it overrides the scenario's default client type.
    """
96
    scenario_obj = scenarios[scenario_name]
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    # Override client type if specified on command line
    if client_type is not None:
        # Create a copy of the load config with overridden client type
        import copy

        scenario_obj = copy.deepcopy(scenario_obj)
        scenario_obj.load.client_type = client_type

        # Adjust retry settings based on client type
        if client_type == "legacy":
            # Legacy uses per-request retries
            if scenario_obj.load.max_retries > 1:
                scenario_obj.load.max_retries = 1
        elif client_type == "aiperf":
            # AI-Perf uses full test retries
            if scenario_obj.load.max_retries < 3:
                scenario_obj.load.max_retries = 3

    return scenario_obj
117
118
119
120


@contextmanager
def _clients(
121
122
123
124
125
    logger: logging.Logger,
    log_dir: str,
    deployment_spec: DeploymentSpec,
    namespace: str,
    model: str,
126
    load_config: Load,
127
):
128
129
130
131
    """Start client processes using factory pattern for client selection.

    Args:
        logger: Logger instance
132
        log_dir: Log directory for output logs and client logs/artifacts
133
134
135
136
137
138
139
140
141
142
143
144
        deployment_spec: Deployment specification
        namespace: Kubernetes namespace
        model: Model name to test
        load_config: Load configuration object containing client settings
    """
    # Get appropriate client function based on configuration
    client_func = get_client_function(load_config.client_type)

    logger.info(
        f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
    )

145
    procs: list[SpawnProcess] = []
146
    ctx = multiprocessing.get_context("spawn")
147

148
149
    # Both client types use max_request_rate for rate limiting (requests/sec)
    max_request_rate = load_config.max_request_rate
150

151
152
153
    # Check if this is a continuous load test (rolling upgrade scenarios)
    continuous_load = getattr(load_config, "continuous_load", False)

154
155
156
157
158
159
160
161
162
163
164
165
166
    # Check if this is a mixed token test (overflow + recovery)
    # If mixed_token_test is True, run two phases; otherwise run normally
    if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test:
        logger.info(
            f"Mixed token test: {load_config.overflow_request_count} overflow requests "
            f"({load_config.overflow_token_length} tokens) + "
            f"{load_config.normal_request_count} normal requests "
            f"({load_config.input_token_length} tokens)"
        )

        # First phase: Send overflow requests
        for i in range(load_config.clients):
            proc_overflow = ctx.Process(
167
                target=client_func,
168
169
170
171
                args=(
                    deployment_spec,
                    namespace,
                    model,
172
                    f"{log_dir}{OVERFLOW_SUFFIX}",
173
                    i,
174
175
                    load_config.overflow_request_count,  # 15 overflow requests
                    load_config.overflow_token_length,  # 2x max_seq_len tokens
176
177
                    load_config.output_token_length,
                    load_config.max_retries,
178
                    max_request_rate,
179
                    continuous_load,
180
181
                ),
            )
182
183
184
185
186
187
188
189
190
191
192
            proc_overflow.start()
            procs.append(proc_overflow)
            logger.debug(f"Started overflow client {i} (PID: {proc_overflow.pid})")

        # Wait for overflow requests to complete
        for proc in procs:
            proc.join()

        logger.info("Overflow requests completed. Starting recovery phase...")

        # Second phase: Send normal requests to test recovery
193
        procs_recovery: list[SpawnProcess] = []
194
195
196
197
198
199
200
        for i in range(load_config.clients):
            proc_normal = ctx.Process(
                target=client_func,
                args=(
                    deployment_spec,
                    namespace,
                    model,
201
                    f"{log_dir}{RECOVERY_SUFFIX}",
202
203
204
205
206
                    i,
                    load_config.normal_request_count,  # 15 normal requests
                    load_config.input_token_length,  # Normal token count
                    load_config.output_token_length,
                    load_config.max_retries,
207
                    max_request_rate,
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
                ),
            )
            proc_normal.start()
            procs_recovery.append(proc_normal)
            logger.debug(f"Started recovery client {i} (PID: {proc_normal.pid})")

        # Add recovery processes to main list
        procs.extend(procs_recovery)
    else:
        # Normal test - single phase
        for i in range(load_config.clients):
            procs.append(
                ctx.Process(
                    target=client_func,
                    args=(
                        deployment_spec,
                        namespace,
                        model,
226
                        log_dir,
227
228
229
230
231
                        i,
                        load_config.requests_per_client,
                        load_config.input_token_length,
                        load_config.output_token_length,
                        load_config.max_retries,
232
                        max_request_rate,
233
                        continuous_load,  # Pass continuous_load flag
234
235
236
237
238
                    ),
                )
            )
            procs[-1].start()
            logger.debug(f"Started client {i} (PID: {procs[-1].pid})")
239

240
241
242
243
244
245
246
247
    yield procs

    for proc in procs:
        logger.debug(f"{proc} waiting for join")
        proc.join()
        logger.debug(f"{proc} joined")


248
249
250
251
def _terminate_client_processes(
    client_procs: list[SpawnProcess],
    logger: logging.Logger,
):
252
    """
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    Terminate client processes.
    """
    # Send SIGINT to client processes to stop continuous load
    if client_procs:
        logger.info(f"Sending SIGINT to {len(client_procs)} client processes...")
        for proc in client_procs:
            if proc.is_alive():
                try:
                    if proc.pid is not None:
                        logger.debug(f"Sending SIGINT to client process {proc.pid}")
                        os.kill(proc.pid, signal.SIGINT)
                    else:
                        raise ValueError(f"Process {proc} has no PID")
                except ProcessLookupError:
                    logger.debug(f"Process {proc.pid} already terminated")
                except Exception as e:
                    logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}")
        logger.info(
            "SIGINT sent to all client processes, waiting for graceful shutdown..."
        )
    else:
        logger.warning("No client processes provided to terminate")
275
276


277
278
279
280
281
282
async def _inject_failures(
    failures: list[Failure],
    logger: logging.Logger,
    deployment: ManagedDeployment,
) -> dict[str, list]:  # noqa: F811
    affected_pods: dict[str, list] = {}
283

284
285
    for failure in failures:
        await asyncio.sleep(failure.time)
286
287
288

        logger.info(f"Injecting failure for: {failure}")

289
290
291
        affected_pods[failure.get_failure_key()] = await failure.execute(
            deployment, logger
        )
292

293
294
    return affected_pods

295

296
297
# TODO: These globals might not work in parallel testing. FIXME

298
global_result_list = []
299
300
# Global storage for test results (used by validation fixture)
test_results_cache = {}
301
302
303


@pytest.fixture(autouse=True)
304
305
306
307
308
309
def validation_context(request, scenario):  # noqa: F811
    """Provides shared context between test execution and validation.

    This fixture creates a shared dictionary that the test populates during
    execution (deployment, namespace, affected_pods), then uses that data
    in teardown to parse results and run checkers.
310
311

    Automatically detects result type (AI-Perf or legacy) and uses
312
    the appropriate parser. After parsing, immediately runs validation checkers.
313
    """
314
315
316
317
318
319
320
321
    # Shared context that test will populate during execution
    context: dict[str, Any] = {
        "deployment": None,
        "namespace": None,
        "affected_pods": {},
    }

    yield context  # Test receives this and populates it
322

323
324
    # Determine log paths based on whether this is a mixed token test
    log_paths = []
325
326
327
    test_name = request.node.name
    logger = logging.getLogger(test_name)

328
329
    if hasattr(scenario.load, "mixed_token_test") and scenario.load.mixed_token_test:
        # For mixed token tests, we have separate overflow and recovery directories
330
331
        overflow_dir = resolve_test_output_path(f"{request.node.name}{OVERFLOW_SUFFIX}")
        recovery_dir = resolve_test_output_path(f"{request.node.name}{RECOVERY_SUFFIX}")
332
333
334
335
336
337
338
        log_paths = [overflow_dir, recovery_dir]

        logging.info("Mixed token test detected. Looking for results in:")
        logging.info(f"  - Overflow phase: {overflow_dir}")
        logging.info(f"  - Recovery phase: {recovery_dir}")
    else:
        # Standard test with single directory
339
        log_paths = [resolve_test_output_path(request.node.name)]
340

341
342
    # Use factory to auto-detect and parse results
    try:
343
        results = parse_test_results(
344
            log_dir=None,
345
            log_paths=log_paths,
346
347
            tablefmt="fancy_grid",
            sla=scenario.load.sla,
348
349
            success_threshold=scenario.load.success_threshold,
            print_output=True,
350
351
352
            # force_parser can be set based on client_type if needed
            # force_parser=scenario.load.client_type,
        )
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        # Store results for reference
        if results:
            logging.info(f"Results parsed: {type(results)}")
            test_results_cache[test_name] = results

            # IMMEDIATELY run validation now that we have results
            try:
                logger.info("\n" + "=" * 60)
                logger.info("Running validation checks...")
                logger.info("=" * 60)

                # Extract metrics and recovery time from parsed results
                if isinstance(results, list) and len(results) > 0:
                    result = results[0]
                elif isinstance(results, dict):
                    result = results
                else:
                    logger.warning(f"Unexpected result format: {type(results)}")
                    result = None

                if result:
                    metrics = result.get("metrics", {})
                    recovery_time = result.get("recovery_time")

                    # Create ValidationContext for all checkers
                    validation_ctx = ValidationContext(
                        scenario=scenario,
380
                        log_dir=resolve_test_output_path(test_name),
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                        metrics=metrics,
                        deployment=context.get("deployment"),
                        namespace=context.get("namespace"),
                        recovery_time=recovery_time,
                        affected_pods=context.get("affected_pods", {}),
                    )

                    # Use pre-generated checkers from scenario
                    # Checkers were already determined during scenario creation
                    checkers = scenario.checkers or []

                    # Run all checkers
                    for checker in checkers:
                        logger.info(f"\nRunning checker: {checker.name}")
                        checker.check(validation_ctx)

                    logger.info("=" * 60)
                    logger.info("✓ All validation checks passed")
                    logger.info("=" * 60 + "\n")

            except AssertionError as e:
                logger.error("=" * 60)
                logger.error(f"✗ Validation failed: {e}")
                logger.error("=" * 60 + "\n")
                # Re-raise to fail the test
                raise
            except Exception as e:
                logger.error(f"Validation error: {e}")
                # Don't fail test on validation errors (non-assertion exceptions)
                logger.warning("Skipping validation due to error")

412
    except Exception:
413
        logging.exception("Failed to parse results for %s", test_name)
414

415
416
    # Add all directories to global list for session summary
    global_result_list.extend(log_paths)
417
418
419
420


@pytest.fixture(autouse=True, scope="session")
def results_summary():
421
422
    """
    Session summary that processes all tests but only prints paired tests.
423
    """
424
    yield
425
426
427
428

    if not global_result_list:
        return

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    # Step 1: Group directories
    test_groups: dict[str, dict[str, str]] = {}

    for log_path in global_result_list:
        if log_path.endswith(OVERFLOW_SUFFIX):
            base_name = log_path[: -len(OVERFLOW_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["overflow"] = log_path
        elif log_path.endswith(RECOVERY_SUFFIX):
            base_name = log_path[: -len(RECOVERY_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["recovery"] = log_path

    # Step 2: Process all tests (get results) but only print paired ones
445
    try:
446
        # First, silently parse all tests to get results (for any downstream processing)
447
448
449
450
        parse_test_results(
            log_dir=None,
            log_paths=global_result_list,
            tablefmt="fancy_grid",
451
            print_output=False,  # Don't print anything
452
        )
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

        for base_name, paths in test_groups.items():
            if "overflow" in paths and "recovery" in paths:
                # Extract scenario from test name to pass configs
                scenario_obj = None
                match = re.search(r"\[(.*)\]", base_name)
                if match:
                    scenario_name = match.group(1)
                    if scenario_name in scenarios:
                        scenario_obj = scenarios[scenario_name]
                        logging.info(
                            f"Found scenario '{scenario_name}' for combined results."
                        )

                if not scenario_obj:
                    logging.warning(
                        f"Could not find scenario for '{base_name}'. Using default thresholds."
                    )

                success_threshold = (
                    scenario_obj.load.success_threshold if scenario_obj else 90.0
                )
                logging.info(
                    f"Using success_threshold: {success_threshold} for combined summary of '{base_name}'"
                )

                # This function will print the combined summary
                process_overflow_recovery_test(
                    overflow_path=paths["overflow"],
                    recovery_path=paths["recovery"],
                    tablefmt="fancy_grid",
                    sla=scenario_obj.load.sla if scenario_obj else None,
                    success_threshold=success_threshold,
                )

488
489
    except Exception as e:
        logging.error(f"Failed to parse combined results: {e}")
490
491


492
493
@pytest.mark.k8s
@pytest.mark.fault_tolerance
494
@pytest.mark.post_merge
495
496
@pytest.mark.e2e
@pytest.mark.slow
497
@pytest.mark.gpu_0
498
499
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
500
    scenario: Scenario,  # noqa: F811
501
    request,
502
503
    image: str,
    namespace: str,
504
    validation_context,  # noqa: F811  # Shared context for passing data to validation
505
    skip_service_restart: bool,
506
507
508
):
    """
    Test dynamo serve deployments with injected failures
509
510
511
512
513
514

    Flow:
    1. validation_context fixture creates empty dict: {"deployment": None, "namespace": None, "affected_pods": {}}
    2. This test populates it: validation_context["deployment"] = deployment, etc.
    3. After test completes, fixture reads validation_context and runs validation checkers
    4. Checkers use the populated ValidationContext to verify test results and K8s events
515
516
517
518
519
520
521
522
523
    """

    logger = logging.getLogger(request.node.name)

    scenario.deployment.name = "fault-tolerance-test"

    if image:
        scenario.deployment.set_image(image)

524
525
526
    # Get model using helper function and ensure it's set on all services
    model = get_model_from_deployment(scenario.deployment, scenario)
    scenario.deployment.set_model(model)  # Set model on all services including Frontend
527
528
529
530
531
532
533

    scenario.deployment.set_logging(True, "info")

    async with ManagedDeployment(
        namespace=namespace,
        log_dir=request.node.name,
        deployment_spec=scenario.deployment,
534
        skip_service_restart=skip_service_restart,
535
    ) as deployment:
536
537
538
539
        # Populate shared context for validation
        validation_context["deployment"] = deployment
        validation_context["namespace"] = namespace

540
541
        with _clients(
            logger,
542
            resolve_test_output_path(request.node.name),
543
544
545
            scenario.deployment,
            namespace,
            model,
546
            scenario.load,  # Pass entire Load config object
547
        ) as client_procs:
548
            # Inject failures and capture which pods were affected
549
550
551
            affected_pods = await _inject_failures(
                scenario.failures, logger, deployment
            )
552
            logger.info(f"Affected pods during test: {affected_pods}")
553
554
555

            if scenario.load.continuous_load:
                _terminate_client_processes(client_procs, logger)