common.py 5.49 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""

6
import dataclasses
7
import logging
8
import os
Alec's avatar
Alec committed
9
from collections.abc import Mapping
10
from copy import deepcopy
11
from typing import Any, Dict, Optional
12

Alec's avatar
Alec committed
13
14
import pytest

15
from dynamo.common.utils.paths import WORKSPACE_DIR
16
from tests.serve.conftest import ServicePorts
17
from tests.utils.client import send_request
18
from tests.utils.constants import DefaultPort
19
from tests.utils.engine_process import EngineConfig, EngineProcess
20

21
DEFAULT_TIMEOUT = 10
22
23

SERVE_TEST_DIR = os.path.join(WORKSPACE_DIR, "tests/serve")
24
25


26
27
28
def run_serve_deployment(
    config: EngineConfig,
    request: Any,
29
30
    *,
    ports: ServicePorts | None = None,  # pass `dynamo_dynamic_ports` here
31
32
33
    extra_env: Optional[Dict[str, str]] = None,
) -> None:
    """Run a standard serve deployment test for any EngineConfig.
34

35
36
37
    - Launches the engine via EngineProcess.from_script
    - Builds a payload (with optional override/mutator)
    - Iterates configured endpoints and validates responses and logs
38
    """
39
40
41
42
43
44
45
46
47
48
49

    logger = logging.getLogger(request.node.name)
    logger.info("Starting %s test_deployment", config.name)

    assert (
        config.request_payloads is not None and len(config.request_payloads) > 0
    ), "request_payloads must be provided on EngineConfig"

    logger.info("Using model: %s", config.model)
    logger.info("Script: %s", config.script_name)

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    merged_env: dict[str, str] = {}
    if extra_env:
        merged_env.update(extra_env)

    if ports is not None:
        dynamic_frontend_port = int(ports.frontend_port)
        dynamic_system_port1 = int(ports.system_port1)
        dynamic_system_port2 = int(ports.system_port2)
        # The environments are used by the bash scripts to set the ports.
        merged_env.update(
            {
                "DYN_HTTP_PORT": str(dynamic_frontend_port),
                # Alias for PORT1 (many scripts only read this).
                "DYN_SYSTEM_PORT": str(dynamic_system_port1),
                "DYN_SYSTEM_PORT1": str(dynamic_system_port1),
                "DYN_SYSTEM_PORT2": str(dynamic_system_port2),
            }
        )
        # Ensure EngineProcess health checks hit the correct frontend port.
        config = dataclasses.replace(config, frontend_port=dynamic_frontend_port)
    else:
        # Backward compat: infer from config/extra_env if no explicit ports are passed.
        dynamic_frontend_port = int(config.frontend_port)
        dynamic_system_port1 = int(
            merged_env.get("DYN_SYSTEM_PORT1")
            or merged_env.get("DYN_SYSTEM_PORT")
            or DefaultPort.SYSTEM1.value
        )
        dynamic_system_port2 = int(
            merged_env.get("DYN_SYSTEM_PORT2") or DefaultPort.SYSTEM2.value
        )

82
    with EngineProcess.from_script(
83
        config, request, extra_env=merged_env
84
    ) as server_process:
85
86
        for _payload in config.request_payloads:
            logger.info("TESTING: Payload: %s", _payload.__class__.__name__)
87

88
89
90
            # Make a per-iteration copy so tests can safely override ports/fields
            # without mutating shared config instances across parametrized cases.
            payload = deepcopy(_payload)
91
            # inject model
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            if hasattr(payload, "with_model"):
                payload = payload.with_model(config.model)

            # Default behavior: requests go to the frontend port, except metrics which target
            # worker system ports (mapped from DefaultPort -> per-test ports).
            if getattr(payload, "endpoint", "") == "/metrics":
                if payload.port == DefaultPort.SYSTEM1.value:
                    payload.port = dynamic_system_port1
                elif payload.port == DefaultPort.SYSTEM2.value:
                    payload.port = dynamic_system_port2
            else:
                payload.port = dynamic_frontend_port

            # Optional extra system ports for specialized payloads (e.g. LoRA control-plane APIs).
            # BasePayload always defines `system_ports` (usually empty); map defaults
            # (SYSTEM_PORT1/2) to per-test system ports when present.
            if payload.system_ports:
                mapped_system_ports: list[int] = []
                for p in payload.system_ports:
                    if p == DefaultPort.SYSTEM1.value:
                        mapped_system_ports.append(dynamic_system_port1)
                    elif p == DefaultPort.SYSTEM2.value:
                        mapped_system_ports.append(dynamic_system_port2)
                    else:
                        mapped_system_ports.append(p)
                payload.system_ports = mapped_system_ports

            for _ in range(payload.repeat_count):
120
                response = send_request(
121
122
123
124
                    url=payload.url(),
                    payload=payload.body,
                    timeout=payload.timeout,
                    method=payload.method,
125
                )
126
                server_process.check_response(payload, response)
Alec's avatar
Alec committed
127
128
129
130
131
132
133
134
135
136
137
138
139


def params_with_model_mark(configs: Mapping[str, EngineConfig]):
    """Return pytest params for a config dict, adding a model marker per param.

    This enables simple model collection after pytest filtering.
    """
    params = []
    for config_name, cfg in configs.items():
        marks = list(getattr(cfg, "marks", []))
        marks.append(pytest.mark.model(cfg.model))
        params.append(pytest.param(config_name, marks=marks))
    return params