handler.py 4.06 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
"""MM Router Handler — routes multimodal requests via KV-cache-aware worker selection."""
5
6
7
8

import logging
from typing import Any, AsyncGenerator

9
from dynamo.common.multimodal.image_loader import ImageLoader
10
11
12
13
14
15
16
17
18
19
from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging

from .mm_processor import build_block_mm_infos, extract_image_urls, process_multimodal

configure_dynamo_logging()
logger = logging.getLogger(__name__)


class MMRouterHandler:
20
    """Routes requests to the vLLM worker with the best KV cache overlap."""
21
22
23
24
25
26
27
28
29
30
31
32
33
34

    def __init__(
        self,
        kv_router: KvRouter,
        tokenizer: Any,
        processor: Any,
        model: str,
        block_size: int,
    ):
        self.kv_router = kv_router
        self.tokenizer = tokenizer
        self.processor = processor
        self.model = model
        self.block_size = block_size
35
        self._image_loader = ImageLoader()
36
37

    async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
38
        """Main entry point: process request, compute routing, forward to best worker."""
39
40
41
42
        messages = request.get("extra_args", {}).get("messages", [])
        image_urls = extract_image_urls(messages)

        if image_urls:
43
44
            routing_tokens, block_mm_infos = await self._process_mm_request(
                request, messages, image_urls
45
46
            )
        else:
47
48
49
50
51
            routing_tokens = request.get("token_ids")
            if not routing_tokens:
                raise ValueError("Missing token_ids in preprocessed request")
            n_blocks = (len(routing_tokens) + self.block_size - 1) // self.block_size
            block_mm_infos = [None] * n_blocks
52
53

        stream = await self.kv_router.generate(
54
            token_ids=request.get("token_ids"),
55
56
57
58
59
60
61
            model=request["model"],
            stop_conditions=request.get("stop_conditions"),
            sampling_options=request.get("sampling_options"),
            output_options=request.get("output_options"),
            router_config_override=request.get("router_config_override"),
            extra_args=request.get("extra_args"),
            multi_modal_data=request.get("multi_modal_data"),
62
63
64
65
            mm_routing_info={
                "routing_token_ids": routing_tokens,
                "block_mm_infos": block_mm_infos,
            },
66
67
68
        )
        async for response in stream:
            yield response
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    async def _process_mm_request(
        self,
        request: dict,
        messages: list[dict],
        image_urls: list[str],
    ) -> tuple[list[int], list[dict | None]]:
        """Process multimodal: load images, expand tokens, build routing info."""
        processed = await process_multimodal(
            messages=messages,
            image_urls=image_urls,
            tokenizer=self.tokenizer,
            processor=self.processor,
            model=self.model,
            image_loader=self._image_loader,
        )

        # Strip image content from messages to reduce serialization payload
        for msg in messages:
            content = msg.get("content", [])
            if isinstance(content, list):
                for part in content:
                    if part.get("type") == "image_url":
                        part["image_url"]["url"] = "<stripped>"

        # Rewrite Url → RawUrl to skip url::Url::parse in Rust depythonize
        mm_data = request.get("multi_modal_data", {})
        if isinstance(mm_data, dict):
            for item in mm_data.get("image_url", []):
                if isinstance(item, dict) and "Url" in item:
                    item["RawUrl"] = item.pop("Url")

        block_mm_infos = build_block_mm_infos(
            num_tokens=len(processed.tokens),
            block_size=self.block_size,
            mm_hashes=processed.mm_hashes,
            image_ranges=processed.image_ranges,
        )
        if block_mm_infos is None:
            raise ValueError("Failed to build block_mm_infos")

        return processed.tokens, block_mm_infos