common.py 8.34 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
# SPDX-License-Identifier: Apache-2.0

"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""

6
import dataclasses
7
import logging
8
import os
9
import time
Alec's avatar
Alec committed
10
from collections.abc import Mapping
11
from copy import deepcopy
12
from typing import Any, Dict, Optional
13

Alec's avatar
Alec committed
14
15
import pytest

16
from dynamo.common.utils.paths import WORKSPACE_DIR
17
from tests.conftest import ServicePorts
18
from tests.utils.client import send_request
19
from tests.utils.constants import DefaultPort
20
from tests.utils.engine_process import EngineConfig, EngineProcess
21

22
DEFAULT_TIMEOUT = 10
23
24

SERVE_TEST_DIR = os.path.join(WORKSPACE_DIR, "tests/serve")
25
26


27
28
29
def run_serve_deployment(
    config: EngineConfig,
    request: Any,
30
31
    *,
    ports: ServicePorts | None = None,  # pass `dynamo_dynamic_ports` here
32
33
34
    extra_env: Optional[Dict[str, str]] = None,
) -> None:
    """Run a standard serve deployment test for any EngineConfig.
35

36
37
38
    - Launches the engine via EngineProcess.from_script
    - Builds a payload (with optional override/mutator)
    - Iterates configured endpoints and validates responses and logs
39
    """
40
41
42
43
44
45
46
47
48
49
50

    logger = logging.getLogger(request.node.name)
    logger.info("Starting %s test_deployment", config.name)

    assert (
        config.request_payloads is not None and len(config.request_payloads) > 0
    ), "request_payloads must be provided on EngineConfig"

    logger.info("Using model: %s", config.model)
    logger.info("Script: %s", config.script_name)

51
52
53
54
    merged_env: dict[str, str] = {}
    if extra_env:
        merged_env.update(extra_env)

55
56
57
58
59
60
61
62
63
64
    # Stagger engine startup under xdist to avoid vLLM profiling race
    # (vLLM bug #10643: concurrent profilers miscount each other's memory).
    worker_id = os.environ.get("PYTEST_XDIST_WORKER", "")
    if worker_id.startswith("gw"):
        worker_num = int(worker_id.removeprefix("gw"))
        if worker_num > 0:
            stagger_s = worker_num * 15
            logger.info("Staggering startup by %ds (xdist %s)", stagger_s, worker_id)
            time.sleep(stagger_s)

65
66
    if ports is not None:
        dynamic_frontend_port = int(ports.frontend_port)
67
68
        dynamic_system_ports = [int(p) for p in ports.system_ports]

69
        # The environments are used by the bash scripts to set the ports.
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        merged_env["DYN_HTTP_PORT"] = str(dynamic_frontend_port)

        # If no system ports are provided, explicitly ensure we don't pass any
        # stale DYN_SYSTEM_PORT* values via extra_env.
        if not dynamic_system_ports:
            for k in list(merged_env.keys()):
                if k == "DYN_SYSTEM_PORT":
                    merged_env.pop(k, None)
                    continue
                if k.startswith("DYN_SYSTEM_PORT") and k != "DYN_SYSTEM_PORT":
                    suffix = k.removeprefix("DYN_SYSTEM_PORT")
                    if suffix.isdigit():
                        merged_env.pop(k, None)
        else:
            # Alias for PORT1 (many scripts only read this).
            merged_env["DYN_SYSTEM_PORT"] = str(dynamic_system_ports[0])
            merged_env["DYN_SYSTEM_PORT1"] = str(dynamic_system_ports[0])
            for idx, port in enumerate(dynamic_system_ports, start=1):
                merged_env[f"DYN_SYSTEM_PORT{idx}"] = str(port)

90
91
92
93
        # Unique ZMQ port for vLLM KV event publishing (avoids xdist collisions).
        if ports.kv_event_port:
            merged_env["DYN_VLLM_KV_EVENT_PORT"] = str(ports.kv_event_port)

94
95
96
97
98
        # Ensure EngineProcess health checks hit the correct frontend port.
        config = dataclasses.replace(config, frontend_port=dynamic_frontend_port)
    else:
        # Backward compat: infer from config/extra_env if no explicit ports are passed.
        dynamic_frontend_port = int(config.frontend_port)
99
100
101
102
103
104
105
106
107
108
109
        # Preserve the historical two-port behavior in this branch. Tests that
        # need tighter control should pass `ports=...` to avoid default port
        # collisions under xdist.
        dynamic_system_ports = [
            int(
                merged_env.get("DYN_SYSTEM_PORT1")
                or merged_env.get("DYN_SYSTEM_PORT")
                or DefaultPort.SYSTEM1.value
            ),
            int(merged_env.get("DYN_SYSTEM_PORT2") or DefaultPort.SYSTEM2.value),
        ]
110

111
    with EngineProcess.from_script(
112
        config, request, extra_env=merged_env
113
    ) as server_process:
114
115
        for _payload in config.request_payloads:
            logger.info("TESTING: Payload: %s", _payload.__class__.__name__)
116

117
118
119
            # Make a per-iteration copy so tests can safely override ports/fields
            # without mutating shared config instances across parametrized cases.
            payload = deepcopy(_payload)
120
            # inject model
121
122
123
124
125
126
127
            if hasattr(payload, "with_model"):
                payload = payload.with_model(config.model)

            # Default behavior: requests go to the frontend port, except metrics which target
            # worker system ports (mapped from DefaultPort -> per-test ports).
            if getattr(payload, "endpoint", "") == "/metrics":
                if payload.port == DefaultPort.SYSTEM1.value:
128
129
130
131
132
133
                    if len(dynamic_system_ports) < 1:
                        raise RuntimeError(
                            "Payload targets SYSTEM_PORT1 but no system ports were provided "
                            f"(payload={payload.__class__.__name__})"
                        )
                    payload.port = dynamic_system_ports[0]
134
                elif payload.port == DefaultPort.SYSTEM2.value:
135
136
137
138
139
140
                    if len(dynamic_system_ports) < 2:
                        raise RuntimeError(
                            "Payload targets SYSTEM_PORT2 but only 1 system port was provided "
                            f"(payload={payload.__class__.__name__})"
                        )
                    payload.port = dynamic_system_ports[1]
141
142
143
144
145
146
147
148
149
150
            else:
                payload.port = dynamic_frontend_port

            # Optional extra system ports for specialized payloads (e.g. LoRA control-plane APIs).
            # BasePayload always defines `system_ports` (usually empty); map defaults
            # (SYSTEM_PORT1/2) to per-test system ports when present.
            if payload.system_ports:
                mapped_system_ports: list[int] = []
                for p in payload.system_ports:
                    if p == DefaultPort.SYSTEM1.value:
151
152
153
154
155
156
                        if len(dynamic_system_ports) < 1:
                            raise RuntimeError(
                                "Payload.system_ports includes SYSTEM_PORT1 but no system ports were provided "
                                f"(payload={payload.__class__.__name__})"
                            )
                        mapped_system_ports.append(dynamic_system_ports[0])
157
                    elif p == DefaultPort.SYSTEM2.value:
158
159
160
161
162
163
                        if len(dynamic_system_ports) < 2:
                            raise RuntimeError(
                                "Payload.system_ports includes SYSTEM_PORT2 but only 1 system port was provided "
                                f"(payload={payload.__class__.__name__})"
                            )
                        mapped_system_ports.append(dynamic_system_ports[1])
164
165
166
167
168
                    else:
                        mapped_system_ports.append(p)
                payload.system_ports = mapped_system_ports

            for _ in range(payload.repeat_count):
169
                response = send_request(
170
171
172
173
                    url=payload.url(),
                    payload=payload.body,
                    timeout=payload.timeout,
                    method=payload.method,
174
                    stream=payload.http_stream,
175
                )
176
                server_process.check_response(payload, response)
Alec's avatar
Alec committed
177

178
179
180
181
            # Call final_validation if the payload has one (e.g., CachedTokensChatPayload)
            if hasattr(payload, "final_validation"):
                payload.final_validation()

Alec's avatar
Alec committed
182
183
184
185
186
187
188
189
190
191
192
193

def params_with_model_mark(configs: Mapping[str, EngineConfig]):
    """Return pytest params for a config dict, adding a model marker per param.

    This enables simple model collection after pytest filtering.
    """
    params = []
    for config_name, cfg in configs.items():
        marks = list(getattr(cfg, "marks", []))
        marks.append(pytest.mark.model(cfg.model))
        params.append(pytest.param(config_name, marks=marks))
    return params