test_trtllm.py 4.19 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
6
from dataclasses import dataclass, field
7
8
9

import pytest

10
11
12
from tests.serve.common import run_serve_deployment
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import chat_payload_default, completion_payload_default
13
14
15
16
17

logger = logging.getLogger(__name__)


@dataclass
18
class TRTLLMConfig(EngineConfig):
19
20
    """Configuration for trtllm test scenarios"""

21
    stragglers: list[str] = field(default_factory=lambda: ["TRTLLM:EngineCore"])
22

23

24
trtllm_dir = os.environ.get("TRTLLM_DIR", "/workspace/components/backends/trtllm")
25

26
27
28
29
# trtllm test configurations
trtllm_configs = {
    "aggregated": TRTLLMConfig(
        name="aggregated",
30
        directory=trtllm_dir,
31
        script_name="agg.sh",
32
        marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
33
        model="Qwen/Qwen3-0.6B",
34
35
36
37
38
        models_port=8000,
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
        ],
39
40
41
    ),
    "disaggregated": TRTLLMConfig(
        name="disaggregated",
42
        directory=trtllm_dir,
43
        script_name="disagg.sh",
44
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
45
        model="Qwen/Qwen3-0.6B",
46
47
48
49
50
        models_port=8000,
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
        ],
51
52
53
    ),
    "aggregated_router": TRTLLMConfig(
        name="aggregated_router",
54
        directory=trtllm_dir,
55
        script_name="agg_router.sh",
56
        marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
57
        model="Qwen/Qwen3-0.6B",
58
59
60
61
62
63
64
65
66
67
68
69
70
        models_port=8000,
        request_payloads=[
            chat_payload_default(
                expected_log=[
                    r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
                    r"Event processor for worker_id \d+ processing event: Stored\(",
                    r"Selected worker: \d+, logit: ",
                ]
            )
        ],
        env={
            "DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
        },
71
72
73
    ),
    "disaggregated_router": TRTLLMConfig(
        name="disaggregated_router",
74
        directory=trtllm_dir,
75
        script_name="disagg_router.sh",
76
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
77
        model="Qwen/Qwen3-0.6B",
78
79
80
81
82
        models_port=8000,
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
        ],
83
    ),
84
85
86
87
88
89
90
91
92
93
94
95
96
97
}


@pytest.fixture(
    params=[
        pytest.param(config_name, marks=config.marks)
        for config_name, config in trtllm_configs.items()
    ]
)
def trtllm_config_test(request):
    """Fixture that provides different trtllm test configurations"""
    return trtllm_configs[request.param]


98
@pytest.mark.trtllm_marker
99
100
101
102
103
104
@pytest.mark.e2e
def test_deployment(trtllm_config_test, request, runtime_services):
    """
    Test dynamo deployments with different configurations.
    """
    config = trtllm_config_test
105
106
    extra_env = {"MODEL_PATH": config.model, "SERVED_MODEL_NAME": config.model}
    run_serve_deployment(config, request, extra_env=extra_env)
107
108


109
# TODO make this a normal guy
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.trtllm_marker
@pytest.mark.slow
def test_chat_only_aggregated_with_test_logits_processor(
    request, runtime_services, monkeypatch
):
    """
    Run a single aggregated chat-completions test using Qwen 0.6B with the
    test logits processor enabled, and expect "Hello world" in the response.
    """

    # Enable HelloWorld logits processor only for this test
    monkeypatch.setenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR", "1")

    base = trtllm_configs["aggregated"]
    config = TRTLLMConfig(
        name="aggregated_qwen_chatonly",
        directory=base.directory,
        script_name=base.script_name,  # agg.sh
        marks=[],  # not used by this direct test
131
132
133
        request_payloads=[
            chat_payload_default(),
        ],
134
135
136
137
138
        model="Qwen/Qwen3-0.6B",
        delayed_start=base.delayed_start,
        timeout=base.timeout,
    )

139
    run_serve_deployment(config, request)