"deploy/operator/Makefile" did not exist on "e28ff8d2b1cfd433a4bed58461a9b13c0c9345dc"
Unverified Commit 815b1291 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: mm aware routing for vllm (#6235)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
parent 0abebe38
......@@ -37,6 +37,7 @@ from dynamo.llm import (
from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.image_loader import ImageLoader
# Multimodal data dictionary keys
......@@ -48,6 +49,27 @@ DECODED_VARIANT_KEY: Final = "Decoded"
configure_dynamo_logging()
logger = logging.getLogger(__name__)
def _compute_mm_uuids(
multi_modal_data: Dict[str, Any] | None
) -> Dict[str, list[str]] | None:
"""
Compute multi_modal_uuids from multi_modal_data.
Each image gets a SHA256 hex digest as its UUID, ensuring consistent
hashing across the MM Router, vLLM handler, and Rust KV publisher.
"""
if not multi_modal_data or "image" not in multi_modal_data:
return None
images = multi_modal_data["image"]
if not isinstance(images, list):
images = [images]
if not images:
return None
uuids = compute_mm_uuids_from_images(images)
return {"image": uuids}
# LoRAManager singleton - initialized lazily when DYN_LORA_ENABLED is set
# None = not yet initialized, False = disabled/failed, LoRAManager = initialized
_lora_manager = None
......@@ -1010,12 +1032,17 @@ class BaseWorkerHandler(ABC):
"token_ids": [],
},
)
else:
# Normal path: use token IDs
prompt = TokensPrompt(
prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data
)
return prompt, embedding_sequence_length, None
# Normal path: use token IDs
mm_uuids = _compute_mm_uuids(multi_modal_data)
prompt_kwargs = dict[str, Any](
prompt_token_ids=request["token_ids"],
multi_modal_data=multi_modal_data,
)
if mm_uuids is not None:
prompt_kwargs["multi_modal_uuids"] = mm_uuids
prompt = TokensPrompt(**prompt_kwargs)
return prompt, embedding_sequence_length, None
@staticmethod
def _build_completion_usage(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import hashlib
import io
import logging
from typing import Any, Sequence
logger = logging.getLogger(__name__)
def image_to_bytes(img: Any) -> bytes:
"""Convert a supported image object to PNG bytes for hashing."""
from PIL import Image
if isinstance(img, bytes):
return img
if isinstance(img, Image.Image):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# Frontend-decoding can provide image tensors as numpy arrays.
try:
import numpy as np
if isinstance(img, np.ndarray):
pil_img = Image.fromarray(img)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
return buf.getvalue()
except ImportError:
pass
raise TypeError(f"Unsupported image type for hashing: {type(img)}")
def compute_mm_uuids_from_images(images: Sequence[Any]) -> list[str]:
"""
Compute SHA256 hex UUIDs for image inputs.
"""
uuids: list[str] = []
for img in images:
raw_bytes = image_to_bytes(img)
uuids.append(hashlib.sha256(raw_bytes).hexdigest())
return uuids
......@@ -49,6 +49,7 @@ class TestVllmKvEventsApi:
5. lora_id
6. medium
7. lora_name (added in vLLM 0.14.0)
8. extra_keys (contains MM info, cache_salt, etc.)
If vLLM adds/removes/reorders fields, this test will fail.
"""
......@@ -60,6 +61,7 @@ class TestVllmKvEventsApi:
"lora_id",
"medium",
"lora_name",
"extra_keys",
)
actual_fields = BlockStored.__struct_fields__
......@@ -146,6 +148,7 @@ class TestVllmKvEventsApi:
lora_id=None,
medium="GPU",
lora_name=None,
extra_keys=None,
)
encoded = msgspec.msgpack.encode(event)
......@@ -157,9 +160,9 @@ class TestVllmKvEventsApi:
decoded[0] == "BlockStored"
), f"Expected tag 'BlockStored', got {decoded[0]}"
# Verify field count (tag + 7 fields = 8 elements)
assert len(decoded) == 8, (
f"Expected 8 elements (tag + 7 fields), got {len(decoded)}.\n"
# Verify field count (tag + 8 fields = 9 elements)
assert len(decoded) == 9, (
f"Expected 9 elements (tag + 8 fields), got {len(decoded)}.\n"
f"Decoded: {decoded}\n"
f"If field count changed, update Rust deserializers."
)
......@@ -172,3 +175,4 @@ class TestVllmKvEventsApi:
assert decoded[5] is None, f"lora_id at wrong position: {decoded[5]}"
assert decoded[6] == "GPU", f"medium at wrong position: {decoded[6]}"
assert decoded[7] is None, f"lora_name at wrong position: {decoded[7]}"
assert decoded[8] is None, f"extra_keys at wrong position: {decoded[8]}"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for vLLM.
This worker sits between the frontend and vLLM workers, computing mm_hash
for images and routing requests to the worker with the best KV cache overlap.
"""
from .handler import MMRouterHandler
from .mm_processor import ProcessedInput, build_block_mm_infos
__all__ = [
"MMRouterHandler",
"ProcessedInput",
"build_block_mm_infos",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Entry point for running MM Router Worker as a module."""
from .mm_router_worker import main
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Handler - Routes requests to best vLLM worker based on KV cache overlap.
"""
import logging
from typing import Any, AsyncGenerator
from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging
from .mm_processor import build_block_mm_infos, extract_image_urls, process_multimodal
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class MMRouterHandler:
"""
Handler that computes mm_hash for multimodal requests and routes
to the best vLLM worker based on KV cache overlap.
"""
def __init__(
self,
kv_router: KvRouter,
tokenizer: Any,
processor: Any,
model: str,
block_size: int,
):
"""
Initialize the MM Router Handler.
Args:
kv_router: KvRouter for KV-aware worker selection and routing
tokenizer: HuggingFace AutoTokenizer
processor: HuggingFace AutoProcessor (optional)
model: Model path/name
block_size: KV cache block size
"""
self.kv_router = kv_router
self.tokenizer = tokenizer
self.processor = processor
self.model = model
self.block_size = block_size
async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
"""
Main entry point - receives request, computes routing, forwards to best worker.
The request format (after Frontend preprocessing with ModelInput.Tokens):
{
"token_ids": [...],
"sampling_options": {...},
"stop_conditions": {...},
"extra_args": {"messages": [...]}
}
Args:
request: Preprocessed request from Frontend
Yields:
Response chunks from the downstream vLLM worker
"""
# Extract messages from extra_args (set by Frontend preprocessor)
messages = request.get("extra_args", {}).get("messages", [])
image_urls = extract_image_urls(messages)
if image_urls:
# Process multimodal: download images, compute mm_hash
# Do not reuse request["token_ids"] for MM routing: those are placeholder-level
# tokens from frontend. We need processor-expanded tokens to build block_mm_infos.
# Request payload does not include a rendered template string; extra_args carries
# original messages, so mm_processor reapplies chat template locally.
processed = process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
)
# Build block_mm_infos for MM-aware hash computation
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError(
"Failed to build block_mm_infos for multimodal request"
)
routing_tokens = processed.tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"MM request: {len(routing_tokens)} routing tokens, "
f"{len(image_urls)} images, {routing_blocks} routing blocks"
)
else:
# Text-only: rely on frontend-preprocessed token_ids (ModelInput.Tokens contract)
tokens = request.get("token_ids")
if not tokens:
raise ValueError(
"Missing or empty token_ids in preprocessed request for text-only routing"
)
routing_tokens = tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"Text request: {len(routing_tokens)} routing tokens, {routing_blocks} routing blocks"
)
# Text-only routing has no multimodal objects; provide per-block None entries.
block_mm_infos = [None] * routing_blocks
# Route and generate through KvRouter with explicit fields.
# We pass:
# - execution payload (token_ids + multi_modal_data)
# - routing payload (mm_routing_info: routing_token_ids + block_mm_infos)
# so generate() can select worker internally while preserving MM correctness.
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError("Missing or empty token_ids in preprocessed request")
mm_routing_info: dict[str, Any] = {
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
}
stream = await self.kv_router.generate(
token_ids=token_ids,
model=request["model"],
stop_conditions=request.get("stop_conditions"),
sampling_options=request.get("sampling_options"),
output_options=request.get("output_options"),
router_config_override=request.get("router_config_override"),
extra_args=request.get("extra_args"),
multi_modal_data=request.get("multi_modal_data"),
mm_routing_info=mm_routing_info,
)
async for response in stream:
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Multimodal processing utilities for vLLM MM Router Worker.
Key differences from TRT-LLM version:
- Image loading: PIL + requests/base64 (no TRT-LLM dependency)
- mm_hash: SHA256 of normalized PNG bytes (matches vLLM multi_modal_uuids)
- Token replacement: NOT needed — vLLM keeps the original image_token_id as-is
"""
import base64
import logging
from dataclasses import dataclass
from io import BytesIO
from typing import Any
from urllib.parse import urlparse
import requests
from PIL import Image
from dynamo.vllm.multimodal_utils.hash_utils import compute_mm_uuids_from_images
logger = logging.getLogger(__name__)
# =============================================================================
# Data structures
# =============================================================================
@dataclass
class ProcessedInput:
"""Processed multimodal input."""
tokens: list[int]
mm_hashes: list[int] | None
image_ranges: list[tuple[int, int]] | None # [(start, end), ...] per image
# =============================================================================
# Public functions
# =============================================================================
def extract_image_urls(messages: list[dict]) -> list[str]:
"""Extract image URLs from OpenAI-format messages."""
urls = []
for msg in messages:
content = msg.get("content", [])
if isinstance(content, list):
for part in content:
if part.get("type") == "image_url":
url = part.get("image_url", {}).get("url")
if url:
urls.append(url)
return urls
def process_multimodal(
messages: list[dict],
image_urls: list[str],
tokenizer: Any,
processor: Any,
model: str,
) -> ProcessedInput:
"""
Process multimodal request: load images, get expanded tokens and mm_hashes.
Uses PIL for image loading and hashlib for mm_hash computation.
Unlike TRT-LLM, vLLM keeps original image_token_id (no replacement).
"""
# The preprocessed request does not carry a rendered template string; it carries
# original messages in extra_args, so we must apply chat template again here.
prompt = _build_prompt_with_images(messages, tokenizer, processor)
logger.info(f"Prompt (first 300 chars): {prompt[:300]}")
# Load images as PIL
pil_images = []
for url in image_urls:
pil_img = _load_image(url)
pil_images.append(pil_img)
# Get expanded tokens and image ranges (no token replacement for vLLM)
tokens, image_ranges = _get_expanded_tokens(
prompt, pil_images, tokenizer, processor
)
logger.info(f"Expanded: {len(tokens)} tokens, " f"image_ranges={image_ranges}")
# Compute mm_hashes exactly like vLLM handler's multi_modal_uuids path.
mm_uuids = compute_mm_uuids_from_images(pil_images)
mm_hashes = [int(uuid[:16], 16) for uuid in mm_uuids]
logger.info(f"mm_hashes={mm_hashes}")
return ProcessedInput(tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges)
def build_block_mm_infos(
num_tokens: int,
block_size: int,
mm_hashes: list[int] | None,
image_ranges: list[tuple[int, int]] | None,
) -> list[dict | None] | None:
"""
Build per-block mm_info for routing.
For each block, check which images overlap with it and add their mm_hash.
Assumption: mm_hashes and image_ranges are in the same order as images appear
in the request (which matches their order in the token sequence).
"""
if not mm_hashes or not image_ranges or len(mm_hashes) != len(image_ranges):
return None
num_blocks = (num_tokens + block_size - 1) // block_size
result = []
for block_idx in range(num_blocks):
block_start = block_idx * block_size
block_end = block_start + block_size
# Find images overlapping this block
mm_objects = [
{"mm_hash": mm_hash, "offsets": []}
for mm_hash, (img_start, img_end) in zip(mm_hashes, image_ranges)
if block_end > img_start and block_start < img_end
]
result.append({"mm_objects": mm_objects} if mm_objects else None)
return result
# =============================================================================
# Internal functions
# =============================================================================
def _build_prompt_with_images(
messages: list[dict], tokenizer: Any, processor: Any
) -> str:
"""
Build a prompt that includes image placeholders using the tokenizer's
chat template. This is critical for Qwen2-VL/Qwen2.5-VL models which
need <|vision_start|><|image_pad|>...<|vision_end|> in the prompt for
the processor to expand image tokens correctly.
Raises if chat template cannot be applied. For MM routing correctness, we do
not silently fall back to text-only prompts.
"""
# Try processor first (has the best chat template for multimodal)
if processor is not None and hasattr(processor, "apply_chat_template"):
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Fall back to tokenizer if available
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
raise ValueError("Neither processor nor tokenizer provides apply_chat_template")
def _load_image(url: str) -> Image.Image:
"""
Load an image from URL (http/https or data URI) and return a PIL RGB image.
"""
parsed = urlparse(url)
if parsed.scheme == "data":
# data:image/png;base64,<data>
_, data = parsed.path.split(",", 1)
raw_bytes = base64.b64decode(data)
elif parsed.scheme in ("http", "https"):
response = requests.get(url, timeout=30)
response.raise_for_status()
raw_bytes = response.content
else:
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
return Image.open(BytesIO(raw_bytes)).convert("RGB")
def _get_expanded_tokens(
prompt: str,
pil_images: list[Image.Image],
tokenizer: Any,
processor: Any,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""
Get tokens with visual expansion and find each image's token range.
Unlike TRT-LLM, vLLM keeps the original image_token_id (no replacement).
"""
if processor is None:
return tokenizer.encode(prompt), None
try:
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
# Get image_token_id from processor
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
raise ValueError("processor.image_token_id not found")
# Find contiguous image token ranges (NO replacement for vLLM)
contiguous_ranges = _find_image_token_ranges(tokens, image_token_id)
# Compute tokens per image from processor output
tokens_per_image = _compute_tokens_per_image(output, processor)
# Split ranges according to tokens_per_image
image_ranges = _compute_per_image_ranges(contiguous_ranges, tokens_per_image)
return tokens, image_ranges
except Exception as e:
logger.warning(f"HF processor failed: {e}", exc_info=True)
return tokenizer.encode(prompt), None
def _compute_tokens_per_image(processor_output: dict, processor: Any) -> list[int]:
"""
Compute the number of visual tokens for each image from processor output.
Only Qwen-style processors (Qwen2-VL, Qwen2.5-VL) are supported.
Other model families will raise ValueError.
"""
processor_cls = type(processor).__qualname__
if "qwen" not in processor_cls.lower():
raise NotImplementedError(
f"_compute_tokens_per_image only supports Qwen-style processors "
f"tuples. Got processor class: {processor_cls}"
)
grid_thw = processor_output.get("image_grid_thw")
if grid_thw is None:
raise ValueError("image_grid_thw not found in processor output")
merge_size = getattr(processor.image_processor, "merge_size", 2)
return [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
def _find_image_token_ranges(
tokens: list[int], image_token_id: int
) -> list[tuple[int, int]]:
"""
Find all contiguous ranges of image tokens.
Unlike the TRT-LLM version, this does NOT replace tokens — vLLM keeps
the original image_token_id as-is in KV events.
Returns: list of (start, end) ranges for contiguous image token regions.
"""
ranges = []
start = None
for i, t in enumerate(tokens):
if t == image_token_id:
if start is None:
start = i
elif start is not None:
ranges.append((start, i))
start = None
if start is not None:
ranges.append((start, len(tokens)))
if ranges:
logger.info(
f"Found {sum(e - s for s, e in ranges)} image tokens "
f"(id={image_token_id}) in {len(ranges)} range(s)"
)
return ranges
def _compute_per_image_ranges(
contiguous_ranges: list[tuple[int, int]],
tokens_per_image: list[int],
) -> list[tuple[int, int]] | None:
"""
Split contiguous image token ranges by each image's token count.
Example: contiguous_ranges=[(0, 100)], tokens_per_image=[60, 40]
Returns: [(0, 60), (60, 100)] # image 1 at 0-60, image 2 at 60-100
"""
if not contiguous_ranges:
if tokens_per_image:
logger.warning(
f"No image tokens found but {len(tokens_per_image)} images expected"
)
return None
# Greedily assign images to ranges in order
result = []
image_idx = 0
for range_start, range_end in contiguous_ranges:
range_size = range_end - range_start
pos = range_start
consumed = 0
# Consume images that fit entirely in this range
# (a single image's tokens are always contiguous, cannot span ranges)
while image_idx < len(tokens_per_image):
needed = tokens_per_image[image_idx]
if consumed + needed <= range_size:
result.append((pos, pos + needed))
pos += needed
consumed += needed
image_idx += 1
else:
break
# Range must be exactly filled (no leftover image tokens)
if consumed != range_size:
logger.warning(
f"Range size mismatch: consumed {consumed} != range {range_size}"
)
return None
# All images must be consumed
if image_idx != len(tokens_per_image):
logger.warning(f"Not all images mapped: {image_idx} < {len(tokens_per_image)}")
return None
return result
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for vLLM.
This worker receives OpenAI-format requests from the frontend, computes
mm_hash for any images, finds the best vLLM worker based on KV cache
overlap, and forwards the request to that worker.
Usage:
python -m examples.backends.vllm.mm_router_worker \
--model Qwen/Qwen2.5-VL-7B-Instruct \
--namespace default \
--component mm_router \
--endpoint generate \
--downstream-component VllmWorker \
--downstream-endpoint generate
"""
import argparse
import asyncio
import logging
import signal
import uvloop
from transformers import AutoProcessor, AutoTokenizer
from dynamo.llm import KvRouter, KvRouterConfig, ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from .handler import MMRouterHandler
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="MM Router Worker - Multimodal-aware KV cache routing for vLLM"
)
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model path or HuggingFace model ID",
)
parser.add_argument(
"--block-size",
type=int,
default=32,
help="KV cache block size",
)
# This worker's endpoint configuration
parser.add_argument(
"--namespace",
type=str,
default="default",
help="Dynamo namespace",
)
parser.add_argument(
"--component",
type=str,
default="mm_router",
help="This worker's component name",
)
parser.add_argument(
"--endpoint",
type=str,
default="generate",
help="This worker's endpoint name",
)
# Downstream vLLM worker configuration
parser.add_argument(
"--downstream-component",
type=str,
default="VllmWorker",
help="Downstream vLLM workers' component name",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default="generate",
help="Downstream vLLM workers' endpoint name",
)
return parser.parse_args()
async def graceful_shutdown(runtime: DistributedRuntime) -> None:
"""Handle graceful shutdown."""
logger.info("Received shutdown signal, shutting down...")
runtime.shutdown()
logger.info("Shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime) -> None:
"""
Main worker function.
Sets up connections to downstream vLLM workers, creates KvRouter
for tracking their cache states, and serves the MM router endpoint.
"""
args = parse_args()
# Set up signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("MM Router Worker (vLLM) starting...")
logger.info(f"Model: {args.model}")
logger.info(f"This worker: {args.namespace}.{args.component}.{args.endpoint}")
logger.info(
f"Downstream: {args.namespace}.{args.downstream_component}.{args.downstream_endpoint}"
)
# Connect to downstream vLLM workers
downstream_endpoint = (
runtime.namespace(args.namespace)
.component(args.downstream_component)
.endpoint(args.downstream_endpoint)
)
downstream_client = await downstream_endpoint.client()
logger.info("Waiting for downstream vLLM workers...")
instance_ids = await downstream_client.wait_for_instances()
logger.info(f"Found {len(instance_ids)} workers: {list(instance_ids)}")
# Create KvRouter to select workers based on KV overlap
kv_router = KvRouter(
endpoint=downstream_endpoint,
block_size=args.block_size,
kv_router_config=KvRouterConfig(),
)
logger.info("KvRouter created successfully")
# Initialize tokenizer and processor for MM processing
logger.info(f"Loading tokenizer from {args.model}...")
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
logger.info(f"Loading HuggingFace processor from {args.model}...")
# TODO: hf AutoProcessor may be slow than the vllm equivalent @zhongdaor
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
# Create handler
handler = MMRouterHandler(
kv_router=kv_router,
tokenizer=tokenizer,
processor=processor,
model=args.model,
block_size=args.block_size,
)
# Register this worker's endpoint
component = runtime.namespace(args.namespace).component(args.component)
endpoint = component.endpoint(args.endpoint)
# Use ModelInput.Tokens so Frontend preprocesses the request
# Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}}
await register_llm(
ModelInput.Tokens,
ModelType.Chat,
endpoint,
args.model,
kv_cache_block_size=args.block_size,
)
logger.info(f"MM Router Worker (vLLM) ready, serving {args.endpoint} endpoint...")
# Serve the endpoint
await endpoint.serve_endpoint(handler.generate)
def main() -> None:
"""Entry point for the MM Router Worker."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
uvloop.install()
asyncio.run(worker())
if __name__ == "__main__":
main()
......@@ -416,12 +416,19 @@ impl RouterHandles {
async fn query_prefill_worker(
&self,
tokens: &[u32],
block_mm_infos: Option<&[Option<dynamo_llm::kv_router::protocols::BlockExtraInfo>]>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
) -> Result<u64, QueryRouterResult> {
self.prefill_router
.query_prefill_worker(tokens, update_states, lora_name, priority_jump)
.query_prefill_worker(
tokens,
block_mm_infos,
update_states,
lora_name,
priority_jump,
)
.await
.map(|(worker_id, _dp_rank)| worker_id)
.map_err(|e| {
......@@ -455,7 +462,15 @@ impl RouterHandles {
};
self.decode_router
.find_best_match(None, tokens, config_override.as_ref(), false, None, 0.0)
.find_best_match(
None,
tokens,
None,
config_override.as_ref(),
false,
None,
0.0,
)
.await
.map_err(|e| {
tracing::error!(error = ?e, "Decode query failed");
......@@ -1026,7 +1041,7 @@ pub unsafe extern "C" fn route_request(
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = if is_disaggregated {
handles
.query_prefill_worker(tokens, false, None, 0.0)
.query_prefill_worker(tokens, None, false, None, 0.0)
.await?
} else {
0
......
......@@ -813,7 +813,7 @@ impl KvRouter {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None))]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
fn generate<'p>(
&self,
py: Python<'p>,
......@@ -826,6 +826,9 @@ impl KvRouter {
worker_id: Option<WorkerId>,
dp_rank: Option<DpRank>,
extra_args: Option<PyObject>,
block_mm_infos: Option<PyObject>,
multi_modal_data: Option<PyObject>,
mm_routing_info: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults
let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
......@@ -859,6 +862,32 @@ impl KvRouter {
None
};
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos
{
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
if let Some(obj) = multi_modal_data {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
if let Some(obj) = mm_routing_info {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
block_mm_infos.map(
|infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
routing_token_ids: token_ids.clone(),
block_mm_infos: infos,
},
)
};
// Create tracker to capture worker routing info from KvRouter
let tracker = Arc::new(RequestTracker::new());
......@@ -872,6 +901,8 @@ impl KvRouter {
.sampling_options(sampling_options)
.output_options(output_options)
.router_config_override(router_config_override)
.multi_modal_data(multi_modal_data)
.mm_routing_info(mm_routing_info)
.extra_args(extra_args)
.tracker(Some(tracker.clone()));
......@@ -914,13 +945,14 @@ impl KvRouter {
Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
}
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
request_id: Option<String>,
block_mm_infos: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
let override_config: llm_rs::kv_router::RouterConfigOverride =
......@@ -930,6 +962,13 @@ impl KvRouter {
None
};
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos
{
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
......@@ -938,6 +977,7 @@ impl KvRouter {
.find_best_match(
request_id.as_deref(),
&token_ids,
block_mm_infos.as_deref(),
router_config_override.as_ref(),
update_states,
None, // lora_name not exposed in Python API yet
......
......@@ -1360,6 +1360,10 @@ class KvRouter:
router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None,
dp_rank: Optional[int] = None,
extra_args: Optional[JsonLike] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
multi_modal_data: Optional[JsonLike] = None,
mm_routing_info: Optional[JsonLike] = None,
) -> AsyncIterator[JsonLike]:
"""
Generate text using the KV-aware router.
......@@ -1378,6 +1382,16 @@ class KvRouter:
the request will be routed to the specific (worker_id, dp_rank) pair.
If only dp_rank is set, the router will select the best worker but
force routing to the specified dp_rank.
extra_args: Optional extra request arguments to include in the
PreprocessedRequest.
block_mm_infos: Optional block-level multimodal metadata aligned to
request blocks. Backward-compatible shortcut; this is
converted to mm_routing_info with routing_token_ids=token_ids.
multi_modal_data: Optional multimodal payload map to preserve image/video
data for downstream model execution.
mm_routing_info: Optional structured routing-only multimodal payload
(e.g., {"routing_token_ids": [...], "block_mm_infos": [...]})
used by router selection without changing execution token_ids.
Returns:
An async iterator yielding generation responses
......@@ -1396,6 +1410,7 @@ class KvRouter:
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
) -> Tuple[int, int, int]:
"""
Find the best matching worker for the given tokens.
......@@ -1406,6 +1421,9 @@ class KvRouter:
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
block_mm_infos: Optional block-level multimodal metadata aligned to request
blocks. When provided, this is used in block hash computation
to enable MM-aware worker selection.
Returns:
A tuple of (worker_id, dp_rank, overlap_blocks) where:
......
......@@ -114,6 +114,8 @@ pub enum RouterRequest {
#[serde(rename = "new")]
New {
tokens: Vec<Token>,
#[serde(default, skip_serializing_if = "Option::is_none")]
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
},
MarkPrefill,
MarkFree,
......@@ -121,7 +123,10 @@ pub enum RouterRequest {
impl Default for RouterRequest {
fn default() -> Self {
RouterRequest::New { tokens: vec![] }
RouterRequest::New {
tokens: vec![],
block_mm_infos: None,
}
}
}
......
......@@ -50,8 +50,8 @@ use crate::{
approx::PruneConfig,
indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{
DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
BlockExtraInfo, DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest,
RouterResponse, TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_block_hash_for_seq,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
......@@ -371,6 +371,7 @@ impl KvRouter {
&self,
context_id: Option<&str>,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
......@@ -385,7 +386,7 @@ impl KvRouter {
let isl_tokens = tokens.len();
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, None));
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, block_mm_infos));
let hash_elapsed = start.elapsed();
let overlap_scores = self
......@@ -566,9 +567,20 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let context_id = ctx.context().id().to_string();
// Handle different request types
let response = match request {
RouterRequest::New { tokens } => {
RouterRequest::New {
tokens,
block_mm_infos,
} => {
let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true, None, 0.0)
.find_best_match(
Some(&context_id),
&tokens,
block_mm_infos.as_deref(),
None,
true,
None,
0.0,
)
.await?;
RouterResponse::New {
......
......@@ -21,7 +21,7 @@ use dynamo_runtime::{
use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride, protocols::BlockExtraInfo},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::{RequestPhase, RequestTracker, WORKER_TYPE_PREFILL},
......@@ -103,6 +103,19 @@ pub struct PrefillRouter {
}
impl PrefillRouter {
fn routing_inputs(req: &PreprocessedRequest) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = req.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&req.token_ids, None)
}
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(
model_manager: Arc<ModelManager>,
......@@ -285,8 +298,15 @@ impl PrefillRouter {
.as_ref()
.and_then(|r| r.priority_jump)
.unwrap_or(0.0);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(req);
match self
.query_prefill_worker(&req.token_ids, false, lora_name, priority_jump)
.query_prefill_worker(
routing_token_ids,
block_mm_infos,
false,
lora_name,
priority_jump,
)
.await
{
Ok((worker_id, dp_rank)) => (worker_id, dp_rank),
......@@ -475,6 +495,7 @@ impl PrefillRouter {
pub async fn query_prefill_worker(
&self,
token_ids: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
......@@ -491,6 +512,7 @@ impl PrefillRouter {
.find_best_match(
None,
token_ids,
block_mm_infos,
None,
update_states,
lora_name,
......
......@@ -787,6 +787,58 @@ enum RawKvEvent {
AllBlocksCleared,
}
/// Parse MM hash from extra_keys string:
/// - Only accept canonical vLLM MM identifiers (64-char hex digest)
/// - Convert by taking the first 16 hex chars as u64
fn parse_mm_hash_from_extra_key(s: &str) -> Option<u64> {
// extra_keys mixes MM identifiers with LoRA/cache_salt/prompt-embed metadata.
// Only MM identifiers should be mapped into BlockExtraInfo.
if s.len() == 64 && s.chars().all(|c| c.is_ascii_hexdigit()) {
return u64::from_str_radix(&s[..16], 16).ok();
}
None
}
/// Convert vLLM BlockStored extra_keys to block-level MM infos.
/// extra_keys is a list aligned with blocks:
/// - None => no MM content in that block
/// - ["hash1", "hash2", ...] => one or more MM objects in that block
fn extra_keys_to_block_mm_infos(
extra_keys: Option<Vec<Option<Vec<String>>>>,
) -> Option<Vec<Option<BlockExtraInfo>>> {
let extra_keys = extra_keys?;
if extra_keys.is_empty() {
return None;
}
let infos: Vec<Option<BlockExtraInfo>> = extra_keys
.into_iter()
.map(|block_keys| {
let mm_objects: Vec<BlockMmObjectInfo> = block_keys
.unwrap_or_default()
.iter()
.filter_map(|key| parse_mm_hash_from_extra_key(key))
.map(|mm_hash| BlockMmObjectInfo {
mm_hash,
offsets: vec![], // extra_keys does not carry offsets today
})
.collect();
if mm_objects.is_empty() {
None
} else {
Some(BlockExtraInfo { mm_objects })
}
})
.collect();
if infos.iter().all(|i| i.is_none()) {
return None;
}
Some(infos)
}
/// Our producers use msgspec with `tag=True` and `array_like=True`, which
/// encodes each event as either a tagged map or a tagged tuple. To be tolerant of
/// additional fields that may be appended in the future, we implement a custom
......@@ -824,6 +876,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut lora_id: Option<Option<u64>> = None;
let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<String>>>>> = None;
let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None;
while let Some(key) = map.next_key::<String>()? {
......@@ -852,6 +905,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"lora_name" => {
lora_name = Some(map.next_value()?);
}
"extra_keys" => {
extra_keys = Some(map.next_value()?);
}
"block_mm_infos" => {
block_mm_infos = Some(map.next_value()?);
}
......@@ -868,6 +924,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?;
let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let block_mm_infos = block_mm_infos
.unwrap_or(None)
.or_else(|| extra_keys_to_block_mm_infos(extra_keys.unwrap_or(None)));
Ok(RawKvEvent::BlockStored {
block_hashes,
parent_block_hash: parent_block_hash.unwrap_or(None),
......@@ -876,7 +935,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
lora_id: lora_id.unwrap_or(None),
medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None),
block_mm_infos: block_mm_infos.unwrap_or(None),
block_mm_infos,
})
}
Some("BlockRemoved") => {
......@@ -923,11 +982,16 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<String>>>> =
seq.next_element()?.unwrap_or(None);
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> =
seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {}
let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
Ok(RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
......@@ -1206,6 +1270,114 @@ mod test_event_processing {
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared));
}
#[test]
fn test_parse_mm_hash_from_extra_key() {
assert_eq!(
parse_mm_hash_from_extra_key(
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210"
),
Some(0x0123_4567_89ab_cdef)
);
assert_eq!(parse_mm_hash_from_extra_key("123"), None);
assert_eq!(parse_mm_hash_from_extra_key("not_a_hash"), None);
}
#[test]
fn test_extra_keys_to_block_mm_infos() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let infos = extra_keys_to_block_mm_infos(Some(vec![
Some(vec![mm_hash.clone()]),
None,
Some(vec!["invalid".to_string(), mm_hash]),
]))
.expect("expected parsed MM infos");
assert_eq!(infos.len(), 3);
assert_eq!(
infos[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
assert!(infos[1].is_none());
assert_eq!(
infos[2].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_seq_block_stored_field8_supports_extra_keys() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let extra_keys_payload = rmps::to_vec(&(
"BlockStored",
vec![10_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
None::<u64>,
None::<String>,
None::<String>,
vec![Some(vec![mm_hash])],
))
.unwrap();
let extra_keys_event: RawKvEvent = rmps::from_slice(&extra_keys_payload).unwrap();
let RawKvEvent::BlockStored {
lora_name,
block_mm_infos,
..
} = extra_keys_event
else {
panic!("expected BlockStored");
};
assert!(lora_name.is_none());
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_map_block_stored_supports_extra_keys() {
#[derive(serde::Serialize)]
struct MapBlockStoredEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
medium: Option<String>,
lora_name: Option<String>,
extra_keys: Option<Vec<Option<Vec<String>>>>,
}
let payload = rmps::to_vec(&MapBlockStoredEvent {
event_type: "BlockStored",
block_hashes: vec![10],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: None,
medium: Some("GPU".to_string()),
lora_name: None,
extra_keys: Some(vec![Some(vec![
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string(),
])]),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { block_mm_infos, .. } = event else {
panic!("expected BlockStored");
};
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
}
#[cfg(test)]
......
......@@ -19,7 +19,7 @@ use crate::{
kv_router::{
KvRouter,
metrics::RouterRequestMetrics,
protocols::{TokensWithHashes, WorkerWithDpRank},
protocols::{BlockExtraInfo, TokensWithHashes, WorkerWithDpRank},
},
preprocessor::PreprocessedRequest,
protocols::common::{
......@@ -108,6 +108,21 @@ impl KvPushRouter {
KvPushRouter { inner, chooser }
}
fn routing_inputs(
request: &PreprocessedRequest,
) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = request.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&request.token_ids, None)
}
/// Select a worker for the request, either using a preselected worker or finding the best match.
///
/// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`.
......@@ -123,6 +138,7 @@ impl KvPushRouter {
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(request);
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase {
......@@ -140,7 +156,8 @@ impl KvPushRouter {
.chooser
.find_best_match(
Some(context_id),
&request.token_ids,
routing_token_ids,
block_mm_infos,
request.router_config_override.as_ref(),
!is_query_only,
lora_name,
......@@ -148,6 +165,27 @@ impl KvPushRouter {
)
.await?;
if !is_query_only {
let total_blocks = routing_token_ids
.len()
.div_ceil(self.chooser.block_size() as usize);
// NOTE: tests/mm_router/test_vllm_mm_router_e2e.py parses this log line.
// Keep the "[ROUTING] ... with X/Y blocks overlap" shape stable unless
// router tests are updated together.
tracing::debug!(
request_id = %context_id,
worker_id = best_worker.worker_id,
dp_rank = best_worker.dp_rank,
overlap_blocks = overlap_amount,
total_blocks = total_blocks,
"[ROUTING] Best: worker_{} dp_rank={} with {}/{} blocks overlap",
best_worker.worker_id,
best_worker.dp_rank,
overlap_amount,
total_blocks,
);
}
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
......@@ -165,14 +203,14 @@ impl KvPushRouter {
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(&request.token_ids, worker)
.get_overlap_blocks(routing_token_ids, worker)
.await?;
if !is_query_only {
self.chooser
.add_request(
context_id.to_string(),
&request.token_ids,
routing_token_ids,
overlap_blocks,
expected_output_tokens,
worker,
......@@ -275,10 +313,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let request_metrics =
RouterRequestMetrics::from_component(self.chooser.client().endpoint.component());
if let Some(ref tracker) = request.tracker {
let isl_blocks = request.token_ids.len().div_ceil(block_size);
let (routing_token_ids, _) = Self::routing_inputs(&request);
let isl_blocks = routing_token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_isl(
request.token_ids.len(),
routing_token_ids.len(),
overlap_amount as usize * block_size,
);
tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
......
......@@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize};
use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::kv_router::protocols::BlockExtraInfo;
use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType;
......@@ -72,6 +73,20 @@ pub struct PrefillResult {
pub prompt_tokens_details: Option<dynamo_async_openai::types::PromptTokensDetails>,
}
/// Optional multimodal routing-only data.
/// This is used by the router to compute overlaps on an alternate token sequence
/// (for example, MM-expanded tokens) without changing execution token_ids.
#[derive(Serialize, Deserialize, Debug, Clone, Default, Builder)]
#[builder(default)]
pub struct MmRoutingInfo {
/// Token IDs to use for routing overlap computation.
pub routing_token_ids: Vec<TokenIdType>,
/// Block-level multimodal metadata aligned with routing_token_ids blocks.
/// Use `None` entries for blocks without multimodal objects.
pub block_mm_infos: Vec<Option<BlockExtraInfo>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
......@@ -102,6 +117,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multi_modal_data: Option<MultimodalDataMap>,
/// Optional multimodal routing-only fields (separate from execution payload).
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mm_routing_info: Option<MmRoutingInfo>,
/// StopConditions are conditions that the inference engine will use to stop generation.
pub stop_conditions: StopConditions,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment