test_trtllm.py 9.22 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
import dataclasses
5
6
import logging
import os
7
from dataclasses import dataclass, field
8
9
10

import pytest

11
12
13
14
15
from tests.serve.common import (
    WORKSPACE_DIR,
    params_with_model_mark,
    run_serve_deployment,
)
16
from tests.utils.constants import DefaultPort
17
from tests.utils.engine_process import EngineConfig
18
from tests.utils.payload_builder import (
19
20
    TEXT_PROMPT,
    chat_payload,
21
    chat_payload_default,
22
    completion_payload,
23
24
    completion_payload_default,
    metric_payload_default,
25
    multimodal_payload_default,
26
)
27
28
29
30
31

logger = logging.getLogger(__name__)


@dataclass
32
class TRTLLMConfig(EngineConfig):
33
34
    """Configuration for trtllm test scenarios"""

35
    stragglers: list[str] = field(default_factory=lambda: ["TRTLLM:EngineCore"])
36

37

38
trtllm_dir = os.environ.get("TRTLLM_DIR") or os.path.join(
39
    WORKSPACE_DIR, "examples/backends/trtllm"
40
)
41

42
43
44
# TensorRT-LLM test configurations
# NOTE: pytest.mark.gpu_1 tests take ~442s (7m 22s) total to run sequentially (with models pre-cached)
# TODO: Parallelize these tests to reduce total execution time
45
46
47
trtllm_configs = {
    "aggregated": TRTLLMConfig(
        name="aggregated",
48
        directory=trtllm_dir,
49
        script_name="agg_metrics.sh",
50
51
52
53
        marks=[
            pytest.mark.gpu_1,
            pytest.mark.pre_merge,
            pytest.mark.trtllm,
54
55
56
            pytest.mark.timeout(
                300
            ),  # 3x measured time (44.66s) + download time (150s)
57
        ],
58
        model="Qwen/Qwen3-0.6B",
59
        frontend_port=DefaultPort.FRONTEND.value,
60
61
62
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
63
            metric_payload_default(min_num_requests=6, backend="trtllm"),
64
        ],
65
66
67
    ),
    "disaggregated": TRTLLMConfig(
        name="disaggregated",
68
        directory=trtllm_dir,
69
        script_name="disagg.sh",
70
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm, pytest.mark.post_merge],
71
        model="Qwen/Qwen3-0.6B",
72
        frontend_port=DefaultPort.FRONTEND.value,
73
74
75
76
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
        ],
77
    ),
78
79
80
81
    "disaggregated_same_gpu": TRTLLMConfig(
        name="disaggregated_same_gpu",
        directory=trtllm_dir,
        script_name="disagg_same_gpu.sh",
82
83
84
85
        marks=[
            pytest.mark.gpu_1,
            pytest.mark.pre_merge,
            pytest.mark.trtllm,
86
            pytest.mark.skip(reason="unstable"),
87
88
89
            pytest.mark.timeout(
                480
            ),  # 3x measured time (103.66s) + download time (150s)
90
        ],
91
        model="Qwen/Qwen3-0.6B",
92
        frontend_port=DefaultPort.FRONTEND.value,
93
94
95
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
96
97
98
99
100
101
            metric_payload_default(
                port=DefaultPort.SYSTEM1.value, min_num_requests=6, backend="trtllm"
            ),
            metric_payload_default(
                port=DefaultPort.SYSTEM2.value, min_num_requests=6, backend="trtllm"
            ),
102
103
        ],
    ),
104
105
106
107
108
109
    "aggregated_logprobs": TRTLLMConfig(
        name="aggregated_logprobs",
        directory=trtllm_dir,
        script_name="agg.sh",
        marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm],
        model="Qwen/Qwen3-0.6B",
110
        frontend_port=DefaultPort.FRONTEND.value,
111
112
113
114
115
116
117
118
119
120
121
122
123
        request_payloads=[
            chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
            chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
            chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
            chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
        ],
    ),
    "disaggregated_logprobs": TRTLLMConfig(
        name="disaggregated_logprobs",
        directory=trtllm_dir,
        script_name="disagg.sh",
        marks=[pytest.mark.gpu_2, pytest.mark.post_merge, pytest.mark.trtllm],
        model="Qwen/Qwen3-0.6B",
124
        frontend_port=DefaultPort.FRONTEND.value,
125
126
127
128
129
130
131
        request_payloads=[
            chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
            chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
            chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
            chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
        ],
    ),
132
133
    "aggregated_router": TRTLLMConfig(
        name="aggregated_router",
134
        directory=trtllm_dir,
135
        script_name="agg_router.sh",
136
137
138
139
        marks=[
            pytest.mark.gpu_1,
            pytest.mark.pre_merge,
            pytest.mark.trtllm,
140
141
142
            pytest.mark.timeout(
                300
            ),  # 3x measured time (37.91s) + download time (180s)
143
        ],
144
        model="Qwen/Qwen3-0.6B",
145
        frontend_port=DefaultPort.FRONTEND.value,
146
147
148
149
        request_payloads=[
            chat_payload_default(
                expected_log=[
                    r"Event processor for worker_id \d+ processing event: Stored\(",
Yan Ru Pei's avatar
Yan Ru Pei committed
150
                    r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
151
152
153
154
155
156
                ]
            )
        ],
        env={
            "DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
        },
157
158
159
    ),
    "disaggregated_router": TRTLLMConfig(
        name="disaggregated_router",
160
        directory=trtllm_dir,
161
        script_name="disagg_router.sh",
162
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm, pytest.mark.nightly],
163
        model="Qwen/Qwen3-0.6B",
164
        frontend_port=DefaultPort.FRONTEND.value,
165
166
167
168
        request_payloads=[
            chat_payload_default(),
            completion_payload_default(),
        ],
169
    ),
170
171
172
173
    "disaggregated_multimodal": TRTLLMConfig(
        name="disaggregated_multimodal",
        directory=trtllm_dir,
        script_name="disagg_multimodal.sh",
174
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm, pytest.mark.multimodal],
175
        model="Qwen/Qwen2-VL-7B-Instruct",
176
        frontend_port=DefaultPort.FRONTEND.value,
177
178
179
180
        timeout=900,
        delayed_start=60,
        request_payloads=[multimodal_payload_default()],
    ),
181
182
183
184
    "completions_only": TRTLLMConfig(
        name="completions_only",
        directory=trtllm_dir,
        script_name="agg.sh",
185
186
187
        marks=[
            pytest.mark.gpu_1,
            pytest.mark.trtllm,
188
189
190
            pytest.mark.timeout(
                480
            ),  # 3x measured time (83.85s) + download time (210s) for 7B model
191
        ],
192
193
194
195
196
197
198
199
        model="deepseek-ai/deepseek-llm-7b-base",
        script_args=["--dyn-endpoint-types", "completions"],
        env={
            "MODEL_PATH": "deepseek-ai/deepseek-llm-7b-base",
            "SERVED_MODEL_NAME": "deepseek-ai/deepseek-llm-7b-base",
        },
        request_payloads=[
            completion_payload_default(),
200
            completion_payload(prompt=TEXT_PROMPT, logprobs=3),
201
202
        ],
    ),
203
204
205
}


Alec's avatar
Alec committed
206
@pytest.fixture(params=params_with_model_mark(trtllm_configs))
207
208
def trtllm_config_test(request):
    """Fixture that provides different trtllm test configurations"""
209
    return trtllm_configs[request.param]
210
211


212
@pytest.mark.trtllm
213
@pytest.mark.e2e
214
215
216
217
218
219
220
def test_deployment(
    trtllm_config_test,
    request,
    runtime_services_dynamic_ports,
    dynamo_dynamic_ports,
    predownload_models,
):
221
222
223
    """
    Test dynamo deployments with different configurations.
    """
224
225
226
227
228
229
230
231
232
233
234
235
    # Use per-test ports so tests can run safely under pytest-xdist.
    config = dataclasses.replace(
        trtllm_config_test, frontend_port=dynamo_dynamic_ports.frontend_port
    )
    # Non-port env stays here; ports are wired by run_serve_deployment(ports=...).
    config.env.update(
        {
            "MODEL_PATH": config.model,
            "SERVED_MODEL_NAME": config.model,
        }
    )
    run_serve_deployment(config, request, ports=dynamo_dynamic_ports)
236
237


238
# TODO make this a normal guy
239
240
@pytest.mark.e2e
@pytest.mark.gpu_1
241
@pytest.mark.trtllm
242
@pytest.mark.timeout(660)  # 3x measured time (159.68s) + download time (180s)
243
def test_chat_only_aggregated_with_test_logits_processor(
244
245
246
247
248
    request,
    runtime_services_dynamic_ports,
    dynamo_dynamic_ports,
    predownload_models,
    monkeypatch,
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
):
    """
    Run a single aggregated chat-completions test using Qwen 0.6B with the
    test logits processor enabled, and expect "Hello world" in the response.
    """

    # Enable HelloWorld logits processor only for this test
    monkeypatch.setenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR", "1")

    base = trtllm_configs["aggregated"]
    config = TRTLLMConfig(
        name="aggregated_qwen_chatonly",
        directory=base.directory,
        script_name=base.script_name,  # agg.sh
        marks=[],  # not used by this direct test
264
        request_payloads=[
265
            chat_payload_default(expected_response=["Hello world!"]),
266
        ],
267
268
269
270
271
        model="Qwen/Qwen3-0.6B",
        delayed_start=base.delayed_start,
        timeout=base.timeout,
    )

272
273
274
275
276
277
278
279
280
281
    config = dataclasses.replace(
        config, frontend_port=dynamo_dynamic_ports.frontend_port
    )
    config.env.update(
        {
            "MODEL_PATH": config.model,
            "SERVED_MODEL_NAME": config.model,
        }
    )
    run_serve_deployment(config, request, ports=dynamo_dynamic_ports)