handler_base.py 15.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import asyncio
17
import copy
18
import logging
19
import os
20
from contextlib import asynccontextmanager
21
22
from dataclasses import asdict, dataclass
from enum import Enum
23
from typing import Any, AsyncGenerator, Optional, Union
24

25
import torch
26
from tensorrt_llm.executor.result import GenerationResult
27
from tensorrt_llm.executor.utils import RequestError
28
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
29
from tensorrt_llm.llmapi.llm import SamplingParams
30

31
from dynamo._core import Context
32
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
33
from dynamo.nixl_connect import Connector
34
from dynamo.runtime import DistributedRuntime
35
from dynamo.runtime.logging import configure_dynamo_logging
36
from dynamo.trtllm.engine import TensorRTLLMEngine
37
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
38
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
39
from dynamo.trtllm.publisher import Publisher
40
41
42
43
from dynamo.trtllm.utils.disagg_utils import (
    DisaggregatedParams,
    DisaggregatedParamsCodec,
)
44
45
46
47
48
49
50
51

configure_dynamo_logging()


class DisaggregationMode(Enum):
    AGGREGATED = "prefill_and_decode"
    PREFILL = "prefill"
    DECODE = "decode"
52
    ENCODE = "encode"
53
54
55
56
57
58
59
60
61
62
63
64
65


@dataclass
class RequestHandlerConfig:
    """
    Configuration for the request handler
    """

    component: object
    engine: TensorRTLLMEngine
    default_sampling_params: SamplingParams
    publisher: Publisher
    disaggregation_mode: DisaggregationMode
66
    encode_client: Optional[object] = None
67
68
69
    multimodal_processor: Optional[
        MultimodalRequestProcessor
    ] = None  # for multimodal support
70
    connector: Optional[Connector] = None
71
72
73
    runtime: Optional[
        DistributedRuntime
    ] = None  # DistributedRuntime reference for graceful shutdown
74
    metrics_collector: Optional[Any] = None  # TensorRT-LLM MetricsCollector
75
    kv_block_size: int = 32
76
77
78
79
80
81
82
83
84
85
86
87


class HandlerBase:
    """
    Base class for request handlers.
    """

    def __init__(self, config: RequestHandlerConfig):
        self.engine = config.engine
        self.component = config.component
        self.default_sampling_params = config.default_sampling_params
        self.publisher = config.publisher
88
        self.metrics_collector = config.metrics_collector
89
        self.disaggregation_mode = config.disaggregation_mode
90
        self.encode_client = config.encode_client
91
        self.multimodal_processor = config.multimodal_processor
92
        self.first_generation = True
93
        self.connector = config.connector
94
95
        # Store runtime reference for graceful shutdown
        self.runtime = config.runtime
96
        self.kv_block_size: int = config.kv_block_size
97
98
99
100
101
102
103
104
105
106
107
108

    def check_error(self, result: dict):
        """
        Check if there is an error in the result.
        """
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            return result["finish_reason"] == "error"
        else:
            return (
                result["finish_reason"] == "stop" or result["finish_reason"] == "error"
            )

109
110
111
    async def _handle_cancellation(
        self, generation_result: GenerationResult, context: Context
    ):
112
113
114
115
        """Background task to handle cancellation by monitoring context state."""
        try:
            # Wait asynchronously for cancellation signal instead of polling
            await context.async_killed_or_stopped()
116
117
118
            # Abort the generation
            generation_result.abort()
            logging.debug(f"Aborted Request ID: {context.id()}")
119
120
121
122
123
124
        except asyncio.CancelledError:
            # Task was cancelled, which is expected when generation completes
            pass

    @asynccontextmanager
    async def _cancellation_monitor(
125
        self, generation_result: GenerationResult, context: Context
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    ) -> AsyncGenerator[asyncio.Task, None]:
        """
        Context manager for monitoring request cancellation.

        Automatically creates a background task to monitor for cancellation and
        cleans it up when the context exits.

        Yields:
            asyncio.Task: The cancellation monitoring task
        """
        cancellation_task = asyncio.create_task(
            self._handle_cancellation(generation_result, context)
        )

        try:
            yield cancellation_task
        finally:
            # Clean up the background cancellation task
            if not cancellation_task.done():
                cancellation_task.cancel()
                try:
                    await cancellation_task
                except asyncio.CancelledError:
                    pass

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    async def _initiate_shutdown(self, error: Exception):
        """Initiate graceful shutdown after fatal error"""
        logging.warning(f"Initiating graceful shutdown due to: {error}")

        try:
            if self.runtime:
                logging.info("Shutting down Dynamo runtime...")
                self.runtime.shutdown()

            if self.engine:
                logging.info("Shutting down TensorRT-LLM engine...")
                await self.engine.cleanup()
        except Exception as cleanup_error:
            logging.error(f"Error during graceful shutdown: {cleanup_error}")
        finally:
            logging.critical("Forcing process exit for restart")
            os._exit(1)

169
    async def generate_locally(
170
171
172
173
        self,
        request: dict,
        context: Context,
        embeddings: Optional[Union[torch.Tensor, dict]] = None,
174
    ):
175
176
        """
        Generate responses based on the disaggregation mode in the request.
177
178
179

        Args:
            request: The request dictionary containing generation parameters
180
            context: Context object for cancellation handling
181
            embeddings: Optional tensor or dict containing embeddings for multimodal processing
182
183
184
        """
        logging.debug(f"Request: {request}")

185
186
187
188
189
190
191
        # Default to text-based input. This will be overwritten if multimodal
        # content is found and processed.
        processed_input = None

        # Check for multimodal request and process it
        if self.multimodal_processor:
            processed_input = await self.multimodal_processor.process_openai_request(
192
                request, embeddings
193
194
195
196
197
198
            )

        else:
            # text-only flow
            processed_input = request.get("token_ids")

199
200
201
202
203
204
205
206
207
        # Check if there is an error in the publisher error queue
        publishers_error = (
            self.publisher.check_error_queue() if self.publisher else None
        )
        if publishers_error:
            raise publishers_error

        # Decode the disaggregated params from the request
        disaggregated_params = None
208

209
210
211
212
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            request["stop_conditions"]["max_tokens"] = 1
            disaggregated_params = LlmDisaggregatedParams(request_type="context_only")

213
        if "prefill_result" in request:
214
215
216
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                raise ValueError("Cannot provide disaggregated_params in prefill mode")
            disaggregated_params = DisaggregatedParamsCodec.decode(
217
218
219
                DisaggregatedParams(
                    **request["prefill_result"].get("disaggregated_params")
                )
220
221
222
223
224
225
226
227
228
229
230
            )
            disaggregated_params.request_type = "generation_only"

        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
            and disaggregated_params is None
        ):
            raise ValueError("Disaggregated params are required for decode mode")

        num_output_tokens_so_far = 0

231
        sampling_params = copy.deepcopy(self.default_sampling_params)
232

233
234
235
236
237
238
239
240
241
242
        for key, value in request["sampling_options"].items():
            if not value:
                continue
            if hasattr(sampling_params, key):
                setattr(sampling_params, key, value)

        max_tokens = request["stop_conditions"]["max_tokens"]
        if max_tokens:
            sampling_params.max_tokens = max_tokens

243
        ignore_eos = request["stop_conditions"].get("ignore_eos")
244
245
246
        if ignore_eos:
            sampling_params.ignore_eos = ignore_eos

247
        min_tokens = request["stop_conditions"].get("min_tokens")
248
249
250
        if min_tokens:
            sampling_params.min_tokens = min_tokens

251
252
253
254
255
        stop_token_ids = request["stop_conditions"].get("stop_token_ids_hidden")
        if stop_token_ids:
            existing = sampling_params.stop_token_ids or []
            sampling_params.stop_token_ids = list(set(existing).union(stop_token_ids))

256
257
258
259
260
261
        # TODO: Instead of True, we should use streaming from the request.
        # However, currently dynamo run does not send streaming in the request.
        streaming = (
            False if self.disaggregation_mode == DisaggregationMode.PREFILL else True
        )

262
263
        request_id = request.get("id") or request.get("request_id", "unknown-id")

264
265
266
267
268
269
        # Optional test-only logits processing (enable with DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR=1)
        if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
            processors = [HelloWorldLogitsProcessor(self.engine.llm.tokenizer)]
            adapters = create_trtllm_adapters(processors)
            sampling_params.logits_processor = adapters

270
271
272
273
274
        prefill_result = request.get("prefill_result")
        prefill_prompt_tokens_details = (
            prefill_result.get("prompt_tokens_details") if prefill_result else None
        )

275
276
277
278
279
280
281
282
        try:
            # NEW: Updated engine call to include multimodal data
            generation_result = self.engine.llm.generate_async(
                inputs=processed_input,  # Use the correctly extracted inputs
                sampling_params=sampling_params,
                disaggregated_params=disaggregated_params,
                streaming=streaming,
            )
283

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            # Use the context manager to handle cancellation monitoring
            async with self._cancellation_monitor(generation_result, context):
                async for res in generation_result:
                    # TRTLLM engine needs to start generating tokens first before stats
                    # can be retrieved.
                    if self.first_generation and self.publisher:
                        self.publisher.start()
                        self.first_generation = False

                    # If we are not done generating, but there are no outputs, return an error
                    if not res.outputs and not res.finished:
                        yield {"finish_reason": "error", "token_ids": []}
                        break

                    output = res.outputs[0]
                    # The engine returns all tokens generated so far. We must calculate the new
                    # tokens generated in this iteration to create the "delta".
                    next_total_toks = len(output.token_ids)
302
303
304

                    out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}

305
306
307
308
309
310
311
312
                    if output.finish_reason:
                        out["finish_reason"] = output.finish_reason
                    if output.stop_reason:
                        out["stop_reason"] = output.stop_reason
                    if self.disaggregation_mode == DisaggregationMode.PREFILL:
                        # Return the disaggregated params only when operating in prefill mode.
                        out["disaggregated_params"] = asdict(
                            DisaggregatedParamsCodec.encode(output.disaggregated_params)
313
                        )
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                    if out.get("finish_reason"):
                        num_input_tokens = len(request.get("token_ids", []))

                        prompt_tokens_details = None
                        if prefill_prompt_tokens_details:
                            prompt_tokens_details = prefill_prompt_tokens_details
                        else:
                            if output.request_perf_metrics is not None:
                                kv_cache_metrics = (
                                    output.request_perf_metrics.kv_cache_metrics
                                )
                                cached_tokens = min(
                                    num_input_tokens,
                                    kv_cache_metrics.num_reused_blocks
                                    * self.kv_block_size,
                                )
                                if cached_tokens > 0:
                                    prompt_tokens_details = {
                                        "cached_tokens": int(cached_tokens),
                                    }

                        out["completion_usage"] = {
                            "prompt_tokens": int(num_input_tokens),
                            "completion_tokens": int(next_total_toks),
                            "total_tokens": int(num_input_tokens + next_total_toks),
                            "prompt_tokens_details": prompt_tokens_details,
                        }

343
344
345
346
347
348
                    if res.finished and not out.get("finish_reason"):
                        out["finish_reason"] = "unknown"
                        logging.warning(
                            "Request finished with no finish reason set - this indicates a possible bug"
                        )

349
350
351
352
353
354
355
356
357
358
359
                    # Log metrics to TensorRT-LLM MetricsCollector when request finishes
                    if (
                        res.finished
                        and self.metrics_collector
                        and hasattr(res, "metrics_dict")
                    ):
                        try:
                            self.metrics_collector.log_metrics_dict(res.metrics_dict)
                        except Exception as e:
                            logging.warning(f"Failed to log TensorRT-LLM metrics: {e}")

360
361
362
363
364
365
366
367
368
369
370
371
                    # Yield the chunk to the client and update the token count for the next iteration.
                    yield out
                    num_output_tokens_so_far = next_total_toks

        # 1. Client cancellation - don't shutdown
        except asyncio.CancelledError:
            logging.debug(f"Request {request_id}: Client cancelled")
            # _cancellation_monitor already called abort_request
            return  # Just stop, no error response

        # 2. Per-request errors - send to client, don't shutdown
        except RequestError as e:
372
373
374
375
376
377
            error_msg = str(e)
            logging.warning(f"Request {request_id} error: {error_msg}")
            yield {
                "finish_reason": {"error": error_msg},
                "token_ids": [],
            }
378
379
380
381
382
383
384
385
386
387
388
389
390

        # 3. ALL OTHER ERRORS - graceful shutdown
        except Exception as e:
            error_type = type(e).__name__
            error_msg = str(e)
            logging.error(
                f"Fatal {error_type} in request {request_id}: {error_msg}",
                exc_info=True,
            )

            # Try to send error to client before shutdown
            try:
                yield {
391
                    "finish_reason": {"error": error_msg},
392
393
394
395
396
397
398
                    "token_ids": [],
                }
            except Exception:
                pass  # Best effort

            # Initiate graceful shutdown
            await self._initiate_shutdown(e)