common.py 3.11 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""

6
import logging
7
import os
Alec's avatar
Alec committed
8
from collections.abc import Mapping
9
from typing import Any, Dict, Optional
10

Alec's avatar
Alec committed
11
12
import pytest

13
14
from tests.utils.client import send_request
from tests.utils.engine_process import EngineConfig, EngineProcess
15

16
DEFAULT_TIMEOUT = 10
17
18
19
20
21
22
23
24
25
26
27
28
29
30

# Determine WORKSPACE_DIR with precedence: current path -> env WORKSPACE_DIR -> /workspace
if os.path.exists(os.path.join(os.getcwd(), "Cargo.toml")):
    WORKSPACE_DIR = os.getcwd()
else:
    _workspace_dir = os.environ.get("WORKSPACE_DIR")
    if _workspace_dir:
        WORKSPACE_DIR = _workspace_dir
    elif os.path.exists("/workspace"):
        WORKSPACE_DIR = "/workspace"
    else:
        WORKSPACE_DIR = os.getcwd()

SERVE_TEST_DIR = os.path.join(WORKSPACE_DIR, "tests/serve")
31
32


33
34
35
36
37
38
def run_serve_deployment(
    config: EngineConfig,
    request: Any,
    extra_env: Optional[Dict[str, str]] = None,
) -> None:
    """Run a standard serve deployment test for any EngineConfig.
39

40
41
42
    - Launches the engine via EngineProcess.from_script
    - Builds a payload (with optional override/mutator)
    - Iterates configured endpoints and validates responses and logs
43
    """
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    logger = logging.getLogger(request.node.name)
    logger.info("Starting %s test_deployment", config.name)

    assert (
        config.request_payloads is not None and len(config.request_payloads) > 0
    ), "request_payloads must be provided on EngineConfig"

    logger.info("Using model: %s", config.model)
    logger.info("Script: %s", config.script_name)

    with EngineProcess.from_script(
        config, request, extra_env=extra_env
    ) as server_process:
        for payload in config.request_payloads:
            logger.info("TESTING: Payload: %s", payload.__class__.__name__)

            payload_item = payload
            # inject model
            if hasattr(payload_item, "with_model"):
                payload_item = payload_item.with_model(config.model)

            if payload_item.port != config.models_port:
                logger.warning(
                    f"Current payload port: {payload_item.port} doesn't match the model port: {config.models_port}"
                )

            for _ in range(payload_item.repeat_count):
                response = send_request(
                    url=payload_item.url(),
                    payload=payload_item.body,
                    timeout=payload_item.timeout,
                    method=payload_item.method,
                )
                server_process.check_response(payload_item, response)
Alec's avatar
Alec committed
79
80
81
82
83
84
85
86
87
88
89
90
91


def params_with_model_mark(configs: Mapping[str, EngineConfig]):
    """Return pytest params for a config dict, adding a model marker per param.

    This enables simple model collection after pytest filtering.
    """
    params = []
    for config_name, cfg in configs.items():
        marks = list(getattr(cfg, "marks", []))
        marks.append(pytest.mark.model(cfg.model))
        params.append(pytest.param(config_name, marks=marks))
    return params