common.py 9.59 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
# SPDX-License-Identifier: Apache-2.0

"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""

6
import dataclasses
7
import logging
8
import os
9
import time
Alec's avatar
Alec committed
10
from collections.abc import Mapping
11
from copy import deepcopy
12
from typing import Any, Dict, Optional
13

Alec's avatar
Alec committed
14
15
import pytest

16
from dynamo.common.utils.paths import WORKSPACE_DIR
17
from tests.conftest import ServicePorts
18
from tests.utils.client import send_request
19
from tests.utils.constants import DefaultPort
20
from tests.utils.engine_process import EngineConfig, EngineProcess
21
from tests.utils.port_utils import allocate_port, deallocate_port
22

23
DEFAULT_TIMEOUT = 10
24
25

SERVE_TEST_DIR = os.path.join(WORKSPACE_DIR, "tests/serve")
26
27


28
29
30
def run_serve_deployment(
    config: EngineConfig,
    request: Any,
31
32
    *,
    ports: ServicePorts | None = None,  # pass `dynamo_dynamic_ports` here
33
34
35
    extra_env: Optional[Dict[str, str]] = None,
) -> None:
    """Run a standard serve deployment test for any EngineConfig.
36

37
38
39
    - Launches the engine via EngineProcess.from_script
    - Builds a payload (with optional override/mutator)
    - Iterates configured endpoints and validates responses and logs
40
    """
41
42
43
44
45
46
47
48
49
50
51

    logger = logging.getLogger(request.node.name)
    logger.info("Starting %s test_deployment", config.name)

    assert (
        config.request_payloads is not None and len(config.request_payloads) > 0
    ), "request_payloads must be provided on EngineConfig"

    logger.info("Using model: %s", config.model)
    logger.info("Script: %s", config.script_name)

52
53
54
55
    merged_env: dict[str, str] = {}
    if extra_env:
        merged_env.update(extra_env)

56
57
58
59
60
61
62
63
64
65
    # In serial mode (no parallel scheduler), pass the marker's KV cache budget
    # so the launch script's small default doesn't starve larger models.
    # The parallel scheduler already sets this env var per-test.
    if "_PROFILE_OVERRIDE_VLLM_KV_CACHE_BYTES" not in os.environ:
        kv_mark = request.node.get_closest_marker("requested_vllm_kv_cache_bytes")
        if kv_mark:
            merged_env.setdefault(
                "_PROFILE_OVERRIDE_VLLM_KV_CACHE_BYTES", str(int(kv_mark.args[0]))
            )

66
67
68
69
70
71
72
73
74
75
    # Stagger engine startup under xdist to avoid vLLM profiling race
    # (vLLM bug #10643: concurrent profilers miscount each other's memory).
    worker_id = os.environ.get("PYTEST_XDIST_WORKER", "")
    if worker_id.startswith("gw"):
        worker_num = int(worker_id.removeprefix("gw"))
        if worker_num > 0:
            stagger_s = worker_num * 15
            logger.info("Staggering startup by %ds (xdist %s)", stagger_s, worker_id)
            time.sleep(stagger_s)

76
77
    if ports is not None:
        dynamic_frontend_port = int(ports.frontend_port)
78
79
        dynamic_system_ports = [int(p) for p in ports.system_ports]

80
        # The environments are used by the bash scripts to set the ports.
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        merged_env["DYN_HTTP_PORT"] = str(dynamic_frontend_port)

        # If no system ports are provided, explicitly ensure we don't pass any
        # stale DYN_SYSTEM_PORT* values via extra_env.
        if not dynamic_system_ports:
            for k in list(merged_env.keys()):
                if k == "DYN_SYSTEM_PORT":
                    merged_env.pop(k, None)
                    continue
                if k.startswith("DYN_SYSTEM_PORT") and k != "DYN_SYSTEM_PORT":
                    suffix = k.removeprefix("DYN_SYSTEM_PORT")
                    if suffix.isdigit():
                        merged_env.pop(k, None)
        else:
            # Alias for PORT1 (many scripts only read this).
            merged_env["DYN_SYSTEM_PORT"] = str(dynamic_system_ports[0])
            merged_env["DYN_SYSTEM_PORT1"] = str(dynamic_system_ports[0])
            for idx, port in enumerate(dynamic_system_ports, start=1):
                merged_env[f"DYN_SYSTEM_PORT{idx}"] = str(port)

101
102
103
104
        # Unique ZMQ port for vLLM KV event publishing (avoids xdist collisions).
        if ports.kv_event_port:
            merged_env["DYN_VLLM_KV_EVENT_PORT"] = str(ports.kv_event_port)

105
106
        # Ensure EngineProcess health checks hit the correct frontend port.
        config = dataclasses.replace(config, frontend_port=dynamic_frontend_port)
107

108
109
110
    else:
        # Backward compat: infer from config/extra_env if no explicit ports are passed.
        dynamic_frontend_port = int(config.frontend_port)
111
112
113
114
115
116
117
118
119
120
121
        # Preserve the historical two-port behavior in this branch. Tests that
        # need tighter control should pass `ports=...` to avoid default port
        # collisions under xdist.
        dynamic_system_ports = [
            int(
                merged_env.get("DYN_SYSTEM_PORT1")
                or merged_env.get("DYN_SYSTEM_PORT")
                or DefaultPort.SYSTEM1.value
            ),
            int(merged_env.get("DYN_SYSTEM_PORT2") or DefaultPort.SYSTEM2.value),
        ]
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    # Disagg scripts need a unique bootstrap port so parallel runs don't collide.
    disagg_bootstrap_port: int | None = None
    if config.script_name and "disagg" in config.script_name:
        disagg_bootstrap_port = allocate_port(12000)
        merged_env["DYN_DISAGG_BOOTSTRAP_PORT"] = str(disagg_bootstrap_port)

    try:
        with EngineProcess.from_script(
            config, request, extra_env=merged_env
        ) as server_process:
            for _payload in config.request_payloads:
                logger.info("TESTING: Payload: %s", _payload.__class__.__name__)

                # Make a per-iteration copy so tests can safely override ports/fields
                # without mutating shared config instances across parametrized cases.
                payload = deepcopy(_payload)
                # inject model
                if hasattr(payload, "with_model"):
                    payload = payload.with_model(config.model)

                # Default behavior: requests go to the frontend port, except metrics which target
                # worker system ports (mapped from DefaultPort -> per-test ports).
                if getattr(payload, "endpoint", "") == "/metrics":
                    if payload.port == DefaultPort.SYSTEM1.value:
147
148
                        if len(dynamic_system_ports) < 1:
                            raise RuntimeError(
149
                                "Payload targets SYSTEM_PORT1 but no system ports were provided "
150
151
                                f"(payload={payload.__class__.__name__})"
                            )
152
153
                        payload.port = dynamic_system_ports[0]
                    elif payload.port == DefaultPort.SYSTEM2.value:
154
155
                        if len(dynamic_system_ports) < 2:
                            raise RuntimeError(
156
                                "Payload targets SYSTEM_PORT2 but only 1 system port was provided "
157
158
                                f"(payload={payload.__class__.__name__})"
                            )
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                        payload.port = dynamic_system_ports[1]
                else:
                    payload.port = dynamic_frontend_port

                # Optional extra system ports for specialized payloads (e.g. LoRA control-plane APIs).
                # BasePayload always defines `system_ports` (usually empty); map defaults
                # (SYSTEM_PORT1/2) to per-test system ports when present.
                if payload.system_ports:
                    mapped_system_ports: list[int] = []
                    for p in payload.system_ports:
                        if p == DefaultPort.SYSTEM1.value:
                            if len(dynamic_system_ports) < 1:
                                raise RuntimeError(
                                    "Payload.system_ports includes SYSTEM_PORT1 but no system ports were provided "
                                    f"(payload={payload.__class__.__name__})"
                                )
                            mapped_system_ports.append(dynamic_system_ports[0])
                        elif p == DefaultPort.SYSTEM2.value:
                            if len(dynamic_system_ports) < 2:
                                raise RuntimeError(
                                    "Payload.system_ports includes SYSTEM_PORT2 but only 1 system port was provided "
                                    f"(payload={payload.__class__.__name__})"
                                )
                            mapped_system_ports.append(dynamic_system_ports[1])
                        else:
                            mapped_system_ports.append(p)
                    payload.system_ports = mapped_system_ports

                for _ in range(payload.repeat_count):
                    response = send_request(
                        url=payload.url(),
                        payload=payload.body,
                        timeout=payload.timeout,
                        method=payload.method,
                        stream=payload.http_stream,
                    )
                    server_process.check_response(payload, response)

                # Call final_validation if the payload has one (e.g., CachedTokensChatPayload)
                if hasattr(payload, "final_validation"):
                    payload.final_validation()
    finally:
        if disagg_bootstrap_port is not None:
            deallocate_port(disagg_bootstrap_port)
203

Alec's avatar
Alec committed
204
205
206
207
208
209
210
211
212
213
214
215

def params_with_model_mark(configs: Mapping[str, EngineConfig]):
    """Return pytest params for a config dict, adding a model marker per param.

    This enables simple model collection after pytest filtering.
    """
    params = []
    for config_name, cfg in configs.items():
        marks = list(getattr(cfg, "marks", []))
        marks.append(pytest.mark.model(cfg.model))
        params.append(pytest.param(config_name, marks=marks))
    return params