common.py 9.08 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
# SPDX-License-Identifier: Apache-2.0

"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""

6
import dataclasses
7
import logging
8
import os
9
import time
Alec's avatar
Alec committed
10
from collections.abc import Mapping
11
from copy import deepcopy
12
from typing import Any, Dict, Optional
13

Alec's avatar
Alec committed
14
15
import pytest

16
from dynamo.common.utils.paths import WORKSPACE_DIR
17
from tests.conftest import ServicePorts
18
from tests.utils.client import send_request
19
from tests.utils.constants import DefaultPort
20
from tests.utils.engine_process import EngineConfig, EngineProcess
21
from tests.utils.port_utils import allocate_port, deallocate_port
22

23
DEFAULT_TIMEOUT = 10
24
25

SERVE_TEST_DIR = os.path.join(WORKSPACE_DIR, "tests/serve")
26
27


28
29
30
def run_serve_deployment(
    config: EngineConfig,
    request: Any,
31
32
    *,
    ports: ServicePorts | None = None,  # pass `dynamo_dynamic_ports` here
33
34
35
    extra_env: Optional[Dict[str, str]] = None,
) -> None:
    """Run a standard serve deployment test for any EngineConfig.
36

37
38
39
    - Launches the engine via EngineProcess.from_script
    - Builds a payload (with optional override/mutator)
    - Iterates configured endpoints and validates responses and logs
40
    """
41
42
43
44
45
46
47
48
49
50
51

    logger = logging.getLogger(request.node.name)
    logger.info("Starting %s test_deployment", config.name)

    assert (
        config.request_payloads is not None and len(config.request_payloads) > 0
    ), "request_payloads must be provided on EngineConfig"

    logger.info("Using model: %s", config.model)
    logger.info("Script: %s", config.script_name)

52
53
54
55
    merged_env: dict[str, str] = {}
    if extra_env:
        merged_env.update(extra_env)

56
57
58
59
60
61
62
63
64
65
    # Stagger engine startup under xdist to avoid vLLM profiling race
    # (vLLM bug #10643: concurrent profilers miscount each other's memory).
    worker_id = os.environ.get("PYTEST_XDIST_WORKER", "")
    if worker_id.startswith("gw"):
        worker_num = int(worker_id.removeprefix("gw"))
        if worker_num > 0:
            stagger_s = worker_num * 15
            logger.info("Staggering startup by %ds (xdist %s)", stagger_s, worker_id)
            time.sleep(stagger_s)

66
67
    if ports is not None:
        dynamic_frontend_port = int(ports.frontend_port)
68
69
        dynamic_system_ports = [int(p) for p in ports.system_ports]

70
        # The environments are used by the bash scripts to set the ports.
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        merged_env["DYN_HTTP_PORT"] = str(dynamic_frontend_port)

        # If no system ports are provided, explicitly ensure we don't pass any
        # stale DYN_SYSTEM_PORT* values via extra_env.
        if not dynamic_system_ports:
            for k in list(merged_env.keys()):
                if k == "DYN_SYSTEM_PORT":
                    merged_env.pop(k, None)
                    continue
                if k.startswith("DYN_SYSTEM_PORT") and k != "DYN_SYSTEM_PORT":
                    suffix = k.removeprefix("DYN_SYSTEM_PORT")
                    if suffix.isdigit():
                        merged_env.pop(k, None)
        else:
            # Alias for PORT1 (many scripts only read this).
            merged_env["DYN_SYSTEM_PORT"] = str(dynamic_system_ports[0])
            merged_env["DYN_SYSTEM_PORT1"] = str(dynamic_system_ports[0])
            for idx, port in enumerate(dynamic_system_ports, start=1):
                merged_env[f"DYN_SYSTEM_PORT{idx}"] = str(port)

91
92
93
94
        # Unique ZMQ port for vLLM KV event publishing (avoids xdist collisions).
        if ports.kv_event_port:
            merged_env["DYN_VLLM_KV_EVENT_PORT"] = str(ports.kv_event_port)

95
96
        # Ensure EngineProcess health checks hit the correct frontend port.
        config = dataclasses.replace(config, frontend_port=dynamic_frontend_port)
97

98
99
100
    else:
        # Backward compat: infer from config/extra_env if no explicit ports are passed.
        dynamic_frontend_port = int(config.frontend_port)
101
102
103
104
105
106
107
108
109
110
111
        # Preserve the historical two-port behavior in this branch. Tests that
        # need tighter control should pass `ports=...` to avoid default port
        # collisions under xdist.
        dynamic_system_ports = [
            int(
                merged_env.get("DYN_SYSTEM_PORT1")
                or merged_env.get("DYN_SYSTEM_PORT")
                or DefaultPort.SYSTEM1.value
            ),
            int(merged_env.get("DYN_SYSTEM_PORT2") or DefaultPort.SYSTEM2.value),
        ]
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    # Disagg scripts need a unique bootstrap port so parallel runs don't collide.
    disagg_bootstrap_port: int | None = None
    if config.script_name and "disagg" in config.script_name:
        disagg_bootstrap_port = allocate_port(12000)
        merged_env["DYN_DISAGG_BOOTSTRAP_PORT"] = str(disagg_bootstrap_port)

    try:
        with EngineProcess.from_script(
            config, request, extra_env=merged_env
        ) as server_process:
            for _payload in config.request_payloads:
                logger.info("TESTING: Payload: %s", _payload.__class__.__name__)

                # Make a per-iteration copy so tests can safely override ports/fields
                # without mutating shared config instances across parametrized cases.
                payload = deepcopy(_payload)
                # inject model
                if hasattr(payload, "with_model"):
                    payload = payload.with_model(config.model)

                # Default behavior: requests go to the frontend port, except metrics which target
                # worker system ports (mapped from DefaultPort -> per-test ports).
                if getattr(payload, "endpoint", "") == "/metrics":
                    if payload.port == DefaultPort.SYSTEM1.value:
137
138
                        if len(dynamic_system_ports) < 1:
                            raise RuntimeError(
139
                                "Payload targets SYSTEM_PORT1 but no system ports were provided "
140
141
                                f"(payload={payload.__class__.__name__})"
                            )
142
143
                        payload.port = dynamic_system_ports[0]
                    elif payload.port == DefaultPort.SYSTEM2.value:
144
145
                        if len(dynamic_system_ports) < 2:
                            raise RuntimeError(
146
                                "Payload targets SYSTEM_PORT2 but only 1 system port was provided "
147
148
                                f"(payload={payload.__class__.__name__})"
                            )
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                        payload.port = dynamic_system_ports[1]
                else:
                    payload.port = dynamic_frontend_port

                # Optional extra system ports for specialized payloads (e.g. LoRA control-plane APIs).
                # BasePayload always defines `system_ports` (usually empty); map defaults
                # (SYSTEM_PORT1/2) to per-test system ports when present.
                if payload.system_ports:
                    mapped_system_ports: list[int] = []
                    for p in payload.system_ports:
                        if p == DefaultPort.SYSTEM1.value:
                            if len(dynamic_system_ports) < 1:
                                raise RuntimeError(
                                    "Payload.system_ports includes SYSTEM_PORT1 but no system ports were provided "
                                    f"(payload={payload.__class__.__name__})"
                                )
                            mapped_system_ports.append(dynamic_system_ports[0])
                        elif p == DefaultPort.SYSTEM2.value:
                            if len(dynamic_system_ports) < 2:
                                raise RuntimeError(
                                    "Payload.system_ports includes SYSTEM_PORT2 but only 1 system port was provided "
                                    f"(payload={payload.__class__.__name__})"
                                )
                            mapped_system_ports.append(dynamic_system_ports[1])
                        else:
                            mapped_system_ports.append(p)
                    payload.system_ports = mapped_system_ports

                for _ in range(payload.repeat_count):
                    response = send_request(
                        url=payload.url(),
                        payload=payload.body,
                        timeout=payload.timeout,
                        method=payload.method,
                        stream=payload.http_stream,
                    )
                    server_process.check_response(payload, response)

                # Call final_validation if the payload has one (e.g., CachedTokensChatPayload)
                if hasattr(payload, "final_validation"):
                    payload.final_validation()
    finally:
        if disagg_bootstrap_port is not None:
            deallocate_port(disagg_bootstrap_port)
193

Alec's avatar
Alec committed
194
195
196
197
198
199
200
201
202
203
204
205

def params_with_model_mark(configs: Mapping[str, EngineConfig]):
    """Return pytest params for a config dict, adding a model marker per param.

    This enables simple model collection after pytest filtering.
    """
    params = []
    for config_name, cfg in configs.items():
        marks = list(getattr(cfg, "marks", []))
        marks.append(pytest.mark.model(cfg.model))
        params.append(pytest.param(config_name, marks=marks))
    return params