aggregated_handler.py 2.34 KB
Newer Older
1
2
3
4
5
6
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""

import logging
7
from collections.abc import AsyncGenerator
jh-nv's avatar
jh-nv committed
8
9
10
from typing import Optional, Union

import torch
11
12

from dynamo._core import Context
13
14
15
from dynamo.common.memory.multimodal_embedding_cache_manager import (
    MultimodalEmbeddingCacheManager,
)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.request_handlers.handler_base import (
    HandlerBase,
    RequestHandlerConfig,
)


class AggregatedHandler(HandlerBase):
    """
    Handler for aggregated mode (prefill + decode in single worker).

    Supports optional encoder disaggregation (E_PD flow) when encode_client
    and encoder_cache are configured.
    """

    def __init__(
        self,
        config: RequestHandlerConfig,
34
        encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
35
36
37
38
    ):
        super().__init__(config)
        self._encoder_cache = encoder_cache

39
40
41
    async def generate(
        self, request: dict, context: Context
    ) -> AsyncGenerator[dict, None]:
42
43
44
        """Generate response, optionally using remote encoder for multimodal."""
        logging.debug(f"AggregatedHandler Request ID: {context.id()}")

jh-nv's avatar
jh-nv committed
45
        embeddings: Optional[Union[torch.Tensor, dict]] = None
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        ep_disaggregated_params = None
        if self.multimodal_processor and self.encode_client:
            messages = request.get("extra_args", {}).get(
                "messages", request.get("messages", [])
            )
            _, image_urls, _ = self.multimodal_processor.extract_prompt_and_media(
                messages
            )
            if image_urls:
                result = await fetch_embeddings_from_encoder(
                    image_urls,
                    request,
                    self.encode_client,
                    self._encoder_cache,
                )
                if isinstance(result, list):
jh-nv's avatar
jh-nv committed
62
                    embeddings = result  # type: ignore[assignment]
63
64
65
66
67
68
69
                else:
                    ep_disaggregated_params = result

        async for res in self.generate_locally(
            request, context, embeddings, ep_disaggregated_params
        ):
            yield res