test_deployment.py 18.8 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
import asyncio
5
6
import logging
import multiprocessing
7
import os
8
import re
9
import signal
10
from contextlib import contextmanager
11
12
from multiprocessing.context import SpawnProcess
from typing import Any, Optional
13
14
15

import pytest

16
from tests.fault_tolerance.deploy.base_checker import ValidationContext
17
18
from tests.fault_tolerance.deploy.client_factory import get_client_function
from tests.fault_tolerance.deploy.parse_factory import parse_test_results
19
20
21
22
from tests.fault_tolerance.deploy.parse_results import process_overflow_recovery_test
from tests.fault_tolerance.deploy.scenarios import (
    OVERFLOW_SUFFIX,
    RECOVERY_SUFFIX,
23
    Failure,
24
    Load,
25
    Scenario,
26
27
    scenarios,
)
28
from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment
29
30


31
32
@pytest.fixture
def scenario(scenario_name, client_type):
33
34
35
36
    """Get scenario and optionally override client type from command line.

    If --client-type is specified, it overrides the scenario's default client type.
    """
37
    scenario_obj = scenarios[scenario_name]
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    # Override client type if specified on command line
    if client_type is not None:
        # Create a copy of the load config with overridden client type
        import copy

        scenario_obj = copy.deepcopy(scenario_obj)
        scenario_obj.load.client_type = client_type

        # Adjust retry settings based on client type
        if client_type == "legacy":
            # Legacy uses per-request retries
            if scenario_obj.load.max_retries > 1:
                scenario_obj.load.max_retries = 1
        elif client_type == "aiperf":
            # AI-Perf uses full test retries
            if scenario_obj.load.max_retries < 3:
                scenario_obj.load.max_retries = 3

    return scenario_obj
58
59
60
61


@contextmanager
def _clients(
62
63
64
65
66
    logger: logging.Logger,
    log_dir: str,
    deployment_spec: DeploymentSpec,
    namespace: str,
    model: str,
67
    load_config: Load,
68
):
69
70
71
72
    """Start client processes using factory pattern for client selection.

    Args:
        logger: Logger instance
73
        log_dir: Log directory for output logs and client logs/artifacts
74
75
76
77
78
79
80
81
82
83
84
85
        deployment_spec: Deployment specification
        namespace: Kubernetes namespace
        model: Model name to test
        load_config: Load configuration object containing client settings
    """
    # Get appropriate client function based on configuration
    client_func = get_client_function(load_config.client_type)

    logger.info(
        f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
    )

86
    procs: list[SpawnProcess] = []
87
    ctx = multiprocessing.get_context("spawn")
88
89
90
91
92
93
94
95
96

    # Determine retry_delay_or_rate based on client type
    if load_config.client_type == "legacy":
        # Legacy client uses max_request_rate for rate limiting
        retry_delay_or_rate = load_config.max_request_rate
    else:
        # AI-Perf client uses retry_delay between attempts (default 5s)
        retry_delay_or_rate = 5

97
98
99
    # Check if this is a continuous load test (rolling upgrade scenarios)
    continuous_load = getattr(load_config, "continuous_load", False)

100
101
102
103
104
105
106
107
108
109
110
111
112
    # Check if this is a mixed token test (overflow + recovery)
    # If mixed_token_test is True, run two phases; otherwise run normally
    if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test:
        logger.info(
            f"Mixed token test: {load_config.overflow_request_count} overflow requests "
            f"({load_config.overflow_token_length} tokens) + "
            f"{load_config.normal_request_count} normal requests "
            f"({load_config.input_token_length} tokens)"
        )

        # First phase: Send overflow requests
        for i in range(load_config.clients):
            proc_overflow = ctx.Process(
113
                target=client_func,
114
115
116
117
                args=(
                    deployment_spec,
                    namespace,
                    model,
118
                    f"{log_dir}{OVERFLOW_SUFFIX}",
119
                    i,
120
121
                    load_config.overflow_request_count,  # 15 overflow requests
                    load_config.overflow_token_length,  # 2x max_seq_len tokens
122
123
124
                    load_config.output_token_length,
                    load_config.max_retries,
                    retry_delay_or_rate,
125
                    continuous_load,
126
127
                ),
            )
128
129
130
131
132
133
134
135
136
137
138
            proc_overflow.start()
            procs.append(proc_overflow)
            logger.debug(f"Started overflow client {i} (PID: {proc_overflow.pid})")

        # Wait for overflow requests to complete
        for proc in procs:
            proc.join()

        logger.info("Overflow requests completed. Starting recovery phase...")

        # Second phase: Send normal requests to test recovery
139
        procs_recovery: list[SpawnProcess] = []
140
141
142
143
144
145
146
        for i in range(load_config.clients):
            proc_normal = ctx.Process(
                target=client_func,
                args=(
                    deployment_spec,
                    namespace,
                    model,
147
                    f"{log_dir}{RECOVERY_SUFFIX}",
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
                    i,
                    load_config.normal_request_count,  # 15 normal requests
                    load_config.input_token_length,  # Normal token count
                    load_config.output_token_length,
                    load_config.max_retries,
                    retry_delay_or_rate,
                ),
            )
            proc_normal.start()
            procs_recovery.append(proc_normal)
            logger.debug(f"Started recovery client {i} (PID: {proc_normal.pid})")

        # Add recovery processes to main list
        procs.extend(procs_recovery)
    else:
        # Normal test - single phase
        for i in range(load_config.clients):
            procs.append(
                ctx.Process(
                    target=client_func,
                    args=(
                        deployment_spec,
                        namespace,
                        model,
172
                        log_dir,
173
174
175
176
177
178
                        i,
                        load_config.requests_per_client,
                        load_config.input_token_length,
                        load_config.output_token_length,
                        load_config.max_retries,
                        retry_delay_or_rate,
179
                        continuous_load,  # Pass continuous_load flag
180
181
182
183
184
                    ),
                )
            )
            procs[-1].start()
            logger.debug(f"Started client {i} (PID: {procs[-1].pid})")
185

186
187
188
189
190
191
192
193
    yield procs

    for proc in procs:
        logger.debug(f"{proc} waiting for join")
        proc.join()
        logger.debug(f"{proc} joined")


194
195
196
197
def _terminate_client_processes(
    client_procs: list[SpawnProcess],
    logger: logging.Logger,
):
198
    """
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    Terminate client processes.
    """
    # Send SIGINT to client processes to stop continuous load
    if client_procs:
        logger.info(f"Sending SIGINT to {len(client_procs)} client processes...")
        for proc in client_procs:
            if proc.is_alive():
                try:
                    if proc.pid is not None:
                        logger.debug(f"Sending SIGINT to client process {proc.pid}")
                        os.kill(proc.pid, signal.SIGINT)
                    else:
                        raise ValueError(f"Process {proc} has no PID")
                except ProcessLookupError:
                    logger.debug(f"Process {proc.pid} already terminated")
                except Exception as e:
                    logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}")
        logger.info(
            "SIGINT sent to all client processes, waiting for graceful shutdown..."
        )
    else:
        logger.warning("No client processes provided to terminate")
221
222


223
224
225
226
227
228
async def _inject_failures(
    failures: list[Failure],
    logger: logging.Logger,
    deployment: ManagedDeployment,
) -> dict[str, list]:  # noqa: F811
    affected_pods: dict[str, list] = {}
229

230
231
    for failure in failures:
        await asyncio.sleep(failure.time)
232
233
234

        logger.info(f"Injecting failure for: {failure}")

235
236
237
        affected_pods[failure.get_failure_key()] = await failure.execute(
            deployment, logger
        )
238

239
240
    return affected_pods

241
242

global_result_list = []
243
244
# Global storage for test results (used by validation fixture)
test_results_cache = {}
245
246
247


@pytest.fixture(autouse=True)
248
249
250
251
252
253
def validation_context(request, scenario):  # noqa: F811
    """Provides shared context between test execution and validation.

    This fixture creates a shared dictionary that the test populates during
    execution (deployment, namespace, affected_pods), then uses that data
    in teardown to parse results and run checkers.
254
255

    Automatically detects result type (AI-Perf or legacy) and uses
256
    the appropriate parser. After parsing, immediately runs validation checkers.
257
    """
258
259
260
261
262
263
264
265
    # Shared context that test will populate during execution
    context: dict[str, Any] = {
        "deployment": None,
        "namespace": None,
        "affected_pods": {},
    }

    yield context  # Test receives this and populates it
266

267
268
    # Determine log paths based on whether this is a mixed token test
    log_paths = []
269
270
271
    test_name = request.node.name
    logger = logging.getLogger(test_name)

272
273
274
275
276
277
278
279
280
281
282
283
284
    if hasattr(scenario.load, "mixed_token_test") and scenario.load.mixed_token_test:
        # For mixed token tests, we have separate overflow and recovery directories
        overflow_dir = f"{request.node.name}{OVERFLOW_SUFFIX}"
        recovery_dir = f"{request.node.name}{RECOVERY_SUFFIX}"
        log_paths = [overflow_dir, recovery_dir]

        logging.info("Mixed token test detected. Looking for results in:")
        logging.info(f"  - Overflow phase: {overflow_dir}")
        logging.info(f"  - Recovery phase: {recovery_dir}")
    else:
        # Standard test with single directory
        log_paths = [request.node.name]

285
286
    # Use factory to auto-detect and parse results
    try:
287
        results = parse_test_results(
288
            log_dir=None,
289
            log_paths=log_paths,
290
291
            tablefmt="fancy_grid",
            sla=scenario.load.sla,
292
293
            success_threshold=scenario.load.success_threshold,
            print_output=True,
294
295
296
            # force_parser can be set based on client_type if needed
            # force_parser=scenario.load.client_type,
        )
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        # Store results for reference
        if results:
            logging.info(f"Results parsed: {type(results)}")
            test_results_cache[test_name] = results

            # IMMEDIATELY run validation now that we have results
            try:
                logger.info("\n" + "=" * 60)
                logger.info("Running validation checks...")
                logger.info("=" * 60)

                # Extract metrics and recovery time from parsed results
                if isinstance(results, list) and len(results) > 0:
                    result = results[0]
                elif isinstance(results, dict):
                    result = results
                else:
                    logger.warning(f"Unexpected result format: {type(results)}")
                    result = None

                if result:
                    metrics = result.get("metrics", {})
                    recovery_time = result.get("recovery_time")

                    # Create ValidationContext for all checkers
                    validation_ctx = ValidationContext(
                        scenario=scenario,
                        log_dir=test_name,
                        metrics=metrics,
                        deployment=context.get("deployment"),
                        namespace=context.get("namespace"),
                        recovery_time=recovery_time,
                        affected_pods=context.get("affected_pods", {}),
                    )

                    # Use pre-generated checkers from scenario
                    # Checkers were already determined during scenario creation
                    checkers = scenario.checkers or []

                    # Run all checkers
                    for checker in checkers:
                        logger.info(f"\nRunning checker: {checker.name}")
                        checker.check(validation_ctx)

                    logger.info("=" * 60)
                    logger.info("✓ All validation checks passed")
                    logger.info("=" * 60 + "\n")

            except AssertionError as e:
                logger.error("=" * 60)
                logger.error(f"✗ Validation failed: {e}")
                logger.error("=" * 60 + "\n")
                # Re-raise to fail the test
                raise
            except Exception as e:
                logger.error(f"Validation error: {e}")
                # Don't fail test on validation errors (non-assertion exceptions)
                logger.warning("Skipping validation due to error")

356
    except Exception:
357
        logging.exception("Failed to parse results for %s", test_name)
358

359
360
    # Add all directories to global list for session summary
    global_result_list.extend(log_paths)
361
362
363
364


@pytest.fixture(autouse=True, scope="session")
def results_summary():
365
366
    """
    Session summary that processes all tests but only prints paired tests.
367
    """
368
    yield
369
370
371
372

    if not global_result_list:
        return

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    # Step 1: Group directories
    test_groups: dict[str, dict[str, str]] = {}

    for log_path in global_result_list:
        if log_path.endswith(OVERFLOW_SUFFIX):
            base_name = log_path[: -len(OVERFLOW_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["overflow"] = log_path
        elif log_path.endswith(RECOVERY_SUFFIX):
            base_name = log_path[: -len(RECOVERY_SUFFIX)]
            if base_name not in test_groups:
                test_groups[base_name] = {}
            test_groups[base_name]["recovery"] = log_path

    # Step 2: Process all tests (get results) but only print paired ones
389
    try:
390
        # First, silently parse all tests to get results (for any downstream processing)
391
392
393
394
        parse_test_results(
            log_dir=None,
            log_paths=global_result_list,
            tablefmt="fancy_grid",
395
            print_output=False,  # Don't print anything
396
        )
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431

        for base_name, paths in test_groups.items():
            if "overflow" in paths and "recovery" in paths:
                # Extract scenario from test name to pass configs
                scenario_obj = None
                match = re.search(r"\[(.*)\]", base_name)
                if match:
                    scenario_name = match.group(1)
                    if scenario_name in scenarios:
                        scenario_obj = scenarios[scenario_name]
                        logging.info(
                            f"Found scenario '{scenario_name}' for combined results."
                        )

                if not scenario_obj:
                    logging.warning(
                        f"Could not find scenario for '{base_name}'. Using default thresholds."
                    )

                success_threshold = (
                    scenario_obj.load.success_threshold if scenario_obj else 90.0
                )
                logging.info(
                    f"Using success_threshold: {success_threshold} for combined summary of '{base_name}'"
                )

                # This function will print the combined summary
                process_overflow_recovery_test(
                    overflow_path=paths["overflow"],
                    recovery_path=paths["recovery"],
                    tablefmt="fancy_grid",
                    sla=scenario_obj.load.sla if scenario_obj else None,
                    success_threshold=success_threshold,
                )

432
433
    except Exception as e:
        logging.error(f"Failed to parse combined results: {e}")
434
435


436
437
@pytest.mark.k8s
@pytest.mark.fault_tolerance
438
439
440
441
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
442
    scenario: Scenario,  # noqa: F811
443
    request,
444
445
    image: str,
    namespace: str,
446
    validation_context,  # noqa: F811  # Shared context for passing data to validation
447
    skip_service_restart: bool,
448
449
450
):
    """
    Test dynamo serve deployments with injected failures
451
452
453
454
455
456

    Flow:
    1. validation_context fixture creates empty dict: {"deployment": None, "namespace": None, "affected_pods": {}}
    2. This test populates it: validation_context["deployment"] = deployment, etc.
    3. After test completes, fixture reads validation_context and runs validation checkers
    4. Checkers use the populated ValidationContext to verify test results and K8s events
457
458
459
460
461
462
463
464
465
    """

    logger = logging.getLogger(request.node.name)

    scenario.deployment.name = "fault-tolerance-test"

    if image:
        scenario.deployment.set_image(image)

466
    model: Optional[str] = None
467
468
469
470
    if scenario.model:
        scenario.deployment.set_model(scenario.model)
        model = scenario.model
    else:
471
472
473
474
475
476
        # Get model from the appropriate worker based on backend
        try:
            if scenario.backend == "vllm":
                model = scenario.deployment["VllmDecodeWorker"].model
            elif scenario.backend == "sglang":
                model = scenario.deployment["decode"].model
477
478
479
480
481
482
483
484
485
            elif scenario.backend == "trtllm":
                # Determine deployment type from scenario deployment name
                if (
                    "agg" in scenario.deployment.name
                    and "disagg" not in scenario.deployment.name
                ):
                    model = scenario.deployment["TRTLLMWorker"].model
                else:
                    model = scenario.deployment["TRTLLMDecodeWorker"].model
486
487
488
489
490
491
            else:
                model = None
        except (KeyError, AttributeError):
            model = None
    # Fallback to default if still None
    model = model or "Qwen/Qwen3-0.6B"
492
493
494
495
496
497
498

    scenario.deployment.set_logging(True, "info")

    async with ManagedDeployment(
        namespace=namespace,
        log_dir=request.node.name,
        deployment_spec=scenario.deployment,
499
        skip_service_restart=skip_service_restart,
500
    ) as deployment:
501
502
503
504
        # Populate shared context for validation
        validation_context["deployment"] = deployment
        validation_context["namespace"] = namespace

505
506
        with _clients(
            logger,
507
            request.node.name,
508
509
510
            scenario.deployment,
            namespace,
            model,
511
            scenario.load,  # Pass entire Load config object
512
        ) as client_procs:
513
            # Inject failures and capture which pods were affected
514
515
516
            affected_pods = await _inject_failures(
                scenario.failures, logger, deployment
            )
517
            logger.info(f"Affected pods during test: {affected_pods}")
518
519
520

            if scenario.load.continuous_load:
                _terminate_client_processes(client_procs, logger)