trtllm_inc.py 10.3 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# TODO:
# - Support disaggregated serving
6
# - Update examples to use this engine.
7
#
8
# `dynamo-run out=trtllm` runs this script
9
10
11
12
13
14
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params

import argparse
import asyncio
import logging
import sys
15
import warnings
16
from typing import Optional
17
18
19
20

import uvloop

# Import TRTLLM and related modules
21
from tensorrt_llm import SamplingParams
22
23
24
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory

25
26
27
28
29
30
from dynamo.llm import (
    ModelType,
    get_tensorrtllm_engine,
    get_tensorrtllm_publisher,
    register_llm,
)
31
32
33
34
35
36
37
from dynamo.runtime import DistributedRuntime, dynamo_worker

# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

38
39
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
40
41
42
43
44
45
46
47
48
49
50

logging.basicConfig(level=logging.DEBUG)


class Config:
    """Command line parameters or defaults"""

    namespace: str
    component: str
    endpoint: str
    model_path: str
51
    model_name: Optional[str] = None
52
53
54
    tensor_parallel_size: int
    kv_block_size: int
    extra_engine_args: str
55
    publish_events_and_metrics: bool
56
57
58
59
60
61
62


class RequestHandler:
    """
    Request handler for the generate endpoint
    """

63
    def __init__(self, component, engine, default_sampling_params, publishers):
64
65
66
        self.engine = engine
        self.component = component
        self.default_sampling_params = default_sampling_params
67
68
        self.publishers = publishers
        self.first_generation = True
69
70

    async def generate(self, request):
71
        # Check if there is an error in the publishers error queue
72
73
74
        publishers_error = (
            self.publishers.check_error_queue() if self.publishers else None
        )
75
76
77
        if publishers_error:
            raise publishers_error

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        inputs = request["token_ids"]

        sampling_params = self.default_sampling_params
        for key, value in request["sampling_options"].items():
            if not value:
                continue
            if hasattr(sampling_params, key):
                setattr(sampling_params, key, value)

        max_tokens = request["stop_conditions"]["max_tokens"]
        if max_tokens:
            sampling_params.max_tokens = max_tokens

        num_output_tokens_so_far = 0
        # TODO: Disable streaming for context only requests when adding disagg support
        async for res in self.engine.llm.generate_async(
            inputs=inputs, sampling_params=sampling_params, streaming=True
        ):
96
97
98
            # TRTLLM engine needs to start generating tokens first before stats
            # can be retrieved.
            if self.first_generation and self.publishers:
99
                self.publishers.start()
100
101
                self.first_generation = False

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            if res.finished:
                yield {"finish_reason": "stop", "token_ids": []}
                break

            if not res.outputs:
                yield {"finish_reason": "error", "token_ids": []}
                break

            output = res.outputs[0]
            next_total_toks = len(output.token_ids)
            out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
            if output.finish_reason:
                out["finish_reason"] = output.finish_reason
            if output.stop_reason:
                out["stop_reason"] = output.stop_reason
            yield out
            num_output_tokens_so_far = next_total_toks


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
    await init(runtime, cmd_line_args())


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """
    component = runtime.namespace(config.namespace).component(config.component)
    await component.create_service()

    # Convert model path to Path object if it's a local path, otherwise keep as string
    model_path = str(config.model_path)

    arg_map = {
        "model": model_path,
        "tensor_parallel_size": config.tensor_parallel_size,
        "skip_tokenizer_init": True,
        "disable_log_requests": True,
        "enable_prefix_caching": True,
        # KV routing relies on logging KV metrics
        "disable_log_stats": False,
    }
    if config.extra_engine_args != "":
146
        # TODO: Support extra engine args from json file as well.
147
        arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    if config.publish_events_and_metrics:
        # 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
        kv_cache_config = None
        if "kv_cache_config" not in arg_map:
            kv_cache_config = {}
            kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
        else:
            kv_cache_config = arg_map["kv_cache_config"]
            if not kv_cache_config.event_buffer_max_size:
                kv_cache_config.event_buffer_max_size = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
        arg_map["kv_cache_config"] = kv_cache_config

        # Only pytorch backend is supported for now to publish events and metrics.
        if "backend" not in arg_map:
            arg_map["backend"] = "pytorch"
        elif arg_map["backend"] != "pytorch":
            logging.error(
                "Only pytorch backend is supported for now to publish events and metrics."
            )
            sys.exit(1)
168

169
    logging.info(f"TRTLLM engine args: {arg_map}")
170
171
172
173
174
175
176
177
    engine_args = arg_map

    # Populate default sampling params from the model
    tokenizer = tokenizer_factory(arg_map["model"])
    default_sampling_params = SamplingParams()
    default_sampling_params._setup(tokenizer)
    default_sampling_params.stop = None

178
    async with get_tensorrtllm_engine(engine_args) as engine:
179
180
181
182
183
        endpoint = component.endpoint(config.endpoint)
        await register_llm(
            ModelType.Backend, endpoint, config.model_path, config.model_name
        )

184
        if config.publish_events_and_metrics:
185
186
            # Initialize and pass in the publishers to the request handler to
            # publish events and metrics.
187
188
189
            kv_listener = runtime.namespace(config.namespace).component(
                config.component
            )
190
            async with get_tensorrtllm_publisher(
191
192
193
194
195
                component,
                engine,
                kv_listener,
                int(endpoint.lease_id()),
                config.kv_block_size,
196
197
198
199
200
201
202
203
            ) as publisher:
                handler = RequestHandler(
                    component, engine, default_sampling_params, publisher
                )
                await endpoint.serve_endpoint(handler.generate)
        else:
            # No publishers, so just pass in None to the request handler.
            handler = RequestHandler(component, engine, default_sampling_params, None)
204
            await endpoint.serve_endpoint(handler.generate)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231


def cmd_line_args():
    parser = argparse.ArgumentParser(
        description="TensorRT-LLM server integrated with Dynamo LLM."
    )
    parser.add_argument(
        "--endpoint",
        type=str,
        default=DEFAULT_ENDPOINT,
        help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default=DEFAULT_MODEL,
        help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="",
        help="Name to serve the model under. Defaults to deriving it from model path.",
    )
    parser.add_argument(
        "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
    )
232
233
    # IMPORTANT: We should ideally not expose this to users. We should be able to
    # query the block size from the TRTLLM engine.
234
235
236
    parser.add_argument(
        "--kv-block-size", type=int, default=32, help="Size of a KV cache block."
    )
237
238
239
240
241
242
    parser.add_argument(
        "--context-length",
        type=int,
        default=None,
        help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
    )
243
244
245
246
247
248
    parser.add_argument(
        "--extra-engine-args",
        type=str,
        default="",
        help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
    )
249
250
251
252
253
    parser.add_argument(
        "--publish-events-and-metrics",
        action="store_true",
        help="Publish events and metrics to the dynamo components.",
    )
254
255
    args = parser.parse_args()

256
257
258
259
260
261
    if args.context_length is not None:
        warnings.warn(
            "--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
            UserWarning,
        )

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    config = Config()
    config.model_path = args.model_path
    if args.model_name:
        config.model_name = args.model_name
    else:
        # This becomes an `Option` on the Rust side
        config.model_name = None

    endpoint_str = args.endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
        logging.error(
            f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
        sys.exit(1)

    parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts

    config.namespace = parsed_namespace
    config.component = parsed_component_name
    config.endpoint = parsed_endpoint_name
    config.tensor_parallel_size = args.tensor_parallel_size
    config.kv_block_size = args.kv_block_size
    config.extra_engine_args = args.extra_engine_args
286
    config.publish_events_and_metrics = args.publish_events_and_metrics
287
288
289
290
291
292

    return config


if __name__ == "__main__":
    uvloop.install()
293
294
295
296
    try:
        asyncio.run(worker())
    except KeyboardInterrupt:
        logging.info("Received SIGINT, shutting down...")