Unverified Commit 815b1291 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: mm aware routing for vllm (#6235)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
parent 0abebe38
...@@ -37,6 +37,7 @@ from dynamo.llm import ( ...@@ -37,6 +37,7 @@ from dynamo.llm import (
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.image_loader import ImageLoader from .multimodal_utils.image_loader import ImageLoader
# Multimodal data dictionary keys # Multimodal data dictionary keys
...@@ -48,6 +49,27 @@ DECODED_VARIANT_KEY: Final = "Decoded" ...@@ -48,6 +49,27 @@ DECODED_VARIANT_KEY: Final = "Decoded"
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _compute_mm_uuids(
multi_modal_data: Dict[str, Any] | None
) -> Dict[str, list[str]] | None:
"""
Compute multi_modal_uuids from multi_modal_data.
Each image gets a SHA256 hex digest as its UUID, ensuring consistent
hashing across the MM Router, vLLM handler, and Rust KV publisher.
"""
if not multi_modal_data or "image" not in multi_modal_data:
return None
images = multi_modal_data["image"]
if not isinstance(images, list):
images = [images]
if not images:
return None
uuids = compute_mm_uuids_from_images(images)
return {"image": uuids}
# LoRAManager singleton - initialized lazily when DYN_LORA_ENABLED is set # LoRAManager singleton - initialized lazily when DYN_LORA_ENABLED is set
# None = not yet initialized, False = disabled/failed, LoRAManager = initialized # None = not yet initialized, False = disabled/failed, LoRAManager = initialized
_lora_manager = None _lora_manager = None
...@@ -1010,12 +1032,17 @@ class BaseWorkerHandler(ABC): ...@@ -1010,12 +1032,17 @@ class BaseWorkerHandler(ABC):
"token_ids": [], "token_ids": [],
}, },
) )
else: # Normal path: use token IDs
# Normal path: use token IDs mm_uuids = _compute_mm_uuids(multi_modal_data)
prompt = TokensPrompt( prompt_kwargs = dict[str, Any](
prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data prompt_token_ids=request["token_ids"],
) multi_modal_data=multi_modal_data,
return prompt, embedding_sequence_length, None )
if mm_uuids is not None:
prompt_kwargs["multi_modal_uuids"] = mm_uuids
prompt = TokensPrompt(**prompt_kwargs)
return prompt, embedding_sequence_length, None
@staticmethod @staticmethod
def _build_completion_usage( def _build_completion_usage(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import hashlib
import io
import logging
from typing import Any, Sequence
logger = logging.getLogger(__name__)
def image_to_bytes(img: Any) -> bytes:
"""Convert a supported image object to PNG bytes for hashing."""
from PIL import Image
if isinstance(img, bytes):
return img
if isinstance(img, Image.Image):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# Frontend-decoding can provide image tensors as numpy arrays.
try:
import numpy as np
if isinstance(img, np.ndarray):
pil_img = Image.fromarray(img)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
return buf.getvalue()
except ImportError:
pass
raise TypeError(f"Unsupported image type for hashing: {type(img)}")
def compute_mm_uuids_from_images(images: Sequence[Any]) -> list[str]:
"""
Compute SHA256 hex UUIDs for image inputs.
"""
uuids: list[str] = []
for img in images:
raw_bytes = image_to_bytes(img)
uuids.append(hashlib.sha256(raw_bytes).hexdigest())
return uuids
...@@ -49,6 +49,7 @@ class TestVllmKvEventsApi: ...@@ -49,6 +49,7 @@ class TestVllmKvEventsApi:
5. lora_id 5. lora_id
6. medium 6. medium
7. lora_name (added in vLLM 0.14.0) 7. lora_name (added in vLLM 0.14.0)
8. extra_keys (contains MM info, cache_salt, etc.)
If vLLM adds/removes/reorders fields, this test will fail. If vLLM adds/removes/reorders fields, this test will fail.
""" """
...@@ -60,6 +61,7 @@ class TestVllmKvEventsApi: ...@@ -60,6 +61,7 @@ class TestVllmKvEventsApi:
"lora_id", "lora_id",
"medium", "medium",
"lora_name", "lora_name",
"extra_keys",
) )
actual_fields = BlockStored.__struct_fields__ actual_fields = BlockStored.__struct_fields__
...@@ -146,6 +148,7 @@ class TestVllmKvEventsApi: ...@@ -146,6 +148,7 @@ class TestVllmKvEventsApi:
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None, lora_name=None,
extra_keys=None,
) )
encoded = msgspec.msgpack.encode(event) encoded = msgspec.msgpack.encode(event)
...@@ -157,9 +160,9 @@ class TestVllmKvEventsApi: ...@@ -157,9 +160,9 @@ class TestVllmKvEventsApi:
decoded[0] == "BlockStored" decoded[0] == "BlockStored"
), f"Expected tag 'BlockStored', got {decoded[0]}" ), f"Expected tag 'BlockStored', got {decoded[0]}"
# Verify field count (tag + 7 fields = 8 elements) # Verify field count (tag + 8 fields = 9 elements)
assert len(decoded) == 8, ( assert len(decoded) == 9, (
f"Expected 8 elements (tag + 7 fields), got {len(decoded)}.\n" f"Expected 9 elements (tag + 8 fields), got {len(decoded)}.\n"
f"Decoded: {decoded}\n" f"Decoded: {decoded}\n"
f"If field count changed, update Rust deserializers." f"If field count changed, update Rust deserializers."
) )
...@@ -172,3 +175,4 @@ class TestVllmKvEventsApi: ...@@ -172,3 +175,4 @@ class TestVllmKvEventsApi:
assert decoded[5] is None, f"lora_id at wrong position: {decoded[5]}" assert decoded[5] is None, f"lora_id at wrong position: {decoded[5]}"
assert decoded[6] == "GPU", f"medium at wrong position: {decoded[6]}" assert decoded[6] == "GPU", f"medium at wrong position: {decoded[6]}"
assert decoded[7] is None, f"lora_name at wrong position: {decoded[7]}" assert decoded[7] is None, f"lora_name at wrong position: {decoded[7]}"
assert decoded[8] is None, f"extra_keys at wrong position: {decoded[8]}"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for vLLM.
This worker sits between the frontend and vLLM workers, computing mm_hash
for images and routing requests to the worker with the best KV cache overlap.
"""
from .handler import MMRouterHandler
from .mm_processor import ProcessedInput, build_block_mm_infos
__all__ = [
"MMRouterHandler",
"ProcessedInput",
"build_block_mm_infos",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Entry point for running MM Router Worker as a module."""
from .mm_router_worker import main
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Handler - Routes requests to best vLLM worker based on KV cache overlap.
"""
import logging
from typing import Any, AsyncGenerator
from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging
from .mm_processor import build_block_mm_infos, extract_image_urls, process_multimodal
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class MMRouterHandler:
"""
Handler that computes mm_hash for multimodal requests and routes
to the best vLLM worker based on KV cache overlap.
"""
def __init__(
self,
kv_router: KvRouter,
tokenizer: Any,
processor: Any,
model: str,
block_size: int,
):
"""
Initialize the MM Router Handler.
Args:
kv_router: KvRouter for KV-aware worker selection and routing
tokenizer: HuggingFace AutoTokenizer
processor: HuggingFace AutoProcessor (optional)
model: Model path/name
block_size: KV cache block size
"""
self.kv_router = kv_router
self.tokenizer = tokenizer
self.processor = processor
self.model = model
self.block_size = block_size
async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
"""
Main entry point - receives request, computes routing, forwards to best worker.
The request format (after Frontend preprocessing with ModelInput.Tokens):
{
"token_ids": [...],
"sampling_options": {...},
"stop_conditions": {...},
"extra_args": {"messages": [...]}
}
Args:
request: Preprocessed request from Frontend
Yields:
Response chunks from the downstream vLLM worker
"""
# Extract messages from extra_args (set by Frontend preprocessor)
messages = request.get("extra_args", {}).get("messages", [])
image_urls = extract_image_urls(messages)
if image_urls:
# Process multimodal: download images, compute mm_hash
# Do not reuse request["token_ids"] for MM routing: those are placeholder-level
# tokens from frontend. We need processor-expanded tokens to build block_mm_infos.
# Request payload does not include a rendered template string; extra_args carries
# original messages, so mm_processor reapplies chat template locally.
processed = process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
)
# Build block_mm_infos for MM-aware hash computation
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError(
"Failed to build block_mm_infos for multimodal request"
)
routing_tokens = processed.tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"MM request: {len(routing_tokens)} routing tokens, "
f"{len(image_urls)} images, {routing_blocks} routing blocks"
)
else:
# Text-only: rely on frontend-preprocessed token_ids (ModelInput.Tokens contract)
tokens = request.get("token_ids")
if not tokens:
raise ValueError(
"Missing or empty token_ids in preprocessed request for text-only routing"
)
routing_tokens = tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"Text request: {len(routing_tokens)} routing tokens, {routing_blocks} routing blocks"
)
# Text-only routing has no multimodal objects; provide per-block None entries.
block_mm_infos = [None] * routing_blocks
# Route and generate through KvRouter with explicit fields.
# We pass:
# - execution payload (token_ids + multi_modal_data)
# - routing payload (mm_routing_info: routing_token_ids + block_mm_infos)
# so generate() can select worker internally while preserving MM correctness.
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError("Missing or empty token_ids in preprocessed request")
mm_routing_info: dict[str, Any] = {
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
}
stream = await self.kv_router.generate(
token_ids=token_ids,
model=request["model"],
stop_conditions=request.get("stop_conditions"),
sampling_options=request.get("sampling_options"),
output_options=request.get("output_options"),
router_config_override=request.get("router_config_override"),
extra_args=request.get("extra_args"),
multi_modal_data=request.get("multi_modal_data"),
mm_routing_info=mm_routing_info,
)
async for response in stream:
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Multimodal processing utilities for vLLM MM Router Worker.
Key differences from TRT-LLM version:
- Image loading: PIL + requests/base64 (no TRT-LLM dependency)
- mm_hash: SHA256 of normalized PNG bytes (matches vLLM multi_modal_uuids)
- Token replacement: NOT needed — vLLM keeps the original image_token_id as-is
"""
import base64
import logging
from dataclasses import dataclass
from io import BytesIO
from typing import Any
from urllib.parse import urlparse
import requests
from PIL import Image
from dynamo.vllm.multimodal_utils.hash_utils import compute_mm_uuids_from_images
logger = logging.getLogger(__name__)
# =============================================================================
# Data structures
# =============================================================================
@dataclass
class ProcessedInput:
"""Processed multimodal input."""
tokens: list[int]
mm_hashes: list[int] | None
image_ranges: list[tuple[int, int]] | None # [(start, end), ...] per image
# =============================================================================
# Public functions
# =============================================================================
def extract_image_urls(messages: list[dict]) -> list[str]:
"""Extract image URLs from OpenAI-format messages."""
urls = []
for msg in messages:
content = msg.get("content", [])
if isinstance(content, list):
for part in content:
if part.get("type") == "image_url":
url = part.get("image_url", {}).get("url")
if url:
urls.append(url)
return urls
def process_multimodal(
messages: list[dict],
image_urls: list[str],
tokenizer: Any,
processor: Any,
model: str,
) -> ProcessedInput:
"""
Process multimodal request: load images, get expanded tokens and mm_hashes.
Uses PIL for image loading and hashlib for mm_hash computation.
Unlike TRT-LLM, vLLM keeps original image_token_id (no replacement).
"""
# The preprocessed request does not carry a rendered template string; it carries
# original messages in extra_args, so we must apply chat template again here.
prompt = _build_prompt_with_images(messages, tokenizer, processor)
logger.info(f"Prompt (first 300 chars): {prompt[:300]}")
# Load images as PIL
pil_images = []
for url in image_urls:
pil_img = _load_image(url)
pil_images.append(pil_img)
# Get expanded tokens and image ranges (no token replacement for vLLM)
tokens, image_ranges = _get_expanded_tokens(
prompt, pil_images, tokenizer, processor
)
logger.info(f"Expanded: {len(tokens)} tokens, " f"image_ranges={image_ranges}")
# Compute mm_hashes exactly like vLLM handler's multi_modal_uuids path.
mm_uuids = compute_mm_uuids_from_images(pil_images)
mm_hashes = [int(uuid[:16], 16) for uuid in mm_uuids]
logger.info(f"mm_hashes={mm_hashes}")
return ProcessedInput(tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges)
def build_block_mm_infos(
num_tokens: int,
block_size: int,
mm_hashes: list[int] | None,
image_ranges: list[tuple[int, int]] | None,
) -> list[dict | None] | None:
"""
Build per-block mm_info for routing.
For each block, check which images overlap with it and add their mm_hash.
Assumption: mm_hashes and image_ranges are in the same order as images appear
in the request (which matches their order in the token sequence).
"""
if not mm_hashes or not image_ranges or len(mm_hashes) != len(image_ranges):
return None
num_blocks = (num_tokens + block_size - 1) // block_size
result = []
for block_idx in range(num_blocks):
block_start = block_idx * block_size
block_end = block_start + block_size
# Find images overlapping this block
mm_objects = [
{"mm_hash": mm_hash, "offsets": []}
for mm_hash, (img_start, img_end) in zip(mm_hashes, image_ranges)
if block_end > img_start and block_start < img_end
]
result.append({"mm_objects": mm_objects} if mm_objects else None)
return result
# =============================================================================
# Internal functions
# =============================================================================
def _build_prompt_with_images(
messages: list[dict], tokenizer: Any, processor: Any
) -> str:
"""
Build a prompt that includes image placeholders using the tokenizer's
chat template. This is critical for Qwen2-VL/Qwen2.5-VL models which
need <|vision_start|><|image_pad|>...<|vision_end|> in the prompt for
the processor to expand image tokens correctly.
Raises if chat template cannot be applied. For MM routing correctness, we do
not silently fall back to text-only prompts.
"""
# Try processor first (has the best chat template for multimodal)
if processor is not None and hasattr(processor, "apply_chat_template"):
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Fall back to tokenizer if available
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
raise ValueError("Neither processor nor tokenizer provides apply_chat_template")
def _load_image(url: str) -> Image.Image:
"""
Load an image from URL (http/https or data URI) and return a PIL RGB image.
"""
parsed = urlparse(url)
if parsed.scheme == "data":
# data:image/png;base64,<data>
_, data = parsed.path.split(",", 1)
raw_bytes = base64.b64decode(data)
elif parsed.scheme in ("http", "https"):
response = requests.get(url, timeout=30)
response.raise_for_status()
raw_bytes = response.content
else:
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
return Image.open(BytesIO(raw_bytes)).convert("RGB")
def _get_expanded_tokens(
prompt: str,
pil_images: list[Image.Image],
tokenizer: Any,
processor: Any,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""
Get tokens with visual expansion and find each image's token range.
Unlike TRT-LLM, vLLM keeps the original image_token_id (no replacement).
"""
if processor is None:
return tokenizer.encode(prompt), None
try:
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
# Get image_token_id from processor
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
raise ValueError("processor.image_token_id not found")
# Find contiguous image token ranges (NO replacement for vLLM)
contiguous_ranges = _find_image_token_ranges(tokens, image_token_id)
# Compute tokens per image from processor output
tokens_per_image = _compute_tokens_per_image(output, processor)
# Split ranges according to tokens_per_image
image_ranges = _compute_per_image_ranges(contiguous_ranges, tokens_per_image)
return tokens, image_ranges
except Exception as e:
logger.warning(f"HF processor failed: {e}", exc_info=True)
return tokenizer.encode(prompt), None
def _compute_tokens_per_image(processor_output: dict, processor: Any) -> list[int]:
"""
Compute the number of visual tokens for each image from processor output.
Only Qwen-style processors (Qwen2-VL, Qwen2.5-VL) are supported.
Other model families will raise ValueError.
"""
processor_cls = type(processor).__qualname__
if "qwen" not in processor_cls.lower():
raise NotImplementedError(
f"_compute_tokens_per_image only supports Qwen-style processors "
f"tuples. Got processor class: {processor_cls}"
)
grid_thw = processor_output.get("image_grid_thw")
if grid_thw is None:
raise ValueError("image_grid_thw not found in processor output")
merge_size = getattr(processor.image_processor, "merge_size", 2)
return [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
def _find_image_token_ranges(
tokens: list[int], image_token_id: int
) -> list[tuple[int, int]]:
"""
Find all contiguous ranges of image tokens.
Unlike the TRT-LLM version, this does NOT replace tokens — vLLM keeps
the original image_token_id as-is in KV events.
Returns: list of (start, end) ranges for contiguous image token regions.
"""
ranges = []
start = None
for i, t in enumerate(tokens):
if t == image_token_id:
if start is None:
start = i
elif start is not None:
ranges.append((start, i))
start = None
if start is not None:
ranges.append((start, len(tokens)))
if ranges:
logger.info(
f"Found {sum(e - s for s, e in ranges)} image tokens "
f"(id={image_token_id}) in {len(ranges)} range(s)"
)
return ranges
def _compute_per_image_ranges(
contiguous_ranges: list[tuple[int, int]],
tokens_per_image: list[int],
) -> list[tuple[int, int]] | None:
"""
Split contiguous image token ranges by each image's token count.
Example: contiguous_ranges=[(0, 100)], tokens_per_image=[60, 40]
Returns: [(0, 60), (60, 100)] # image 1 at 0-60, image 2 at 60-100
"""
if not contiguous_ranges:
if tokens_per_image:
logger.warning(
f"No image tokens found but {len(tokens_per_image)} images expected"
)
return None
# Greedily assign images to ranges in order
result = []
image_idx = 0
for range_start, range_end in contiguous_ranges:
range_size = range_end - range_start
pos = range_start
consumed = 0
# Consume images that fit entirely in this range
# (a single image's tokens are always contiguous, cannot span ranges)
while image_idx < len(tokens_per_image):
needed = tokens_per_image[image_idx]
if consumed + needed <= range_size:
result.append((pos, pos + needed))
pos += needed
consumed += needed
image_idx += 1
else:
break
# Range must be exactly filled (no leftover image tokens)
if consumed != range_size:
logger.warning(
f"Range size mismatch: consumed {consumed} != range {range_size}"
)
return None
# All images must be consumed
if image_idx != len(tokens_per_image):
logger.warning(f"Not all images mapped: {image_idx} < {len(tokens_per_image)}")
return None
return result
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for vLLM.
This worker receives OpenAI-format requests from the frontend, computes
mm_hash for any images, finds the best vLLM worker based on KV cache
overlap, and forwards the request to that worker.
Usage:
python -m examples.backends.vllm.mm_router_worker \
--model Qwen/Qwen2.5-VL-7B-Instruct \
--namespace default \
--component mm_router \
--endpoint generate \
--downstream-component VllmWorker \
--downstream-endpoint generate
"""
import argparse
import asyncio
import logging
import signal
import uvloop
from transformers import AutoProcessor, AutoTokenizer
from dynamo.llm import KvRouter, KvRouterConfig, ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from .handler import MMRouterHandler
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="MM Router Worker - Multimodal-aware KV cache routing for vLLM"
)
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model path or HuggingFace model ID",
)
parser.add_argument(
"--block-size",
type=int,
default=32,
help="KV cache block size",
)
# This worker's endpoint configuration
parser.add_argument(
"--namespace",
type=str,
default="default",
help="Dynamo namespace",
)
parser.add_argument(
"--component",
type=str,
default="mm_router",
help="This worker's component name",
)
parser.add_argument(
"--endpoint",
type=str,
default="generate",
help="This worker's endpoint name",
)
# Downstream vLLM worker configuration
parser.add_argument(
"--downstream-component",
type=str,
default="VllmWorker",
help="Downstream vLLM workers' component name",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default="generate",
help="Downstream vLLM workers' endpoint name",
)
return parser.parse_args()
async def graceful_shutdown(runtime: DistributedRuntime) -> None:
"""Handle graceful shutdown."""
logger.info("Received shutdown signal, shutting down...")
runtime.shutdown()
logger.info("Shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime) -> None:
"""
Main worker function.
Sets up connections to downstream vLLM workers, creates KvRouter
for tracking their cache states, and serves the MM router endpoint.
"""
args = parse_args()
# Set up signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("MM Router Worker (vLLM) starting...")
logger.info(f"Model: {args.model}")
logger.info(f"This worker: {args.namespace}.{args.component}.{args.endpoint}")
logger.info(
f"Downstream: {args.namespace}.{args.downstream_component}.{args.downstream_endpoint}"
)
# Connect to downstream vLLM workers
downstream_endpoint = (
runtime.namespace(args.namespace)
.component(args.downstream_component)
.endpoint(args.downstream_endpoint)
)
downstream_client = await downstream_endpoint.client()
logger.info("Waiting for downstream vLLM workers...")
instance_ids = await downstream_client.wait_for_instances()
logger.info(f"Found {len(instance_ids)} workers: {list(instance_ids)}")
# Create KvRouter to select workers based on KV overlap
kv_router = KvRouter(
endpoint=downstream_endpoint,
block_size=args.block_size,
kv_router_config=KvRouterConfig(),
)
logger.info("KvRouter created successfully")
# Initialize tokenizer and processor for MM processing
logger.info(f"Loading tokenizer from {args.model}...")
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
logger.info(f"Loading HuggingFace processor from {args.model}...")
# TODO: hf AutoProcessor may be slow than the vllm equivalent @zhongdaor
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
# Create handler
handler = MMRouterHandler(
kv_router=kv_router,
tokenizer=tokenizer,
processor=processor,
model=args.model,
block_size=args.block_size,
)
# Register this worker's endpoint
component = runtime.namespace(args.namespace).component(args.component)
endpoint = component.endpoint(args.endpoint)
# Use ModelInput.Tokens so Frontend preprocesses the request
# Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}}
await register_llm(
ModelInput.Tokens,
ModelType.Chat,
endpoint,
args.model,
kv_cache_block_size=args.block_size,
)
logger.info(f"MM Router Worker (vLLM) ready, serving {args.endpoint} endpoint...")
# Serve the endpoint
await endpoint.serve_endpoint(handler.generate)
def main() -> None:
"""Entry point for the MM Router Worker."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
uvloop.install()
asyncio.run(worker())
if __name__ == "__main__":
main()
...@@ -416,12 +416,19 @@ impl RouterHandles { ...@@ -416,12 +416,19 @@ impl RouterHandles {
async fn query_prefill_worker( async fn query_prefill_worker(
&self, &self,
tokens: &[u32], tokens: &[u32],
block_mm_infos: Option<&[Option<dynamo_llm::kv_router::protocols::BlockExtraInfo>]>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64, priority_jump: f64,
) -> Result<u64, QueryRouterResult> { ) -> Result<u64, QueryRouterResult> {
self.prefill_router self.prefill_router
.query_prefill_worker(tokens, update_states, lora_name, priority_jump) .query_prefill_worker(
tokens,
block_mm_infos,
update_states,
lora_name,
priority_jump,
)
.await .await
.map(|(worker_id, _dp_rank)| worker_id) .map(|(worker_id, _dp_rank)| worker_id)
.map_err(|e| { .map_err(|e| {
...@@ -455,7 +462,15 @@ impl RouterHandles { ...@@ -455,7 +462,15 @@ impl RouterHandles {
}; };
self.decode_router self.decode_router
.find_best_match(None, tokens, config_override.as_ref(), false, None, 0.0) .find_best_match(
None,
tokens,
None,
config_override.as_ref(),
false,
None,
0.0,
)
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!(error = ?e, "Decode query failed"); tracing::error!(error = ?e, "Decode query failed");
...@@ -1026,7 +1041,7 @@ pub unsafe extern "C" fn route_request( ...@@ -1026,7 +1041,7 @@ pub unsafe extern "C" fn route_request(
let result = handles.runtime.secondary().block_on(async { let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = if is_disaggregated { let prefill_worker_id = if is_disaggregated {
handles handles
.query_prefill_worker(tokens, false, None, 0.0) .query_prefill_worker(tokens, None, false, None, 0.0)
.await? .await?
} else { } else {
0 0
......
...@@ -813,7 +813,7 @@ impl KvRouter { ...@@ -813,7 +813,7 @@ impl KvRouter {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None))] #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
fn generate<'p>( fn generate<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
...@@ -826,6 +826,9 @@ impl KvRouter { ...@@ -826,6 +826,9 @@ impl KvRouter {
worker_id: Option<WorkerId>, worker_id: Option<WorkerId>,
dp_rank: Option<DpRank>, dp_rank: Option<DpRank>,
extra_args: Option<PyObject>, extra_args: Option<PyObject>,
block_mm_infos: Option<PyObject>,
multi_modal_data: Option<PyObject>,
mm_routing_info: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults // Depythonize the options with defaults
let stop_conditions: StopConditions = if let Some(obj) = stop_conditions { let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
...@@ -859,6 +862,32 @@ impl KvRouter { ...@@ -859,6 +862,32 @@ impl KvRouter {
None None
}; };
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos
{
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
if let Some(obj) = multi_modal_data {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
if let Some(obj) = mm_routing_info {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
block_mm_infos.map(
|infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
routing_token_ids: token_ids.clone(),
block_mm_infos: infos,
},
)
};
// Create tracker to capture worker routing info from KvRouter // Create tracker to capture worker routing info from KvRouter
let tracker = Arc::new(RequestTracker::new()); let tracker = Arc::new(RequestTracker::new());
...@@ -872,6 +901,8 @@ impl KvRouter { ...@@ -872,6 +901,8 @@ impl KvRouter {
.sampling_options(sampling_options) .sampling_options(sampling_options)
.output_options(output_options) .output_options(output_options)
.router_config_override(router_config_override) .router_config_override(router_config_override)
.multi_modal_data(multi_modal_data)
.mm_routing_info(mm_routing_info)
.extra_args(extra_args) .extra_args(extra_args)
.tracker(Some(tracker.clone())); .tracker(Some(tracker.clone()));
...@@ -914,13 +945,14 @@ impl KvRouter { ...@@ -914,13 +945,14 @@ impl KvRouter {
Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker)) Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
} }
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))] #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None))]
fn best_worker<'p>( fn best_worker<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
request_id: Option<String>, request_id: Option<String>,
block_mm_infos: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override { let router_config_override = if let Some(obj) = router_config_override {
let override_config: llm_rs::kv_router::RouterConfigOverride = let override_config: llm_rs::kv_router::RouterConfigOverride =
...@@ -930,6 +962,13 @@ impl KvRouter { ...@@ -930,6 +962,13 @@ impl KvRouter {
None None
}; };
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos
{
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let chooser = self.inner.chooser.clone(); let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some(); let update_states = request_id.is_some();
...@@ -938,6 +977,7 @@ impl KvRouter { ...@@ -938,6 +977,7 @@ impl KvRouter {
.find_best_match( .find_best_match(
request_id.as_deref(), request_id.as_deref(),
&token_ids, &token_ids,
block_mm_infos.as_deref(),
router_config_override.as_ref(), router_config_override.as_ref(),
update_states, update_states,
None, // lora_name not exposed in Python API yet None, // lora_name not exposed in Python API yet
......
...@@ -1360,6 +1360,10 @@ class KvRouter: ...@@ -1360,6 +1360,10 @@ class KvRouter:
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None, worker_id: Optional[int] = None,
dp_rank: Optional[int] = None, dp_rank: Optional[int] = None,
extra_args: Optional[JsonLike] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
multi_modal_data: Optional[JsonLike] = None,
mm_routing_info: Optional[JsonLike] = None,
) -> AsyncIterator[JsonLike]: ) -> AsyncIterator[JsonLike]:
""" """
Generate text using the KV-aware router. Generate text using the KV-aware router.
...@@ -1378,6 +1382,16 @@ class KvRouter: ...@@ -1378,6 +1382,16 @@ class KvRouter:
the request will be routed to the specific (worker_id, dp_rank) pair. the request will be routed to the specific (worker_id, dp_rank) pair.
If only dp_rank is set, the router will select the best worker but If only dp_rank is set, the router will select the best worker but
force routing to the specified dp_rank. force routing to the specified dp_rank.
extra_args: Optional extra request arguments to include in the
PreprocessedRequest.
block_mm_infos: Optional block-level multimodal metadata aligned to
request blocks. Backward-compatible shortcut; this is
converted to mm_routing_info with routing_token_ids=token_ids.
multi_modal_data: Optional multimodal payload map to preserve image/video
data for downstream model execution.
mm_routing_info: Optional structured routing-only multimodal payload
(e.g., {"routing_token_ids": [...], "block_mm_infos": [...]})
used by router selection without changing execution token_ids.
Returns: Returns:
An async iterator yielding generation responses An async iterator yielding generation responses
...@@ -1396,6 +1410,7 @@ class KvRouter: ...@@ -1396,6 +1410,7 @@ class KvRouter:
token_ids: List[int], token_ids: List[int],
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None, request_id: Optional[str] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
""" """
Find the best matching worker for the given tokens. Find the best matching worker for the given tokens.
...@@ -1406,6 +1421,9 @@ class KvRouter: ...@@ -1406,6 +1421,9 @@ class KvRouter:
request_id: Optional request ID. If provided, router states will be updated request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state. provided, this is a query-only operation that doesn't affect state.
block_mm_infos: Optional block-level multimodal metadata aligned to request
blocks. When provided, this is used in block hash computation
to enable MM-aware worker selection.
Returns: Returns:
A tuple of (worker_id, dp_rank, overlap_blocks) where: A tuple of (worker_id, dp_rank, overlap_blocks) where:
......
...@@ -114,6 +114,8 @@ pub enum RouterRequest { ...@@ -114,6 +114,8 @@ pub enum RouterRequest {
#[serde(rename = "new")] #[serde(rename = "new")]
New { New {
tokens: Vec<Token>, tokens: Vec<Token>,
#[serde(default, skip_serializing_if = "Option::is_none")]
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
}, },
MarkPrefill, MarkPrefill,
MarkFree, MarkFree,
...@@ -121,7 +123,10 @@ pub enum RouterRequest { ...@@ -121,7 +123,10 @@ pub enum RouterRequest {
impl Default for RouterRequest { impl Default for RouterRequest {
fn default() -> Self { fn default() -> Self {
RouterRequest::New { tokens: vec![] } RouterRequest::New {
tokens: vec![],
block_mm_infos: None,
}
} }
} }
......
...@@ -50,8 +50,8 @@ use crate::{ ...@@ -50,8 +50,8 @@ use crate::{
approx::PruneConfig, approx::PruneConfig,
indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError}, indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{ protocols::{
DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse, BlockExtraInfo, DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest,
TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank, RouterResponse, TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_block_hash_for_seq, compute_block_hash_for_seq,
}, },
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
...@@ -371,6 +371,7 @@ impl KvRouter { ...@@ -371,6 +371,7 @@ impl KvRouter {
&self, &self,
context_id: Option<&str>, context_id: Option<&str>,
tokens: &[u32], tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
...@@ -385,7 +386,7 @@ impl KvRouter { ...@@ -385,7 +386,7 @@ impl KvRouter {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes") let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, None)); .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, block_mm_infos));
let hash_elapsed = start.elapsed(); let hash_elapsed = start.elapsed();
let overlap_scores = self let overlap_scores = self
...@@ -566,9 +567,20 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -566,9 +567,20 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let context_id = ctx.context().id().to_string(); let context_id = ctx.context().id().to_string();
// Handle different request types // Handle different request types
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New {
tokens,
block_mm_infos,
} => {
let (best_worker, overlap_blocks) = self let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true, None, 0.0) .find_best_match(
Some(&context_id),
&tokens,
block_mm_infos.as_deref(),
None,
true,
None,
0.0,
)
.await?; .await?;
RouterResponse::New { RouterResponse::New {
......
...@@ -21,7 +21,7 @@ use dynamo_runtime::{ ...@@ -21,7 +21,7 @@ use dynamo_runtime::{
use crate::{ use crate::{
discovery::ModelManager, discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride}, kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride, protocols::BlockExtraInfo},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult}, protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::{RequestPhase, RequestTracker, WORKER_TYPE_PREFILL}, protocols::common::timing::{RequestPhase, RequestTracker, WORKER_TYPE_PREFILL},
...@@ -103,6 +103,19 @@ pub struct PrefillRouter { ...@@ -103,6 +103,19 @@ pub struct PrefillRouter {
} }
impl PrefillRouter { impl PrefillRouter {
fn routing_inputs(req: &PreprocessedRequest) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = req.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&req.token_ids, None)
}
/// Create a disabled prefill router that will never activate (passthrough only) /// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled( pub fn disabled(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
...@@ -285,8 +298,15 @@ impl PrefillRouter { ...@@ -285,8 +298,15 @@ impl PrefillRouter {
.as_ref() .as_ref()
.and_then(|r| r.priority_jump) .and_then(|r| r.priority_jump)
.unwrap_or(0.0); .unwrap_or(0.0);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(req);
match self match self
.query_prefill_worker(&req.token_ids, false, lora_name, priority_jump) .query_prefill_worker(
routing_token_ids,
block_mm_infos,
false,
lora_name,
priority_jump,
)
.await .await
{ {
Ok((worker_id, dp_rank)) => (worker_id, dp_rank), Ok((worker_id, dp_rank)) => (worker_id, dp_rank),
...@@ -475,6 +495,7 @@ impl PrefillRouter { ...@@ -475,6 +495,7 @@ impl PrefillRouter {
pub async fn query_prefill_worker( pub async fn query_prefill_worker(
&self, &self,
token_ids: &[u32], token_ids: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64, priority_jump: f64,
...@@ -491,6 +512,7 @@ impl PrefillRouter { ...@@ -491,6 +512,7 @@ impl PrefillRouter {
.find_best_match( .find_best_match(
None, None,
token_ids, token_ids,
block_mm_infos,
None, None,
update_states, update_states,
lora_name, lora_name,
......
...@@ -787,6 +787,58 @@ enum RawKvEvent { ...@@ -787,6 +787,58 @@ enum RawKvEvent {
AllBlocksCleared, AllBlocksCleared,
} }
/// Parse MM hash from extra_keys string:
/// - Only accept canonical vLLM MM identifiers (64-char hex digest)
/// - Convert by taking the first 16 hex chars as u64
fn parse_mm_hash_from_extra_key(s: &str) -> Option<u64> {
// extra_keys mixes MM identifiers with LoRA/cache_salt/prompt-embed metadata.
// Only MM identifiers should be mapped into BlockExtraInfo.
if s.len() == 64 && s.chars().all(|c| c.is_ascii_hexdigit()) {
return u64::from_str_radix(&s[..16], 16).ok();
}
None
}
/// Convert vLLM BlockStored extra_keys to block-level MM infos.
/// extra_keys is a list aligned with blocks:
/// - None => no MM content in that block
/// - ["hash1", "hash2", ...] => one or more MM objects in that block
fn extra_keys_to_block_mm_infos(
extra_keys: Option<Vec<Option<Vec<String>>>>,
) -> Option<Vec<Option<BlockExtraInfo>>> {
let extra_keys = extra_keys?;
if extra_keys.is_empty() {
return None;
}
let infos: Vec<Option<BlockExtraInfo>> = extra_keys
.into_iter()
.map(|block_keys| {
let mm_objects: Vec<BlockMmObjectInfo> = block_keys
.unwrap_or_default()
.iter()
.filter_map(|key| parse_mm_hash_from_extra_key(key))
.map(|mm_hash| BlockMmObjectInfo {
mm_hash,
offsets: vec![], // extra_keys does not carry offsets today
})
.collect();
if mm_objects.is_empty() {
None
} else {
Some(BlockExtraInfo { mm_objects })
}
})
.collect();
if infos.iter().all(|i| i.is_none()) {
return None;
}
Some(infos)
}
/// Our producers use msgspec with `tag=True` and `array_like=True`, which /// Our producers use msgspec with `tag=True` and `array_like=True`, which
/// encodes each event as either a tagged map or a tagged tuple. To be tolerant of /// encodes each event as either a tagged map or a tagged tuple. To be tolerant of
/// additional fields that may be appended in the future, we implement a custom /// additional fields that may be appended in the future, we implement a custom
...@@ -824,6 +876,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -824,6 +876,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut lora_id: Option<Option<u64>> = None; let mut lora_id: Option<Option<u64>> = None;
let mut medium: Option<Option<String>> = None; let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None; let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<String>>>>> = None;
let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None; let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None;
while let Some(key) = map.next_key::<String>()? { while let Some(key) = map.next_key::<String>()? {
...@@ -852,6 +905,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -852,6 +905,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"lora_name" => { "lora_name" => {
lora_name = Some(map.next_value()?); lora_name = Some(map.next_value()?);
} }
"extra_keys" => {
extra_keys = Some(map.next_value()?);
}
"block_mm_infos" => { "block_mm_infos" => {
block_mm_infos = Some(map.next_value()?); block_mm_infos = Some(map.next_value()?);
} }
...@@ -868,6 +924,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -868,6 +924,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?; let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?;
let block_size = let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?; block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let block_mm_infos = block_mm_infos
.unwrap_or(None)
.or_else(|| extra_keys_to_block_mm_infos(extra_keys.unwrap_or(None)));
Ok(RawKvEvent::BlockStored { Ok(RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash: parent_block_hash.unwrap_or(None), parent_block_hash: parent_block_hash.unwrap_or(None),
...@@ -876,7 +935,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -876,7 +935,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
lora_id: lora_id.unwrap_or(None), lora_id: lora_id.unwrap_or(None),
medium: medium.unwrap_or(None), medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None), lora_name: lora_name.unwrap_or(None),
block_mm_infos: block_mm_infos.unwrap_or(None), block_mm_infos,
}) })
} }
Some("BlockRemoved") => { Some("BlockRemoved") => {
...@@ -923,11 +982,16 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -923,11 +982,16 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None); let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
let medium: Option<String> = seq.next_element()?.unwrap_or(None); let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None); let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<String>>>> =
seq.next_element()?.unwrap_or(None);
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> =
seq.next_element()?.unwrap_or(None); seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {} while seq.next_element::<IgnoredAny>()?.is_some() {}
let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
Ok(RawKvEvent::BlockStored { Ok(RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash, parent_block_hash,
...@@ -1206,6 +1270,114 @@ mod test_event_processing { ...@@ -1206,6 +1270,114 @@ mod test_event_processing {
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared)); assert!(matches!(out.data, KvCacheEventData::Cleared));
} }
#[test]
fn test_parse_mm_hash_from_extra_key() {
assert_eq!(
parse_mm_hash_from_extra_key(
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210"
),
Some(0x0123_4567_89ab_cdef)
);
assert_eq!(parse_mm_hash_from_extra_key("123"), None);
assert_eq!(parse_mm_hash_from_extra_key("not_a_hash"), None);
}
#[test]
fn test_extra_keys_to_block_mm_infos() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let infos = extra_keys_to_block_mm_infos(Some(vec![
Some(vec![mm_hash.clone()]),
None,
Some(vec!["invalid".to_string(), mm_hash]),
]))
.expect("expected parsed MM infos");
assert_eq!(infos.len(), 3);
assert_eq!(
infos[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
assert!(infos[1].is_none());
assert_eq!(
infos[2].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_seq_block_stored_field8_supports_extra_keys() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let extra_keys_payload = rmps::to_vec(&(
"BlockStored",
vec![10_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
None::<u64>,
None::<String>,
None::<String>,
vec![Some(vec![mm_hash])],
))
.unwrap();
let extra_keys_event: RawKvEvent = rmps::from_slice(&extra_keys_payload).unwrap();
let RawKvEvent::BlockStored {
lora_name,
block_mm_infos,
..
} = extra_keys_event
else {
panic!("expected BlockStored");
};
assert!(lora_name.is_none());
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_map_block_stored_supports_extra_keys() {
#[derive(serde::Serialize)]
struct MapBlockStoredEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
medium: Option<String>,
lora_name: Option<String>,
extra_keys: Option<Vec<Option<Vec<String>>>>,
}
let payload = rmps::to_vec(&MapBlockStoredEvent {
event_type: "BlockStored",
block_hashes: vec![10],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: None,
medium: Some("GPU".to_string()),
lora_name: None,
extra_keys: Some(vec![Some(vec![
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string(),
])]),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { block_mm_infos, .. } = event else {
panic!("expected BlockStored");
};
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -19,7 +19,7 @@ use crate::{ ...@@ -19,7 +19,7 @@ use crate::{
kv_router::{ kv_router::{
KvRouter, KvRouter,
metrics::RouterRequestMetrics, metrics::RouterRequestMetrics,
protocols::{TokensWithHashes, WorkerWithDpRank}, protocols::{BlockExtraInfo, TokensWithHashes, WorkerWithDpRank},
}, },
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::{ protocols::common::{
...@@ -108,6 +108,21 @@ impl KvPushRouter { ...@@ -108,6 +108,21 @@ impl KvPushRouter {
KvPushRouter { inner, chooser } KvPushRouter { inner, chooser }
} }
fn routing_inputs(
request: &PreprocessedRequest,
) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = request.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&request.token_ids, None)
}
/// Select a worker for the request, either using a preselected worker or finding the best match. /// Select a worker for the request, either using a preselected worker or finding the best match.
/// ///
/// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`. /// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`.
...@@ -123,6 +138,7 @@ impl KvPushRouter { ...@@ -123,6 +138,7 @@ impl KvPushRouter {
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0); let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens); let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(request);
// Get pre-selected worker based on phase, with backend_instance_id as fallback // Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase { let preselected_id = match phase {
...@@ -140,7 +156,8 @@ impl KvPushRouter { ...@@ -140,7 +156,8 @@ impl KvPushRouter {
.chooser .chooser
.find_best_match( .find_best_match(
Some(context_id), Some(context_id),
&request.token_ids, routing_token_ids,
block_mm_infos,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!is_query_only, !is_query_only,
lora_name, lora_name,
...@@ -148,6 +165,27 @@ impl KvPushRouter { ...@@ -148,6 +165,27 @@ impl KvPushRouter {
) )
.await?; .await?;
if !is_query_only {
let total_blocks = routing_token_ids
.len()
.div_ceil(self.chooser.block_size() as usize);
// NOTE: tests/mm_router/test_vllm_mm_router_e2e.py parses this log line.
// Keep the "[ROUTING] ... with X/Y blocks overlap" shape stable unless
// router tests are updated together.
tracing::debug!(
request_id = %context_id,
worker_id = best_worker.worker_id,
dp_rank = best_worker.dp_rank,
overlap_blocks = overlap_amount,
total_blocks = total_blocks,
"[ROUTING] Best: worker_{} dp_rank={} with {}/{} blocks overlap",
best_worker.worker_id,
best_worker.dp_rank,
overlap_amount,
total_blocks,
);
}
return Ok(WorkerSelection { return Ok(WorkerSelection {
instance_id: best_worker.worker_id, instance_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank, dp_rank: best_worker.dp_rank,
...@@ -165,14 +203,14 @@ impl KvPushRouter { ...@@ -165,14 +203,14 @@ impl KvPushRouter {
let worker = WorkerWithDpRank::new(id, dp_rank); let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self let overlap_blocks = self
.chooser .chooser
.get_overlap_blocks(&request.token_ids, worker) .get_overlap_blocks(routing_token_ids, worker)
.await?; .await?;
if !is_query_only { if !is_query_only {
self.chooser self.chooser
.add_request( .add_request(
context_id.to_string(), context_id.to_string(),
&request.token_ids, routing_token_ids,
overlap_blocks, overlap_blocks,
expected_output_tokens, expected_output_tokens,
worker, worker,
...@@ -275,10 +313,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -275,10 +313,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let request_metrics = let request_metrics =
RouterRequestMetrics::from_component(self.chooser.client().endpoint.component()); RouterRequestMetrics::from_component(self.chooser.client().endpoint.component());
if let Some(ref tracker) = request.tracker { if let Some(ref tracker) = request.tracker {
let isl_blocks = request.token_ids.len().div_ceil(block_size); let (routing_token_ids, _) = Self::routing_inputs(&request);
let isl_blocks = routing_token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks); tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_isl( tracker.record_isl(
request.token_ids.len(), routing_token_ids.len(),
overlap_amount as usize * block_size, overlap_amount as usize * block_size,
); );
tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type()); tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
......
...@@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; ...@@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize};
use super::timing::RequestTracker; use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions}; use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride; use crate::kv_router::RouterConfigOverride;
use crate::kv_router::protocols::BlockExtraInfo;
use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
...@@ -72,6 +73,20 @@ pub struct PrefillResult { ...@@ -72,6 +73,20 @@ pub struct PrefillResult {
pub prompt_tokens_details: Option<dynamo_async_openai::types::PromptTokensDetails>, pub prompt_tokens_details: Option<dynamo_async_openai::types::PromptTokensDetails>,
} }
/// Optional multimodal routing-only data.
/// This is used by the router to compute overlaps on an alternate token sequence
/// (for example, MM-expanded tokens) without changing execution token_ids.
#[derive(Serialize, Deserialize, Debug, Clone, Default, Builder)]
#[builder(default)]
pub struct MmRoutingInfo {
/// Token IDs to use for routing overlap computation.
pub routing_token_ids: Vec<TokenIdType>,
/// Block-level multimodal metadata aligned with routing_token_ids blocks.
/// Use `None` entries for blocks without multimodal objects.
pub block_mm_infos: Vec<Option<BlockExtraInfo>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData { pub enum MultimodalData {
Url(url::Url), Url(url::Url),
...@@ -102,6 +117,11 @@ pub struct PreprocessedRequest { ...@@ -102,6 +117,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub multi_modal_data: Option<MultimodalDataMap>, pub multi_modal_data: Option<MultimodalDataMap>,
/// Optional multimodal routing-only fields (separate from execution payload).
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mm_routing_info: Option<MmRoutingInfo>,
/// StopConditions are conditions that the inference engine will use to stop generation. /// StopConditions are conditions that the inference engine will use to stop generation.
pub stop_conditions: StopConditions, pub stop_conditions: StopConditions,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""End-to-end tests for multimodal KV routing with vLLM.
Architecture:
Frontend -> MM Router Worker -> vLLM Worker
(computes mm_hash) (publishes KV events)
This test validates MM-aware routing by sending repeated multimodal requests and
asserting that router overlap is greater than 1 block (regression guard against
text-only/partially-matched hash paths that typically show 1/N overlap).
"""
from __future__ import annotations
import base64
import os
import re
import shutil
import time
from io import BytesIO
from typing import Any, Generator
import pytest
import requests
from tests.conftest import EtcdServer, NatsServer
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_models_api
from tests.utils.port_utils import allocate_ports
VLLM_MM_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
BLOCK_SIZE = 16
NAMESPACE = "dynamo"
THREE_IMAGE_TOTAL_BLOCKS_RANGE = (200, 340)
SINGLE_IMAGE_TOTAL_BLOCKS_RANGE = (60, 160)
pytestmark = [
pytest.mark.e2e,
pytest.mark.vllm,
pytest.mark.multimodal,
pytest.mark.gpu_1,
pytest.mark.model(VLLM_MM_MODEL),
]
_COLORS = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
]
_ALT_COLORS = [
(255, 255, 0),
(0, 255, 255),
(255, 0, 255),
]
_SINGLE_IMAGE_FRESH_COLOR = (123, 45, 67)
_DOUBLE_IMAGE_FRESH_COLOR = (89, 210, 34)
_STAIRCASE_IMAGE_FRESH_COLOR = (17, 99, 201)
_SWAP_ORDER_FRESH_COLORS = [(14, 141, 77), (211, 66, 101), (44, 91, 233)]
# Contract with lib/llm/src/kv_router/push_router.rs "[ROUTING]" debug log.
# Keep this parser in sync with the router log format.
_ROUTING_RECORD_PATTERN = re.compile(
r"\[ROUTING\].*with\s*(\d+)/(\d+)\s*blocks overlap"
)
def _check_ready(response) -> bool:
try:
return (response.json() or {}).get("status") == "ready"
except ValueError:
return False
def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]:
env = os.environ.copy()
env["DYN_LOG"] = log_level
env["DYN_NAMESPACE"] = NAMESPACE
env["DYN_REQUEST_PLANE"] = "nats"
env.update(extra)
return env
def _prepare_log_dir(request, suffix: str) -> str:
log_dir = f"{request.node.name}_{suffix}"
shutil.rmtree(log_dir, ignore_errors=True)
return log_dir
_COMMON_PROCESS_KWARGS: dict[str, Any] = {
"display_output": True,
"terminate_all_matching_process_names": False,
}
class VLLMWorkerProcess(ManagedProcess):
"""vLLM backend worker that emits KV events."""
def __init__(self, request, *, system_port: int):
super().__init__(
command=[
"python3",
"-m",
"dynamo.vllm",
"--model",
VLLM_MM_MODEL,
"--enable-multimodal",
"--gpu-memory-utilization",
"0.85",
"--max-model-len",
"8192",
"--connector",
"none",
"--served-model-name",
f"{VLLM_MM_MODEL}__internal",
],
env=_make_process_env(DYN_SYSTEM_PORT=str(system_port)),
health_check_urls=[
(f"http://localhost:{system_port}/health", _check_ready)
],
timeout=900,
straggler_commands=["-m dynamo.vllm"],
log_dir=_prepare_log_dir(request, "vllm-worker"),
**_COMMON_PROCESS_KWARGS,
)
class VLLMMMRouterWorkerProcess(ManagedProcess):
"""vLLM MM router worker."""
def __init__(self, request, *, system_port: int):
super().__init__(
command=[
"python3",
"-m",
"examples.backends.vllm.mm_router_worker",
"--model",
VLLM_MM_MODEL,
"--namespace",
NAMESPACE,
"--component",
"mm_router",
"--endpoint",
"generate",
"--downstream-component",
"backend",
"--downstream-endpoint",
"generate",
"--block-size",
str(BLOCK_SIZE),
],
env=_make_process_env(
DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS='["generate"]',
DYN_SYSTEM_PORT=str(system_port),
),
health_check_urls=[
(f"http://localhost:{system_port}/health", _check_ready)
],
timeout=240,
straggler_commands=["mm_router_worker"],
log_dir=_prepare_log_dir(request, "vllm-mm-router"),
**_COMMON_PROCESS_KWARGS,
)
class FrontendProcess(ManagedProcess):
"""Frontend HTTP ingress."""
def __init__(self, request, *, frontend_port: int):
super().__init__(
command=[
"python3",
"-m",
"dynamo.frontend",
"--http-port",
str(frontend_port),
"--router-mode",
"round-robin",
],
env=_make_process_env(log_level="info"),
health_check_urls=[
(f"http://localhost:{frontend_port}/v1/models", check_models_api)
],
timeout=240,
straggler_commands=["-m dynamo.frontend"],
log_dir=_prepare_log_dir(request, "vllm-mm-frontend"),
**_COMMON_PROCESS_KWARGS,
)
@pytest.fixture(scope="module")
def mm_runtime_services(request):
with NatsServer(request, port=0) as nats, EtcdServer(request, port=0) as etcd:
os.environ["NATS_SERVER"] = f"nats://localhost:{nats.port}"
os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd.port}"
yield
os.environ.pop("NATS_SERVER", None)
os.environ.pop("ETCD_ENDPOINTS", None)
@pytest.fixture(scope="module")
def start_vllm_mm_services(
request, mm_runtime_services
) -> Generator[tuple[int, ManagedProcess], None, None]:
frontend_port, vllm_port, router_port = allocate_ports(count=3, start_port=10000)
with VLLMWorkerProcess(request, system_port=vllm_port):
time.sleep(10)
with VLLMMMRouterWorkerProcess(request, system_port=router_port) as router_proc:
time.sleep(3)
with FrontendProcess(request, frontend_port=frontend_port):
yield frontend_port, router_proc
def _make_data_uri(color: tuple[int, int, int], size: int = 1024) -> str:
from PIL import Image
img = Image.new("RGB", (size, size), color)
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
def _build_payload(
image_uris: list[str], prompt: str = "Describe what you see."
) -> dict[str, Any]:
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
for uri in image_uris:
content.append({"type": "image_url", "image_url": {"url": uri}})
return {
"model": VLLM_MM_MODEL,
"messages": [{"role": "user", "content": content}],
"max_tokens": 1,
}
def _extract_routing_records(log_text: str) -> list[tuple[int, int]]:
return [
(int(overlap), int(total))
for overlap, total in _ROUTING_RECORD_PATTERN.findall(log_text)
]
def _wait_for_new_routing_score(
router_proc: ManagedProcess,
start_offset: int,
pre_request_routing_count: int,
timeout_s: float = 120.0,
) -> tuple[int, int, str]:
deadline = time.time() + timeout_s
last_segment = ""
while time.time() < deadline:
full_logs = router_proc.read_logs()
segment = full_logs[start_offset:]
last_segment = segment
records = _extract_routing_records(full_logs)
if len(records) >= pre_request_routing_count + 1:
overlap, total = records[-1]
return overlap, total, segment
time.sleep(1)
fallback_records = _extract_routing_records(last_segment)
if fallback_records:
overlap, total = fallback_records[-1]
return overlap, total, last_segment
return 0, 0, last_segment
def _send_request_get_overlap(
frontend_port: int,
router_proc: ManagedProcess,
payload: dict[str, Any],
label: str,
) -> tuple[int, int, str]:
"""Send one request and read the new routing overlap score."""
pre_request_logs = router_proc.read_logs()
start_offset = len(pre_request_logs)
pre_request_routing_count = len(_extract_routing_records(pre_request_logs))
resp = requests.post(
f"http://localhost:{frontend_port}/v1/chat/completions",
json=payload,
timeout=240,
)
assert resp.status_code == 200, f"HTTP {resp.status_code}: {resp.text}"
data = resp.json()
assert "choices" in data, f"Missing choices in response: {data}"
overlap, total, segment = _wait_for_new_routing_score(
router_proc=router_proc,
start_offset=start_offset,
pre_request_routing_count=pre_request_routing_count,
timeout_s=120,
)
print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}")
return overlap, total, segment
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_text_only_overlap_repeated_prompt(
start_vllm_mm_services, predownload_models
):
"""Text-only routing should increase overlap on repeat and then stabilize."""
frontend_port, router_proc = start_vllm_mm_services
prompt = (
"TEXT routing e2e unique case zeta-7f31. "
"Repeat this sentence to force multiple KV blocks. "
) * 80
payload = _build_payload([], prompt=prompt)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "text_only_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "text_only_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "text_only_req3"
)
assert total_1 > 0 and total_2 > 0 and total_3 > 0, (
f"Expected non-zero total blocks for text-only request, got "
f"{total_1}, {total_2}, {total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert abs(total_1 - total_2) <= 2 and abs(total_2 - total_3) <= 2, (
f"Expected text-only total blocks to remain stable across repeats, got "
f"req1={total_1}, req2={total_2}, req3={total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second text-only overlap > first, got "
f"req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third text-only overlap == second, got "
f"req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_repeated_three_images(
start_vllm_mm_services, predownload_models
):
"""For repeated same 3-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_vllm_mm_services
image_uris = [_make_data_uri(c) for c in _COLORS]
payload = _build_payload(
image_uris, prompt="MM routing e2e: repeated same 3-image request."
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_3_images_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_3_images_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_3_images_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_3 <= high, (
f"Unexpected total blocks for same 3 images (1024): "
f"got {total_3}, expected in [{low}, {high}]"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_repeated_single_image(
start_vllm_mm_services, predownload_models
):
"""For repeated same single-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_vllm_mm_services
payload = _build_payload(
[_make_data_uri(_SINGLE_IMAGE_FRESH_COLOR)],
prompt="MM routing e2e: repeated same single-image request.",
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_single_image_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_single_image_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_single_image_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
low, high = SINGLE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_3 <= high, (
f"Unexpected total blocks for same 1 image (1024): "
f"got {total_3}, expected in [{low}, {high}]"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_repeated_two_identical_images(
start_vllm_mm_services, predownload_models
):
"""For repeated same two-identical-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_vllm_mm_services
image_uri = _make_data_uri(_DOUBLE_IMAGE_FRESH_COLOR)
payload = _build_payload(
[image_uri, image_uri],
prompt="MM routing e2e: repeated same two-identical-image request.",
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_staircase_single_to_double_to_triple_identical_image(
start_vllm_mm_services, predownload_models
):
"""Single->double->triple identical image requests follow prefix-overlap semantics."""
frontend_port, router_proc = start_vllm_mm_services
image_uri = _make_data_uri(_STAIRCASE_IMAGE_FRESH_COLOR)
staircase_prompt = "MM routing e2e: staircase."
payload_single = _build_payload([image_uri], prompt=staircase_prompt)
payload_double = _build_payload([image_uri, image_uri], prompt=staircase_prompt)
payload_triple = _build_payload(
[image_uri, image_uri, image_uri], prompt=staircase_prompt
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload_single, "staircase_1x_image"
)
overlap_2, total_2, segment_2 = _send_request_get_overlap(
frontend_port, router_proc, payload_double, "staircase_2x_image"
)
time.sleep(1)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload_triple, "staircase_3x_image"
)
assert overlap_2 > overlap_1, (
f"Expected overlap to increase from 1 image to 2 images, got "
f"1x={overlap_1}/{total_1}, 2x={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_2[-4000:]}"
)
assert abs(overlap_3 - overlap_2) <= 1, (
"Expected first 3-image request overlap to stay near 2-image overlap "
"(third-image suffix is cold on first 3-image request), got "
f"2x={overlap_2}/{total_2}, 3x={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
total_step_12 = total_2 - total_1
total_step_23 = total_3 - total_2
assert abs(total_step_12 - total_step_23) <= 4, (
"Expected similar total-block increment per additional identical image, got "
f"step(1->2)={total_step_12}, step(2->3)={total_step_23}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_diff_images_less_than_same(
start_vllm_mm_services, predownload_models
):
"""Different images should produce lower overlap than repeated identical images."""
frontend_port, router_proc = start_vllm_mm_services
baseline_payload = _build_payload(
[_make_data_uri(c) for c in _COLORS],
prompt="MM routing e2e: baseline same-images overlap.",
)
overlap_baseline_1, total_baseline_1, _ = _send_request_get_overlap(
frontend_port, router_proc, baseline_payload, "baseline_same_images_req1"
)
overlap_baseline_2, total_baseline_2, segment_baseline = _send_request_get_overlap(
frontend_port, router_proc, baseline_payload, "baseline_same_images_req2"
)
overlap_baseline = max(overlap_baseline_1, overlap_baseline_2)
total_baseline = total_baseline_2
assert abs(total_baseline_1 - total_baseline_2) <= 4, (
"Expected total blocks to stay nearly identical for repeated same request, "
f"got req1={total_baseline_1}, req2={total_baseline_2}"
)
assert overlap_baseline >= 2, (
f"Baseline overlap did not reach 2 blocks. got {overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_baseline[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_baseline <= high, (
f"Unexpected total blocks for baseline same-images request: "
f"got {total_baseline}, expected in [{low}, {high}]"
)
probe_payload = _build_payload(
[_make_data_uri(c) for c in _ALT_COLORS],
prompt="MM routing e2e: baseline same-images overlap.",
)
overlap_probe, total_probe, segment_probe = _send_request_get_overlap(
frontend_port, router_proc, probe_payload, "probe_different_images_req1"
)
assert (
total_probe > 0
), f"No routing score found.\nRecent logs:\n{segment_probe[-4000:]}"
assert abs(total_probe - total_baseline) <= 4, (
f"Expected different-images total blocks to stay near baseline, "
f"got different={total_probe}, baseline={total_baseline}"
)
assert overlap_probe < overlap_baseline, (
f"Expected different-images overlap < baseline overlap, "
f"got different={overlap_probe}/{total_probe}, "
f"baseline={overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_probe[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_same_images_different_prompt_less_than_same_prompt(
start_vllm_mm_services, predownload_models
):
"""Same images but different prompt should produce lower overlap than repeated same prompt."""
frontend_port, router_proc = start_vllm_mm_services
baseline_payload = _build_payload(
[_make_data_uri(c) for c in _COLORS],
prompt="MM routing e2e: prompt-sensitive baseline alpha.",
)
overlap_baseline_1, total_baseline_1, _ = _send_request_get_overlap(
frontend_port,
router_proc,
baseline_payload,
"baseline_same_images_prompt_a_req1",
)
overlap_baseline_2, total_baseline_2, segment_baseline = _send_request_get_overlap(
frontend_port,
router_proc,
baseline_payload,
"baseline_same_images_prompt_a_req2",
)
overlap_baseline = max(overlap_baseline_1, overlap_baseline_2)
total_baseline = total_baseline_2
assert abs(total_baseline_1 - total_baseline_2) <= 4, (
"Expected total blocks to stay nearly identical for repeated same request, "
f"got req1={total_baseline_1}, req2={total_baseline_2}"
)
assert overlap_baseline >= 2, (
f"Baseline overlap did not reach 2 blocks. got {overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_baseline[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_baseline <= high, (
f"Unexpected total blocks for baseline same-images request: "
f"got {total_baseline}, expected in [{low}, {high}]"
)
probe_payload = _build_payload(
[_make_data_uri(c) for c in _COLORS],
prompt="MM routing e2e: prompt-sensitive baseline omega.",
)
overlap_probe, total_probe, segment_probe = _send_request_get_overlap(
frontend_port, router_proc, probe_payload, "probe_same_images_prompt_b_req1"
)
assert (
total_probe > 0
), f"No routing score found.\nRecent logs:\n{segment_probe[-4000:]}"
assert abs(total_probe - total_baseline) <= 4, (
f"Expected different-prompt total blocks to stay near baseline, "
f"got different_prompt={total_probe}, baseline={total_baseline}"
)
assert overlap_probe < overlap_baseline, (
f"Expected different-prompt overlap < baseline overlap, "
f"got different_prompt={overlap_probe}/{total_probe}, "
f"baseline={overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_probe[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_swapped_order_less_than_same_order(
start_vllm_mm_services, predownload_models
):
"""Swapping order of three distinct images should result in near-zero overlap."""
frontend_port, router_proc = start_vllm_mm_services
ordered_uris = [_make_data_uri(c) for c in _SWAP_ORDER_FRESH_COLORS]
ordered_payload = _build_payload(
ordered_uris, prompt="MM routing e2e: order sensitivity ordered baseline."
)
swapped_payload = _build_payload(
list(reversed(ordered_uris)),
prompt="MM routing e2e: order sensitivity ordered baseline.",
)
overlap_ordered_1, total_ordered_1, _ = _send_request_get_overlap(
frontend_port,
router_proc,
ordered_payload,
"ordered_distinct_images_req1",
)
overlap_ordered_2, total_ordered_2, segment_ordered_2 = _send_request_get_overlap(
frontend_port,
router_proc,
ordered_payload,
"ordered_distinct_images_req2",
)
overlap_swapped, total_swapped, segment_swapped = _send_request_get_overlap(
frontend_port,
router_proc,
swapped_payload,
"swapped_distinct_images_req1",
)
assert overlap_ordered_2 > overlap_ordered_1, (
"Expected repeated identical order to increase overlap before swapped-order probe, "
f"got req1={overlap_ordered_1}/{total_ordered_1}, req2={overlap_ordered_2}/{total_ordered_2}.\n"
f"Recent router logs:\n{segment_ordered_2[-4000:]}"
)
assert abs(total_swapped - total_ordered_2) <= 4, (
f"Expected swapped-order total blocks to stay near ordered baseline, "
f"got swapped={total_swapped}, ordered={total_ordered_2}"
)
assert overlap_swapped <= 1, (
"Expected near-zero overlap for swapped order of three distinct images "
f"(allowing 1 shared text block), got {overlap_swapped}/{total_swapped}.\n"
f"Recent router logs:\n{segment_swapped[-4000:]}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment