test_deployment.py 7.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import multiprocessing
import time
from contextlib import contextmanager

import pytest

11
12
13
from tests.fault_tolerance.deploy.client_factory import get_client_function
from tests.fault_tolerance.deploy.parse_factory import parse_test_results
from tests.fault_tolerance.deploy.scenarios import Load, scenarios
14
15
16
17
from tests.utils.managed_deployment import ManagedDeployment


@pytest.fixture(params=scenarios.keys())
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def scenario(request, client_type):
    """Get scenario and optionally override client type from command line.

    If --client-type is specified, it overrides the scenario's default client type.
    """
    scenario_obj = scenarios[request.param]

    # Override client type if specified on command line
    if client_type is not None:
        # Create a copy of the load config with overridden client type
        import copy

        scenario_obj = copy.deepcopy(scenario_obj)
        scenario_obj.load.client_type = client_type

        # Adjust retry settings based on client type
        if client_type == "legacy":
            # Legacy uses per-request retries
            if scenario_obj.load.max_retries > 1:
                scenario_obj.load.max_retries = 1
        elif client_type == "aiperf":
            # AI-Perf uses full test retries
            if scenario_obj.load.max_retries < 3:
                scenario_obj.load.max_retries = 3

    return scenario_obj
44
45
46
47
48
49
50
51
52


@contextmanager
def _clients(
    logger,
    request,
    deployment_spec,
    namespace,
    model,
53
    load_config: Load,
54
):
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    """Start client processes using factory pattern for client selection.

    Args:
        logger: Logger instance
        request: Pytest request fixture
        deployment_spec: Deployment specification
        namespace: Kubernetes namespace
        model: Model name to test
        load_config: Load configuration object containing client settings
    """
    # Get appropriate client function based on configuration
    client_func = get_client_function(load_config.client_type)

    logger.info(
        f"Starting {load_config.clients} clients using '{load_config.client_type}' client"
    )

72
73
    procs = []
    ctx = multiprocessing.get_context("spawn")
74
75
76
77
78
79
80
81
82
83

    # Determine retry_delay_or_rate based on client type
    if load_config.client_type == "legacy":
        # Legacy client uses max_request_rate for rate limiting
        retry_delay_or_rate = load_config.max_request_rate
    else:
        # AI-Perf client uses retry_delay between attempts (default 5s)
        retry_delay_or_rate = 5

    for i in range(load_config.clients):
84
85
        procs.append(
            ctx.Process(
86
                target=client_func,
87
88
89
90
91
92
                args=(
                    deployment_spec,
                    namespace,
                    model,
                    request.node.name,
                    i,
93
94
95
96
97
                    load_config.requests_per_client,
                    load_config.input_token_length,
                    load_config.output_token_length,
                    load_config.max_retries,
                    retry_delay_or_rate,
98
99
100
101
                ),
            )
        )
        procs[-1].start()
102
103
        logger.debug(f"Started client {i} (PID: {procs[-1].pid})")

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    yield procs

    for proc in procs:
        logger.debug(f"{proc} waiting for join")
        proc.join()
        logger.debug(f"{proc} joined")


def _inject_failures(failures, logger, deployment: ManagedDeployment):  # noqa: F811
    for failure in failures:
        time.sleep(failure.time)

        pods = deployment.get_pods(failure.pod_name)[failure.pod_name]

        num_pods = len(pods)

        if not pods:
            continue

        replicas = failure.replicas

        if not replicas:
            replicas = num_pods

        logger.info(f"Injecting failure for: {failure}")

        for x in range(replicas):
            pod = pods[x % num_pods]

            if failure.command == "delete_pod":
                deployment.get_pod_logs(failure.pod_name, pod, ".before_delete")
                pod.delete(force=True)
            else:
                processes = deployment.get_processes(pod)
                for process in processes:
                    if failure.command in process.command:
                        logger.info(
                            f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command}"
                        )
                        process.kill(failure.signal)


global_result_list = []


@pytest.fixture(autouse=True)
def results_table(request, scenario):  # noqa: F811
151
152
153
154
155
    """Parse and display results for individual test using factory pattern.

    Automatically detects result type (AI-Perf or legacy) and uses
    the appropriate parser.
    """
156
    yield
157
158
159
160
161
162
163
164
165
166
167
168
169
170

    # Use factory to auto-detect and parse results
    try:
        parse_test_results(
            log_dir=None,
            log_paths=[request.node.name],
            tablefmt="fancy_grid",
            sla=scenario.load.sla,
            # force_parser can be set based on client_type if needed
            # force_parser=scenario.load.client_type,
        )
    except Exception as e:
        logging.error(f"Failed to parse results for {request.node.name}: {e}")

171
172
173
174
175
    global_result_list.append(request.node.name)


@pytest.fixture(autouse=True, scope="session")
def results_summary():
176
177
178
179
    """Parse and display combined results for all tests in session.

    Automatically detects result types and uses appropriate parsers.
    """
180
    yield
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    if not global_result_list:
        logging.info("No test results to summarize")
        return

    # Use factory to auto-detect and parse combined results
    try:
        parse_test_results(
            log_dir=None,
            log_paths=global_result_list,
            tablefmt="fancy_grid",
        )
    except Exception as e:
        logging.error(f"Failed to parse combined results: {e}")
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220


@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
async def test_fault_scenario(
    scenario,  # noqa: F811
    request,
    image,
    namespace,
):
    """
    Test dynamo serve deployments with injected failures
    """

    logger = logging.getLogger(request.node.name)

    scenario.deployment.name = "fault-tolerance-test"

    if image:
        scenario.deployment.set_image(image)

    if scenario.model:
        scenario.deployment.set_model(scenario.model)
        model = scenario.model
    else:
221
222
223
224
225
226
        # Get model from the appropriate worker based on backend
        try:
            if scenario.backend == "vllm":
                model = scenario.deployment["VllmDecodeWorker"].model
            elif scenario.backend == "sglang":
                model = scenario.deployment["decode"].model
227
228
229
230
231
232
233
234
235
            elif scenario.backend == "trtllm":
                # Determine deployment type from scenario deployment name
                if (
                    "agg" in scenario.deployment.name
                    and "disagg" not in scenario.deployment.name
                ):
                    model = scenario.deployment["TRTLLMWorker"].model
                else:
                    model = scenario.deployment["TRTLLMDecodeWorker"].model
236
237
238
239
240
241
            else:
                model = None
        except (KeyError, AttributeError):
            model = None
    # Fallback to default if still None
    model = model or "Qwen/Qwen3-0.6B"
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    scenario.deployment.set_logging(True, "info")

    async with ManagedDeployment(
        namespace=namespace,
        log_dir=request.node.name,
        deployment_spec=scenario.deployment,
    ) as deployment:
        with _clients(
            logger,
            request,
            scenario.deployment,
            namespace,
            model,
256
            scenario.load,  # Pass entire Load config object
257
258
        ):
            _inject_failures(scenario.failures, logger, deployment)