base_engine.py 23.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
18
19
20
21
import copy
import logging
import os
import signal
22
23
import threading
from contextlib import asynccontextmanager
24
from enum import Enum
25
26
27
28
from queue import Queue
from typing import Any, Optional

from common.parser import LLMAPIConfig
29
from common.protocol import DisaggregatedTypeConverter
30
31
32
33
34
35
36
from common.utils import ManagedThread, ServerType
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams
from tensorrt_llm.llmapi.disagg_utils import (
    CtxGenServerConfig,
    parse_disagg_config_file,
)
37
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
38
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
39

40
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
41
from dynamo.sdk import dynamo_context
42

43
logger = logging.getLogger(__name__)
44

45
logger.setLevel(logging.DEBUG)
46
47


48
49
50
class DisaggRequestType(Enum):
    CONTEXT_ONLY = "context_only"
    GENERATION_ONLY = "generation_only"
51

52
53
54
55
56
57
58
59
60

def update_args_from_disagg_config(
    engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
    # Update the LLM API config with the disaggregated config
    # Allows for different configs for context and generation servers
    engine_config.extra_args.update(**server_config.other_args)
    engine_config.update_sub_configs(server_config.other_args)
    return engine_config
61
62


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def _to_signed_i64(value: int | None) -> int | None:
    """Convert a Python int to signed 64-bit range by two's complement."""
    if value is None:
        return None

    if value >= 2**63:
        return value - 2**64
    if value < -(2**63):
        return ((value + 2**63) % 2**64) - 2**63
    return value


def get_sampling_params(sampling_params_dict, default_sampling_params):
    sampling_params = copy.deepcopy(default_sampling_params)
    for key, value in sampling_params_dict.items():
        if value is None:
            continue
        if hasattr(sampling_params, key):
            setattr(sampling_params, key, value)
    return sampling_params
83
84


85
class BaseTensorrtLLMEngine:
86
87
    def __init__(
        self,
88
89
90
91
92
93
94
95
96
97
        namespace_str: str = "dynamo",
        component_str: str = "tensorrt-llm",
        worker_id: Optional[str] = None,
        engine_config: LLMAPIConfig = None,
        remote_prefill: bool = False,
        min_workers: int = 0,
        disagg_config_file: Optional[str] = None,
        block_size: int = 32,
        router: str = "round_robin",
        server_type: ServerType = ServerType.GEN,
98
    ):
99
100
101
102
103
104
105
106
107
108
109
110
        self._namespace_str = namespace_str
        self._component_str = component_str
        self._worker_id = worker_id
        self._remote_prefill = remote_prefill
        self._min_workers = 0
        self._kv_block_size = block_size
        self._router = router
        self._server_type = server_type
        self._prefill_client = None
        self._error_queue: Queue = Queue()
        self._kv_metrics_publisher = None

111
        if self._remote_prefill or self._server_type == ServerType.CTX:
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            self._min_workers = min_workers
            if disagg_config_file is None or not os.path.exists(disagg_config_file):
                raise ValueError(
                    "llmapi_disaggregated_config file does not exist or not provided"
                )
            disagg_config = parse_disagg_config_file(disagg_config_file)
            server_config: CtxGenServerConfig = None

            for config in disagg_config.server_configs:
                # Select the first context server config
                if config.type == server_type.value:
                    server_config = config
                    break

            if server_config is None:
                server_type_str = (
                    "generation" if server_type == ServerType.GEN else "context"
                )
                raise ValueError(
                    f"No {server_type_str} server config found. Please check the disaggregated config file."
                )

            engine_config = update_args_from_disagg_config(engine_config, server_config)

        if router == "kv":
            self._publish_stats = True
            self._publish_events = True
        else:
            self._publish_stats = False
            self._publish_events = False

        if self._publish_stats:
144
            self._kv_metrics_publisher = WorkerMetricsPublisher()
145

146
147
148
149
150
151
152
153
154
155
156
157
158
        if self._publish_events:
            if self._worker_id is None:
                raise ValueError("Worker ID is None!")

            runtime = dynamo_context["runtime"]
            kv_listener = runtime.namespace(self._namespace_str).component(
                self._component_str
            )
            self._kv_event_publisher = KvEventPublisher(
                kv_listener, int(self._worker_id), self._kv_block_size
            )
            logger.info("KvEventPublisher is initialized")

159
        self._engine_config = engine_config
160
161
162
163
164
165
166
167
168
169
170

    def _init_engine(self):
        logger.info("Initializing engine")
        # Run the engine in a separate thread running the AsyncIO event loop.
        self._llm_engine: Optional[Any] = None
        self._llm_engine_start_cv = threading.Condition()
        self._llm_engine_shutdown_event = asyncio.Event()
        self._event_thread = threading.Thread(
            target=asyncio.run, args=(self._run_llm_engine(),)
        )

171
172
173
174
175
176
        # Populate default sampling params from the model
        tokenizer = tokenizer_factory(self._engine_config.model_name)
        self._default_sampling_params = SamplingParams()
        self._default_sampling_params._setup(tokenizer)
        self._default_sampling_params.stop = None

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.publish_kv_cache_events_thread = None
        self.publish_stats_thread = None

        self._event_thread.start()
        with self._llm_engine_start_cv:
            while self._llm_engine is None:
                self._llm_engine_start_cv.wait()

        # The 'threading.Thread()' will not raise the exception here should the engine
        # failed to start, so the exception is passed back via the engine variable.
        if isinstance(self._llm_engine, Exception):
            e = self._llm_engine
            logger.error(f"Failed to start engine: {e}")
            if self._event_thread is not None:
                self._event_thread.join()
                self._event_thread = None
            raise e

        try:
            if self._publish_stats:
                self._init_publish_metrics_thread()
198
199
200
        except Exception as e:
            logger.error(f"Failed to initialize publish metrics threads: {e}")
            raise e
201

202
        try:
203
            if self._publish_events:
204
205
                self._init_publish_kv_cache_events_thread()
        except Exception as e:
206
            logger.error(f"Failed to initialize publish events threads: {e}")
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            raise e

    def _init_publish_metrics_thread(self):
        # Need to publish stats once so that worker can be selected.
        # Publishing some dummy values...
        request_active_slots = 0
        request_total_slots = 4
        kv_active_block = 0
        kv_total_blocks = 4
        num_requests_waiting = 0
        gpu_cache_usage_perc = 0.0
        gpu_prefix_cache_hit_rate = 0.0

        num_requests_waiting = 0
        gpu_cache_usage_perc = 0.0
        gpu_prefix_cache_hit_rate = 0.0

        if self._kv_metrics_publisher is None:
            logger.error("KV metrics publisher not initialized!")
            return

        self._kv_metrics_publisher.publish(
            request_active_slots,
            request_total_slots,
            kv_active_block,
            kv_total_blocks,
            num_requests_waiting,
            gpu_cache_usage_perc,
            gpu_prefix_cache_hit_rate,
        )

        # Prepare threads for publishing stats but don't start them yet.
        # TRTLLM needs to start generating tokens first before stats
        # can be retrieved.
        self.publish_stats_thread = ManagedThread(
            self.publish_stats_task,
            error_queue=self._error_queue,
            name="publish_stats_thread",
        )

    def _init_publish_kv_cache_events_thread(self):
248
249
        if self._kv_event_publisher is None:
            logger.error("KV event publisher not initialized!")
250
251
            return

252
253
254
        # A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
        # It is used to prevent sending remove event to kv router since partial blocks are not stored.
        self._partial_block_hashes = set()
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        # Prepare threads for publishing kv cache events but don't start them yet.
        # TRTLLM needs to start generating tokens first before kv cache events
        # can be retrieved.
        self.publish_kv_cache_events_thread = ManagedThread(
            self.publish_kv_cache_events_task,
            error_queue=self._error_queue,
            name="publish_kv_cache_events_thread",
        )

    async def publish_stats_task(self):
        """
        Publish stats to the metrics publisher.
        """
        if self._llm_engine is None:
            logger.error("LLM engine not initialized!")
            return

        if self._kv_metrics_publisher is None:
            logger.error("KV metrics publisher not initialized!")
            return False

        stats = self._llm_engine.get_stats_async(timeout=5)
        async for stat in stats:
            request_active_slots = stat["numActiveRequests"]
            request_total_slots = stat["maxNumActiveRequests"]
            kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
            kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
            reused_blocks = stat["kvCacheStats"]["reusedBlocks"]
            freeNumBlocks = stat["kvCacheStats"]["freeNumBlocks"]
            allocTotalBlocks = stat["kvCacheStats"]["allocTotalBlocks"]
            allocNewBlocks = stat["kvCacheStats"]["allocNewBlocks"]
            # NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
            num_requests_waiting = (
                stat["numQueuedRequests"]
                + stat["inflightBatchingStats"]["numPausedRequests"]
            )
            gpu_cache_usage_perc = allocTotalBlocks / kv_total_blocks
            gpu_prefix_cache_hit_rate = stat["kvCacheStats"]["cacheHitRate"]

            logger.debug(
                f"Publishing stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}, num_requests_waiting: {num_requests_waiting}, reused_blocks: {reused_blocks}, freeNumBlocks: {freeNumBlocks}, allocTotalBlocks: {allocTotalBlocks}, allocNewBlocks: {allocNewBlocks}, gpu_cache_usage_perc: {gpu_cache_usage_perc}, gpu_prefix_cache_hit_rate: {gpu_prefix_cache_hit_rate}"
            )

            self._kv_metrics_publisher.publish(
                request_active_slots,
                request_total_slots,
                kv_active_block,
                kv_total_blocks,
                num_requests_waiting,
                gpu_cache_usage_perc,
                gpu_prefix_cache_hit_rate,
            )

        return True

    async def publish_kv_cache_events_task(self):
        """
        Publish kv cache events to the events publisher.
        """
        if self._llm_engine is None:
            logger.error("LLM engine not initialized!")
            return

        events = self._llm_engine.get_kv_cache_events_async(timeout=5)
320
321
322
323
        async for event in events:
            event_id = event["event_id"]
            data = event["data"]
            if data["type"] == "stored":
324
                parent_hash = _to_signed_i64(data["parent_hash"])
325
326
327
328
329
                token_ids = []
                num_block_tokens = []
                block_hashes = []
                for block in data["blocks"]:
                    token_num_in_block = len(block["tokens"])
330
                    block_hash = _to_signed_i64(block["block_hash"])
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
                    if token_num_in_block > self._kv_block_size:
                        logger.error(
                            f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
                        )
                        return
                    if token_num_in_block < self._kv_block_size:
                        logger.debug(
                            f"Early stop when block {block_hash} containing {token_num_in_block} tokens not equal to kv_block_size {self._kv_block_size}"
                        )
                        self._partial_block_hashes.add(block_hash)
                        break
                    num_block_tokens.append(token_num_in_block)
                    block_hashes.append(block_hash)
                    for token in block["tokens"]:
                        token_ids.append(int(token["token_id"]))

                # Note: Currently data does not have lora_id.
                # Using 0 as default value. If later data has
                # lora_id, we need to verify if this is correct.
                lora_id = data.get("lora_id", 0)
351
352
353
354

                logger.debug(
                    f"publish stored event: event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
                )
355
356
357
358
359
360
361
362
363
364
365
                self._kv_event_publisher.publish_stored(
                    event_id,
                    token_ids,
                    num_block_tokens,
                    block_hashes,
                    lora_id,
                    parent_hash,
                )
            elif data["type"] == "removed":
                block_hashes = []
                for block_hash in data["block_hashes"]:
366
                    block_hash = _to_signed_i64(block_hash)
367
368
369
                    if block_hash in self._partial_block_hashes:
                        logger.debug(
                            f"Skipping removing block hash {block_hash} since it is a partial block"
370
                        )
371
372
373
                        self._partial_block_hashes.remove(block_hash)
                        continue
                    block_hashes.append(block_hash)
374
375
376
377

                logger.debug(
                    f"publish removed event: event_id: {event_id}, block_hashes: {block_hashes}"
                )
378
                self._kv_event_publisher.publish_removed(event_id, block_hashes)
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        return True

    def _start_threads(self):
        if (
            self.publish_kv_cache_events_thread
            and not self.publish_kv_cache_events_thread.is_alive()
        ):
            # [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
            self._stats_loop = asyncio.get_running_loop()
            self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
            self.publish_kv_cache_events_thread.start()
            logger.debug("Started kv cache events thread")

        if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
            self._stats_loop = asyncio.get_running_loop()
            self.publish_stats_thread.set_loop(self._stats_loop)
            self.publish_stats_thread.start()
            logger.debug("Started stats thread")

    async def _run_llm_engine(self):
        # Counter to keep track of ongoing request counts.
        self._ongoing_request_count = 0

        @asynccontextmanager
        async def async_llm_wrapper():
            # Create LLM in a thread to avoid blocking
            loop = asyncio.get_running_loop()
            try:
                llm = await loop.run_in_executor(
                    None,
409
410
411
412
                    lambda: LLM(
                        model=self._engine_config.model_name,
                        **self._engine_config.to_dict(),
                    ),
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
                )
                yield llm
            finally:
                if "llm" in locals():
                    # Run shutdown in a thread to avoid blocking
                    await loop.run_in_executor(None, llm.shutdown)

        try:
            async with async_llm_wrapper() as engine:
                # Capture the engine event loop and make it visible to other threads.
                self._event_loop = asyncio.get_running_loop()

                # Signal the engine is started and make it visible to other threads.
                with self._llm_engine_start_cv:
                    self._llm_engine = engine
                    self._llm_engine_start_cv.notify_all()

                logger.info("Engine loaded and ready to serve...")

                # Wait for the engine shutdown signal.
                await self._llm_engine_shutdown_event.wait()

                # Stop the publishing threads
                if self.publish_stats_thread and self.publish_stats_thread.is_alive():
                    self.publish_stats_thread.stop()
                    self.publish_stats_thread.join()
                if (
                    self.publish_kv_cache_events_thread
                    and self.publish_kv_cache_events_thread.is_alive()
                ):
                    self.publish_kv_cache_events_thread.stop()
                    self.publish_kv_cache_events_thread.join()

                # Wait for the ongoing requests to complete.
                while self._ongoing_request_count > 0:
                    logger.info(
                        "Awaiting remaining {} requests".format(
                            self._ongoing_request_count
                        )
                    )
                    await asyncio.sleep(1)

                # Cancel all tasks in the event loop.
                for task in asyncio.all_tasks(loop=self._event_loop):
                    if task is not asyncio.current_task():
                        task.cancel()

        except Exception as e:
            # Signal and pass the exception back via the engine variable if the engine
            # failed to start. If the engine has started, re-raise the exception.
            with self._llm_engine_start_cv:
                if self._llm_engine is None:
                    self._llm_engine = e
                    self._llm_engine_start_cv.notify_all()
                    return
            raise e

        self._llm_engine = None
        logger.info("Shutdown complete")
472
473
474

    async def _get_remote_prefill_response(self, request):
        prefill_request = copy.deepcopy(request)
475
476
        # TRTLLM requires max_tokens to be set for prefill requests.
        prefill_request.stop_conditions.max_tokens = 1
477
478
479
480
481
482
483
        prefill_request.disaggregated_params = DisaggregatedParams(
            request_type=DisaggRequestType.CONTEXT_ONLY.value
        )

        if self._prefill_client is None:
            raise ValueError("Prefill client not initialized")

484
        # TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
485
486
487
488
489
490
491
492
493
494
495
496
497
        ctx_responses = [
            ctx_response
            async for ctx_response in await self._prefill_client.round_robin(
                prefill_request.model_dump_json()
            )
        ]
        if len(ctx_responses) > 1:
            raise ValueError(
                "Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
            )
        logger.debug(
            f"Received response from prefill worker: {ctx_responses[0].data()}"
        )
498
499
        remote_prefill_response = ctx_responses[0]
        return remote_prefill_response
500

501
    async def generate(self, request):
502
503
504
505
506
507
508
509
510
        if self._llm_engine is None:
            raise RuntimeError("Engine not initialized")

        if not self._error_queue.empty():
            raise self._error_queue.get()

        self._ongoing_request_count += 1

        try:
511
            worker_inputs = request.token_ids
512
513
514
515
516
517
518

            disaggregated_params = (
                DisaggregatedTypeConverter.to_llm_disaggregated_params(
                    request.disaggregated_params
                )
            )

519
            num_output_tokens_so_far = 0
520

521
522
523
524
525
526
527
528
529
530
            if self._remote_prefill and self._server_type == ServerType.GEN:
                ctx_response = await self._get_remote_prefill_response(request)
                remote_prefill_response = ctx_response.data()
                if (
                    remote_prefill_response["finish_reason"] == "stop"
                    or remote_prefill_response["finish_reason"] == "error"
                ):
                    yield remote_prefill_response
                    return
                num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
531

532
                # Decode the disaggregated params from the remote prefill response
533
534
535
                disaggregated_params = (
                    DisaggregatedTypeConverter.to_llm_disaggregated_params(
                        DisaggregatedParams(
536
                            **remote_prefill_response["disaggregated_params"]
537
538
539
                        )
                    )
                )
540
541
542
543
544
545

                # Send the first token response to the client
                first_token_response = remote_prefill_response
                first_token_response.pop("disaggregated_params")
                yield first_token_response

546
547
548
549
550
551
552
553
                disaggregated_params.request_type = (
                    DisaggRequestType.GENERATION_ONLY.value
                )

            logger.debug(
                f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
            )

554
555
556
557
558
559
560
            sampling_params = get_sampling_params(
                request.sampling_options.dict(), self._default_sampling_params
            )
            max_tokens = request.stop_conditions.max_tokens
            if max_tokens:
                sampling_params.max_tokens = max_tokens

561
562
563
564
            async for response in self._llm_engine.generate_async(
                inputs=worker_inputs,
                sampling_params=sampling_params,
                disaggregated_params=disaggregated_params,
565
                streaming=self._server_type != ServerType.CTX,
566
            ):
567
568
569
570
571
572
573
                if response.finished and self._server_type != ServerType.CTX:
                    yield {"finish_reason": "stop", "token_ids": []}
                    break

                if not response.outputs:
                    yield {"finish_reason": "error", "token_ids": []}
                    break
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
                output = response.outputs[0]
                next_total_toks = len(output.token_ids)
                out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
                if output.finish_reason:
                    out["finish_reason"] = output.finish_reason
                if output.stop_reason:
                    out["stop_reason"] = output.stop_reason
                if self._server_type == ServerType.CTX:
                    # Return the disaggregated params only when operating in prefill mode.
                    out[
                        "disaggregated_params"
                    ] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
                        output.disaggregated_params
                    ).dict()

                yield out
                num_output_tokens_so_far = next_total_toks
592
593
594
595
596
597
598
599

        except CppExecutorError:
            signal.raise_signal(signal.SIGINT)
        except Exception as e:
            raise RuntimeError("Failed to generate: " + str(e))

        self._start_threads()
        self._ongoing_request_count -= 1