common.py 2.66 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""

6
import logging
Alec's avatar
Alec committed
7
from collections.abc import Mapping
8
from typing import Any, Dict, Optional
9

Alec's avatar
Alec committed
10
11
import pytest

12
13
from tests.utils.client import send_request
from tests.utils.engine_process import EngineConfig, EngineProcess
14

15
DEFAULT_TIMEOUT = 10
Alec's avatar
Alec committed
16
SERVE_TEST_DIR = "/workspace/tests/serve"
17
18


19
20
21
22
23
24
def run_serve_deployment(
    config: EngineConfig,
    request: Any,
    extra_env: Optional[Dict[str, str]] = None,
) -> None:
    """Run a standard serve deployment test for any EngineConfig.
25

26
27
28
    - Launches the engine via EngineProcess.from_script
    - Builds a payload (with optional override/mutator)
    - Iterates configured endpoints and validates responses and logs
29
    """
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    logger = logging.getLogger(request.node.name)
    logger.info("Starting %s test_deployment", config.name)

    assert (
        config.request_payloads is not None and len(config.request_payloads) > 0
    ), "request_payloads must be provided on EngineConfig"

    logger.info("Using model: %s", config.model)
    logger.info("Script: %s", config.script_name)

    with EngineProcess.from_script(
        config, request, extra_env=extra_env
    ) as server_process:
        for payload in config.request_payloads:
            logger.info("TESTING: Payload: %s", payload.__class__.__name__)

            payload_item = payload
            # inject model
            if hasattr(payload_item, "with_model"):
                payload_item = payload_item.with_model(config.model)

            if payload_item.port != config.models_port:
                logger.warning(
                    f"Current payload port: {payload_item.port} doesn't match the model port: {config.models_port}"
                )

            for _ in range(payload_item.repeat_count):
                response = send_request(
                    url=payload_item.url(),
                    payload=payload_item.body,
                    timeout=payload_item.timeout,
                    method=payload_item.method,
                )
                server_process.check_response(payload_item, response)
Alec's avatar
Alec committed
65
66
67
68
69
70
71
72
73
74
75
76
77


def params_with_model_mark(configs: Mapping[str, EngineConfig]):
    """Return pytest params for a config dict, adding a model marker per param.

    This enables simple model collection after pytest filtering.
    """
    params = []
    for config_name, cfg in configs.items():
        marks = list(getattr(cfg, "marks", []))
        marks.append(pytest.mark.model(cfg.model))
        params.append(pytest.param(config_name, marks=marks))
    return params