test_trtllm.py 5.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import time
from dataclasses import dataclass

import pytest

11
from tests.serve.common import EngineConfig, create_payload_for_config
12
13
14
15
from tests.utils.deployment_graph import (
    chat_completions_response_handler,
    completions_response_handler,
)
16
from tests.utils.engine_process import EngineProcess
17
18
19
20
21

logger = logging.getLogger(__name__)


@dataclass
22
class TRTLLMConfig(EngineConfig):
23
24
25
26
27
    """Configuration for trtllm test scenarios"""

    timeout: int = 60


28
class TRTLLMProcess(EngineProcess):
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    """Simple process manager for trtllm shell scripts"""

    def __init__(self, config: TRTLLMConfig, request):
        self.port = 8000
        self.config = config
        self.dir = config.directory
        script_path = os.path.join(self.dir, "launch", config.script_name)

        if not os.path.exists(script_path):
            raise FileNotFoundError(f"trtllm script not found: {script_path}")

        # Set these env vars to customize model launched by launch script to match test
        os.environ["MODEL_PATH"] = config.model
        os.environ["SERVED_MODEL_NAME"] = config.model

        command = ["bash", script_path]

        super().__init__(
            command=command,
            timeout=config.timeout,
            display_output=True,
            working_dir=self.dir,
            health_check_ports=[],  # Disable port health check
            health_check_urls=[
                (f"http://localhost:{self.port}/v1/models", self._check_models_api)
            ],
            delayed_start=config.delayed_start,
            terminate_existing=False,  # If true, will call all bash processes including myself
            stragglers=[],  # Don't kill any stragglers automatically
            log_dir=request.node.name,
        )


# trtllm test configurations
trtllm_configs = {
    "aggregated": TRTLLMConfig(
        name="aggregated",
        directory="/workspace/components/backends/trtllm",
        script_name="agg.sh",
68
        marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
69
70
71
72
73
74
        endpoints=["v1/chat/completions", "v1/completions"],
        response_handlers=[
            chat_completions_response_handler,
            completions_response_handler,
        ],
        model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
75
76
        delayed_start=0,
        timeout=360,
77
78
79
80
81
    ),
    "disaggregated": TRTLLMConfig(
        name="disaggregated",
        directory="/workspace/components/backends/trtllm",
        script_name="disagg.sh",
82
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
83
84
85
86
87
88
        endpoints=["v1/chat/completions", "v1/completions"],
        response_handlers=[
            chat_completions_response_handler,
            completions_response_handler,
        ],
        model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
89
90
        delayed_start=0,
        timeout=360,
91
92
93
94
95
96
97
98
    ),
    # TODO: These are sanity tests that the kv router examples launch
    # and inference without error, but do not do detailed checks on the
    # behavior of KV routing.
    "aggregated_router": TRTLLMConfig(
        name="aggregated_router",
        directory="/workspace/components/backends/trtllm",
        script_name="agg_router.sh",
99
        marks=[pytest.mark.gpu_1, pytest.mark.trtllm_marker],
100
101
102
103
104
105
        endpoints=["v1/chat/completions", "v1/completions"],
        response_handlers=[
            chat_completions_response_handler,
            completions_response_handler,
        ],
        model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
106
107
        delayed_start=0,
        timeout=360,
108
109
110
111
112
    ),
    "disaggregated_router": TRTLLMConfig(
        name="disaggregated_router",
        directory="/workspace/components/backends/trtllm",
        script_name="disagg_router.sh",
113
        marks=[pytest.mark.gpu_2, pytest.mark.trtllm_marker],
114
115
116
117
118
119
        endpoints=["v1/chat/completions", "v1/completions"],
        response_handlers=[
            chat_completions_response_handler,
            completions_response_handler,
        ],
        model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
120
121
        delayed_start=0,
        timeout=360,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    ),
}


@pytest.fixture(
    params=[
        pytest.param(config_name, marks=config.marks)
        for config_name, config in trtllm_configs.items()
    ]
)
def trtllm_config_test(request):
    """Fixture that provides different trtllm test configurations"""
    return trtllm_configs[request.param]


@pytest.mark.e2e
@pytest.mark.slow
def test_deployment(trtllm_config_test, request, runtime_services):
    """
    Test dynamo deployments with different configurations.
    """

    # runtime_services is used to start nats and etcd

    logger = logging.getLogger(request.node.name)
    logger.info("Starting test_deployment")

    config = trtllm_config_test
    payload = create_payload_for_config(config)

    logger.info(f"Using model: {config.model}")
    logger.info(f"Script: {config.script_name}")

    with TRTLLMProcess(config, request) as server_process:
        assert len(config.endpoints) == len(config.response_handlers)
        for endpoint, response_handler in zip(
            config.endpoints, config.response_handlers
        ):
            url = f"http://localhost:{server_process.port}/{endpoint}"
            start_time = time.time()
            elapsed = 0.0

            request_body = (
                payload.payload_chat
                if endpoint == "v1/chat/completions"
                else payload.payload_completions
            )

            for _ in range(payload.repeat_count):
                elapsed = time.time() - start_time

173
174
                response = server_process.send_request(
                    url, payload=request_body, timeout=config.timeout - elapsed
175
                )
176
                server_process.check_response(payload, response, response_handler)