trtllm_inc.py 20.7 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# TODO:
# - Support disaggregated serving
6
# - Update examples to use this engine.
7
#
8
# `dynamo-run out=trtllm` runs this script
9
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
10
11
12
13
14
15
#
# Disaggregated serving:
# - Ingress: dynamo run in=http out=dyn
# - Decode Worker: python3 trtllm_inc.py --task=decode --extra-engine-args=trtllm_config/sample.yaml
# - Prefill Worker: python3 trtllm_inc.py --task=prefill --extra-engine-args=trtllm_config/sample.yaml

16
17
18

import argparse
import asyncio
19
20
import base64
import copy
21
22
import logging
import sys
23
import warnings
24
from dataclasses import asdict, dataclass
25
from typing import Optional
26
27
28
29

import uvloop

# Import TRTLLM and related modules
30
from tensorrt_llm import SamplingParams
31
from tensorrt_llm.llmapi import DisaggregatedParams
32
33
34
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory

35
36
37
38
39
40
from dynamo.llm import (
    ModelType,
    get_tensorrtllm_engine,
    get_tensorrtllm_publisher,
    register_llm,
)
41
from dynamo.runtime import DistributedRuntime, dynamo_worker
42
from dynamo.runtime.logging import configure_dynamo_logging
43
44
45
46
47

# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
48
49
# Default endpoint for the remote prefill service.
DEFAULT_PREFILL_ENDPOINT = "dyn://dynamo.prefill.generate"
50

51
52
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
53

54
configure_dynamo_logging()
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
    endpoint_str = endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
        raise ValueError(
            f"Invalid endpoint format: '{endpoint}'. "
            "Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )

    return tuple(endpoint_parts)


class DisaggregatedParamsCodec:
    """
    Codec for encoding and decoding disaggregated params for network transfer.
    """

    @staticmethod
    def decode(
        disaggregated_params: DisaggregatedParams,
    ) -> DisaggregatedParams:
        if disaggregated_params is None:
            return None

        opaque_state = (
            base64.b64decode(disaggregated_params.opaque_state)
            if disaggregated_params.opaque_state is not None
            else None
        )
        return DisaggregatedParams(
            request_type=disaggregated_params.request_type,
            first_gen_tokens=disaggregated_params.first_gen_tokens,
            ctx_request_id=disaggregated_params.ctx_request_id,
            opaque_state=opaque_state,
            draft_tokens=disaggregated_params.draft_tokens,
        )

    @staticmethod
    def encode(
        disaggregated_params: DisaggregatedParams,
    ) -> DisaggregatedParams:
        if disaggregated_params is None:
            return None

        encoded_opaque_state = (
            base64.b64encode(disaggregated_params.opaque_state).decode("utf-8")
            if disaggregated_params.opaque_state is not None
            else None
        )
        return DisaggregatedParams(
            request_type=disaggregated_params.request_type,
            first_gen_tokens=disaggregated_params.first_gen_tokens,
            ctx_request_id=disaggregated_params.ctx_request_id,
            opaque_state=encoded_opaque_state,
            draft_tokens=disaggregated_params.draft_tokens,
        )


115
116
117
118
119
120
121
class Config:
    """Command line parameters or defaults"""

    namespace: str
    component: str
    endpoint: str
    model_path: str
122
    model_name: Optional[str] = None
123
124
125
    tensor_parallel_size: int
    kv_block_size: int
    extra_engine_args: str
126
    publish_events_and_metrics: bool
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    disaggregation_mode: str
    remote_prefill_endpoint: str

    def __str__(self) -> str:
        return (
            f"Config(namespace={self.namespace}, "
            f"component={self.component}, "
            f"endpoint={self.endpoint}, "
            f"model_path={self.model_path}, "
            f"model_name={self.model_name}, "
            f"tensor_parallel_size={self.tensor_parallel_size}, "
            f"kv_block_size={self.kv_block_size}, "
            f"extra_engine_args={self.extra_engine_args}, "
            f"publish_events_and_metrics={self.publish_events_and_metrics}, "
            f"disaggregation_mode={self.disaggregation_mode}, "
            f"remote_prefill_endpoint={self.remote_prefill_endpoint})"
        )


@dataclass
class RequestHandlerConfig:
    """
    Configuration for the request handler
    """

    component: object
    engine: object
    default_sampling_params: object
    publisher: object
    disaggregation_mode: str
    remote_prefill_client: object
158
159
160
161
162
163
164


class RequestHandler:
    """
    Request handler for the generate endpoint
    """

165
166
167
168
169
170
171
    def __init__(self, config: RequestHandlerConfig):
        self.engine = config.engine
        self.component = config.component
        self.default_sampling_params = config.default_sampling_params
        self.publisher = config.publisher
        self.disaggregation_mode = config.disaggregation_mode
        self.remote_prefill_client = config.remote_prefill_client
172
        self.first_generation = True
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    async def remote_prefill(self, request):
        """
        Send a prefill request to the remote prefill worker.

        Args:
            request: The original request to be sent for prefill

        Returns:
            The response from the remote prefill worker

        Raises:
            ValueError: If prefill client is not initialized or multiple responses received
        """

        prefill_request = copy.deepcopy(request)
        # TRTLLM requires max_tokens to be set for prefill requests.
        prefill_request["stop_conditions"]["max_tokens"] = 1

        # Set the disaggregated params to context_only for remote prefill
        prefill_request["disaggregated_params"] = asdict(
            DisaggregatedParamsCodec.encode(
                DisaggregatedParams(request_type="context_only")
            )
        )

        if self.remote_prefill_client is None:
            raise ValueError("Prefill client not initialized")
        try:
            # TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
            remote_prefill_responses = [
                remote_prefill_response
                async for remote_prefill_response in await self.remote_prefill_client.round_robin(
                    prefill_request
                )
            ]
        except Exception as e:
            raise ValueError(f"Error in remote prefill: {e}")

        if len(remote_prefill_responses) > 1:
            raise ValueError(
                "Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
            )

        if len(remote_prefill_responses) == 0:
            raise ValueError("No response received from remote prefill worker")

        remote_prefill_response = remote_prefill_responses[0]
        return remote_prefill_response

223
    async def generate(self, request):
224
        # Check if there is an error in the publisher error queue
225
        publishers_error = (
226
            self.publisher.check_error_queue() if self.publisher else None
227
        )
228
229
230
        if publishers_error:
            raise publishers_error

231
232
        inputs = request["token_ids"]

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        # Decode the disaggregated params from the request
        if "disaggregated_params" in request:
            disaggregated_params = DisaggregatedParamsCodec.decode(
                DisaggregatedParams(**request["disaggregated_params"])
            )
        else:
            disaggregated_params = None

        num_output_tokens_so_far = 0

        if self.disaggregation_mode == "decode":
            # Run prefill/context phase remotely if disaggregation mode is decode.
            try:
                prefill_result = await self.remote_prefill(request)
            except Exception as e:
                raise ValueError(f"Error in remote prefill: {e}")

            remote_prefill_response = prefill_result.data()
            if (
                remote_prefill_response["finish_reason"] == "stop"
                or remote_prefill_response["finish_reason"] == "error"
            ):
                yield remote_prefill_response
                return
            num_output_tokens_so_far = len(remote_prefill_response["token_ids"])

            # Decode the disaggregated params from the remote prefill response
            disaggregated_params = DisaggregatedParamsCodec.decode(
                DisaggregatedParams(**remote_prefill_response["disaggregated_params"])
            )

            # Send the first token response to the client
            first_token_response = remote_prefill_response
            first_token_response.pop("disaggregated_params")
            yield first_token_response

            # Set the disaggregated params to generation_only for the rest of the generation
            disaggregated_params.request_type = "generation_only"

272
273
274
275
276
277
278
279
280
281
282
283
284
        sampling_params = self.default_sampling_params
        for key, value in request["sampling_options"].items():
            if not value:
                continue
            if hasattr(sampling_params, key):
                setattr(sampling_params, key, value)

        max_tokens = request["stop_conditions"]["max_tokens"]
        if max_tokens:
            sampling_params.max_tokens = max_tokens

        # TODO: Disable streaming for context only requests when adding disagg support
        async for res in self.engine.llm.generate_async(
285
286
287
288
            inputs=inputs,
            sampling_params=sampling_params,
            disaggregated_params=disaggregated_params,
            streaming=(self.disaggregation_mode != "prefill"),
289
        ):
290
291
            # TRTLLM engine needs to start generating tokens first before stats
            # can be retrieved.
292
293
            if self.first_generation and self.publisher:
                self.publisher.start()
294
295
                self.first_generation = False

296
            if res.finished and self.disaggregation_mode != "prefill":
297
298
299
300
301
302
303
304
305
306
307
308
309
310
                yield {"finish_reason": "stop", "token_ids": []}
                break

            if not res.outputs:
                yield {"finish_reason": "error", "token_ids": []}
                break

            output = res.outputs[0]
            next_total_toks = len(output.token_ids)
            out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
            if output.finish_reason:
                out["finish_reason"] = output.finish_reason
            if output.stop_reason:
                out["stop_reason"] = output.stop_reason
311
312
313
314
315
            if self.disaggregation_mode == "prefill":
                # Return the disaggregated params only when operating in prefill mode.
                out["disaggregated_params"] = asdict(
                    DisaggregatedParamsCodec.encode(output.disaggregated_params)
                )
316
317
318
319
320
321
322
323
324
325
326
327
328
            yield out
            num_output_tokens_so_far = next_total_toks


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
    await init(runtime, cmd_line_args())


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

    logging.info(f"Initializing the worker with config: {config}")

    remote_prefill_client = None
    if config.disaggregation_mode == "decode":
        logging.info(
            f"Initializing remote prefill client for endpoint: {config.remote_prefill_endpoint}"
        )
        parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
            config.remote_prefill_endpoint
        )
        remote_prefill_client = (
            await runtime.namespace(parsed_namespace)
            .component(parsed_component_name)
            .endpoint(parsed_endpoint_name)
            .client()
        )

347
348
349
350
351
352
353
354
355
    component = runtime.namespace(config.namespace).component(config.component)
    await component.create_service()

    # Convert model path to Path object if it's a local path, otherwise keep as string
    model_path = str(config.model_path)

    arg_map = {
        "model": model_path,
        "tensor_parallel_size": config.tensor_parallel_size,
356
        "backend": "pytorch",
357
358
359
        "skip_tokenizer_init": True,
    }
    if config.extra_engine_args != "":
360
        # TODO: Support extra engine args from json file as well.
361
        arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    if config.publish_events_and_metrics:
        # 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
        kv_cache_config = None
        if "kv_cache_config" not in arg_map:
            kv_cache_config = {}
            kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
        else:
            kv_cache_config = arg_map["kv_cache_config"]
            if not kv_cache_config.event_buffer_max_size:
                kv_cache_config.event_buffer_max_size = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
        arg_map["kv_cache_config"] = kv_cache_config

        # Only pytorch backend is supported for now to publish events and metrics.
        if "backend" not in arg_map:
            arg_map["backend"] = "pytorch"
        elif arg_map["backend"] != "pytorch":
            logging.error(
                "Only pytorch backend is supported for now to publish events and metrics."
            )
            sys.exit(1)
382

383
    logging.info(f"TRTLLM engine args: {arg_map}")
384
385
386
387
388
389
390
391
    engine_args = arg_map

    # Populate default sampling params from the model
    tokenizer = tokenizer_factory(arg_map["model"])
    default_sampling_params = SamplingParams()
    default_sampling_params._setup(tokenizer)
    default_sampling_params.stop = None

392
    async with get_tensorrtllm_engine(engine_args) as engine:
393
        endpoint = component.endpoint(config.endpoint)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416

        if config.disaggregation_mode != "prefill":
            # Register the model with the endpoint if disaggregation mode is not prefill.
            # Prefill worker will get the request directly from the Decode worker and not
            # through the ingress.
            # FIXME: Enable publishing events and metrics for disaggregated prefill.
            # Currently prefill workers are chosen in round-robin fashion.
            await register_llm(
                ModelType.Backend,
                endpoint,
                config.model_path,
                config.model_name,
                kv_cache_block_size=config.kv_block_size,
            )

        # publisher will be set later if publishing is enabled.
        handler_config = RequestHandlerConfig(
            component=component,
            engine=engine,
            default_sampling_params=default_sampling_params,
            publisher=None,
            disaggregation_mode=config.disaggregation_mode,
            remote_prefill_client=remote_prefill_client,
417
418
        )

419
420
421
422
423
        if (
            config.publish_events_and_metrics
            and config.disaggregation_mode != "prefill"
        ):
            # Initialize and pass in the publisher to the request handler to
424
            # publish events and metrics.
425
426
427
            kv_listener = runtime.namespace(config.namespace).component(
                config.component
            )
428
            async with get_tensorrtllm_publisher(
429
430
431
432
433
                component,
                engine,
                kv_listener,
                int(endpoint.lease_id()),
                config.kv_block_size,
434
            ) as publisher:
435
436
                handler_config.publisher = publisher
                handler = RequestHandler(handler_config)
437
438
                await endpoint.serve_endpoint(handler.generate)
        else:
439
            handler = RequestHandler(handler_config)
440
            await endpoint.serve_endpoint(handler.generate)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467


def cmd_line_args():
    parser = argparse.ArgumentParser(
        description="TensorRT-LLM server integrated with Dynamo LLM."
    )
    parser.add_argument(
        "--endpoint",
        type=str,
        default=DEFAULT_ENDPOINT,
        help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default=DEFAULT_MODEL,
        help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="",
        help="Name to serve the model under. Defaults to deriving it from model path.",
    )
    parser.add_argument(
        "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
    )
468
469
    # IMPORTANT: We should ideally not expose this to users. We should be able to
    # query the block size from the TRTLLM engine.
470
471
472
    parser.add_argument(
        "--kv-block-size", type=int, default=32, help="Size of a KV cache block."
    )
473
474
475
476
477
478
    parser.add_argument(
        "--context-length",
        type=int,
        default=None,
        help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
    )
479
480
481
482
483
484
    parser.add_argument(
        "--extra-engine-args",
        type=str,
        default="",
        help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
    )
485
486
487
    parser.add_argument(
        "--publish-events-and-metrics",
        action="store_true",
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
    )
    parser.add_argument(
        "--task",
        type=str,
        action="append",
        choices=["prefill", "decode", "prefill_and_decode"],
        default=[],
        help="Specifies the task for the engine. Can be specified multiple time for different tasks. Will raise an error if conflicting tasks are specified.",
    )
    parser.add_argument(
        "--remote-prefill-endpoint",
        type=str,
        default=DEFAULT_PREFILL_ENDPOINT,
        help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send prefill requests to when running in decode disaggregation mode. Default: {DEFAULT_PREFILL_ENDPOINT}",
503
    )
504
505
    args = parser.parse_args()

506
    # Validate arguments
507
508
509
510
511
512
    if args.context_length is not None:
        warnings.warn(
            "--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
            UserWarning,
        )

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    endpoint = args.endpoint

    # disaggregation mode
    disaggregation_mode = None
    for choice in ["prefill", "decode", "prefill_and_decode"]:
        if choice in args.task:
            if disaggregation_mode is not None:
                raise ValueError(
                    f"Conflicting tasks specified: {args.task}. Please specify only one task."
                )
            disaggregation_mode = choice

    if disaggregation_mode is None:
        disaggregation_mode = "prefill_and_decode"

    if disaggregation_mode == "prefill":
        if args.remote_prefill_endpoint != DEFAULT_PREFILL_ENDPOINT:
            logging.error(
                "--remote-prefill-endpoint is not supported when running in prefill disaggregation mode."
            )
            sys.exit(1)
        else:
            endpoint = DEFAULT_PREFILL_ENDPOINT

        if args.publish_events_and_metrics:
            warnings.warn(
                "--publish-events-and-metrics is not supported when running in prefill disaggregation mode.",
                UserWarning,
            )

543
544
545
546
547
548
549
550
    config = Config()
    config.model_path = args.model_path
    if args.model_name:
        config.model_name = args.model_name
    else:
        # This becomes an `Option` on the Rust side
        config.model_name = None

551
552
553
    parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
        endpoint
    )
554
555
556
557
558
559
560

    config.namespace = parsed_namespace
    config.component = parsed_component_name
    config.endpoint = parsed_endpoint_name
    config.tensor_parallel_size = args.tensor_parallel_size
    config.kv_block_size = args.kv_block_size
    config.extra_engine_args = args.extra_engine_args
561
    config.publish_events_and_metrics = args.publish_events_and_metrics
562
563
    config.disaggregation_mode = disaggregation_mode
    config.remote_prefill_endpoint = args.remote_prefill_endpoint
564
565
566
567
568
569

    return config


if __name__ == "__main__":
    uvloop.install()
570
571
572
573
    try:
        asyncio.run(worker())
    except KeyboardInterrupt:
        logging.info("Received SIGINT, shutting down...")