Unverified Commit 6634f33f authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix(perf): Skip duplicate image downloads and unnecessary image processing in...


fix(perf): Skip duplicate image downloads and unnecessary image processing in MM Router (vLLM) (#7080)
Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent ae770ad7
...@@ -17,6 +17,7 @@ import asyncio ...@@ -17,6 +17,7 @@ import asyncio
import base64 import base64
import binascii import binascii
import logging import logging
import os
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Final, List, Optional from typing import Any, Dict, Final, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -38,7 +39,7 @@ DECODED_VARIANT_KEY: Final = "Decoded" ...@@ -38,7 +39,7 @@ DECODED_VARIANT_KEY: Final = "Decoded"
class ImageLoader: class ImageLoader:
CACHE_SIZE_MAXIMUM = 8 CACHE_SIZE_MAXIMUM = int(os.environ.get("DYN_MM_IMAGE_CACHE_SIZE", "8"))
def __init__( def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0 self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
......
...@@ -130,7 +130,7 @@ Use the same environment in terminals 2/3/4/5: ...@@ -130,7 +130,7 @@ Use the same environment in terminals 2/3/4/5:
cd "$DYNAMO_ROOT" cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222 export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379 export ETCD_ENDPOINTS=http://127.0.0.1:2379
``` ```
...@@ -143,7 +143,7 @@ Use the same model string here and in the MM router. ...@@ -143,7 +143,7 @@ Use the same model string here and in the MM router.
cd "$DYNAMO_ROOT" cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222 export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379 export ETCD_ENDPOINTS=http://127.0.0.1:2379
export DYN_SYSTEM_PORT=18081 export DYN_SYSTEM_PORT=18081
...@@ -172,7 +172,7 @@ worker again for a repeated multimodal request. ...@@ -172,7 +172,7 @@ worker again for a repeated multimodal request.
cd "$DYNAMO_ROOT" cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222 export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379 export ETCD_ENDPOINTS=http://127.0.0.1:2379
export DYN_SYSTEM_PORT=18083 export DYN_SYSTEM_PORT=18083
...@@ -204,12 +204,12 @@ Important: ...@@ -204,12 +204,12 @@ Important:
- If you customize backend/MM router component names, update the MM router CLI args to match. - If you customize backend/MM router component names, update the MM router CLI args to match.
- `--block-size` defaults to `16`; if your vLLM backend uses a different KV cache block size, - `--block-size` defaults to `16`; if your vLLM backend uses a different KV cache block size,
pass the same value to the MM router. pass the same value to the MM router.
```bash ```bash
cd "$DYNAMO_ROOT" cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222 export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379 export ETCD_ENDPOINTS=http://127.0.0.1:2379
export DYN_LOG=debug export DYN_LOG=debug
...@@ -226,7 +226,7 @@ python -m examples.backends.vllm.mm_router_worker \ ...@@ -226,7 +226,7 @@ python -m examples.backends.vllm.mm_router_worker \
cd "$DYNAMO_ROOT" cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222 export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379 export ETCD_ENDPOINTS=http://127.0.0.1:2379
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """MM Router Handler — routes multimodal requests via KV-cache-aware worker selection."""
MM Router Handler - Routes requests to best vLLM worker based on KV cache overlap.
"""
import logging import logging
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.llm import KvRouter from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -18,10 +17,7 @@ logger = logging.getLogger(__name__) ...@@ -18,10 +17,7 @@ logger = logging.getLogger(__name__)
class MMRouterHandler: class MMRouterHandler:
""" """Routes requests to the vLLM worker with the best KV cache overlap."""
Handler that computes mm_hash for multimodal requests and routes
to the best vLLM worker based on KV cache overlap.
"""
def __init__( def __init__(
self, self,
...@@ -31,112 +27,31 @@ class MMRouterHandler: ...@@ -31,112 +27,31 @@ class MMRouterHandler:
model: str, model: str,
block_size: int, block_size: int,
): ):
"""
Initialize the MM Router Handler.
Args:
kv_router: KvRouter for KV-aware worker selection and routing
tokenizer: HuggingFace AutoTokenizer
processor: HuggingFace AutoProcessor (optional)
model: Model path/name
block_size: KV cache block size
"""
self.kv_router = kv_router self.kv_router = kv_router
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
self.model = model self.model = model
self.block_size = block_size self.block_size = block_size
self._image_loader = ImageLoader()
async def generate(self, request: dict) -> AsyncGenerator[dict, None]: async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
""" """Main entry point: process request, compute routing, forward to best worker."""
Main entry point - receives request, computes routing, forwards to best worker.
The request format (after Frontend preprocessing with ModelInput.Tokens):
{
"token_ids": [...],
"sampling_options": {...},
"stop_conditions": {...},
"extra_args": {"messages": [...]}
}
Args:
request: Preprocessed request from Frontend
Yields:
Response chunks from the downstream vLLM worker
"""
# Extract messages from extra_args (set by Frontend preprocessor)
messages = request.get("extra_args", {}).get("messages", []) messages = request.get("extra_args", {}).get("messages", [])
image_urls = extract_image_urls(messages) image_urls = extract_image_urls(messages)
if image_urls: if image_urls:
# Process multimodal: download images, compute mm_hash routing_tokens, block_mm_infos = await self._process_mm_request(
# Do not reuse request["token_ids"] for MM routing: those are placeholder-level request, messages, image_urls
# tokens from frontend. We need processor-expanded tokens to build block_mm_infos.
# Request payload does not include a rendered template string; extra_args carries
# original messages, so mm_processor reapplies chat template locally.
processed = process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
)
# Build block_mm_infos for MM-aware hash computation
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError(
"Failed to build block_mm_infos for multimodal request"
)
routing_tokens = processed.tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"MM request: {len(routing_tokens)} routing tokens, "
f"{len(image_urls)} images, {routing_blocks} routing blocks"
) )
else: else:
# Text-only: rely on frontend-preprocessed token_ids (ModelInput.Tokens contract) routing_tokens = request.get("token_ids")
tokens = request.get("token_ids") if not routing_tokens:
if not tokens: raise ValueError("Missing token_ids in preprocessed request")
raise ValueError( n_blocks = (len(routing_tokens) + self.block_size - 1) // self.block_size
"Missing or empty token_ids in preprocessed request for text-only routing" block_mm_infos = [None] * n_blocks
)
routing_tokens = tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"Text request: {len(routing_tokens)} routing tokens, {routing_blocks} routing blocks"
)
# Text-only routing has no multimodal objects; provide per-block None entries.
block_mm_infos = [None] * routing_blocks
# Route and generate through KvRouter with explicit fields.
# We pass:
# - execution payload (token_ids + multi_modal_data)
# - routing payload (mm_routing_info: routing_token_ids + block_mm_infos)
# so generate() can select worker internally while preserving MM correctness.
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError("Missing or empty token_ids in preprocessed request")
mm_routing_info: dict[str, Any] = {
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
}
stream = await self.kv_router.generate( stream = await self.kv_router.generate(
token_ids=token_ids, token_ids=request.get("token_ids"),
model=request["model"], model=request["model"],
stop_conditions=request.get("stop_conditions"), stop_conditions=request.get("stop_conditions"),
sampling_options=request.get("sampling_options"), sampling_options=request.get("sampling_options"),
...@@ -144,8 +59,52 @@ class MMRouterHandler: ...@@ -144,8 +59,52 @@ class MMRouterHandler:
router_config_override=request.get("router_config_override"), router_config_override=request.get("router_config_override"),
extra_args=request.get("extra_args"), extra_args=request.get("extra_args"),
multi_modal_data=request.get("multi_modal_data"), multi_modal_data=request.get("multi_modal_data"),
mm_routing_info=mm_routing_info, mm_routing_info={
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
},
) )
async for response in stream: async for response in stream:
yield response yield response
async def _process_mm_request(
self,
request: dict,
messages: list[dict],
image_urls: list[str],
) -> tuple[list[int], list[dict | None]]:
"""Process multimodal: load images, expand tokens, build routing info."""
processed = await process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
image_loader=self._image_loader,
)
# Strip image content from messages to reduce serialization payload
for msg in messages:
content = msg.get("content", [])
if isinstance(content, list):
for part in content:
if part.get("type") == "image_url":
part["image_url"]["url"] = "<stripped>"
# Rewrite Url → RawUrl to skip url::Url::parse in Rust depythonize
mm_data = request.get("multi_modal_data", {})
if isinstance(mm_data, dict):
for item in mm_data.get("image_url", []):
if isinstance(item, dict) and "Url" in item:
item["RawUrl"] = item.pop("Url")
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError("Failed to build block_mm_infos")
return processed.tokens, block_mm_infos
...@@ -105,7 +105,7 @@ echo ...@@ -105,7 +105,7 @@ echo
COMMON_ENV=( COMMON_ENV=(
"DYN_NAMESPACE=${NAMESPACE}" "DYN_NAMESPACE=${NAMESPACE}"
"DYN_REQUEST_PLANE=nats" "DYN_REQUEST_PLANE=tcp"
"NATS_SERVER=${NATS_SERVER}" "NATS_SERVER=${NATS_SERVER}"
"ETCD_ENDPOINTS=${ETCD_ENDPOINTS}" "ETCD_ENDPOINTS=${ETCD_ENDPOINTS}"
) )
......
...@@ -4,22 +4,19 @@ ...@@ -4,22 +4,19 @@
""" """
Multimodal processing utilities for vLLM MM Router Worker. Multimodal processing utilities for vLLM MM Router Worker.
Key differences from TRT-LLM version: Key differences from the TRT-LLM version:
- Image loading: PIL + requests/base64 (no TRT-LLM dependency) - mm_hash uses PIL image bytes to match the vLLM backend's multi_modal_uuids.
- mm_hash: SHA256 of normalized PNG bytes (matches vLLM multi_modal_uuids) - Token replacement is not needed — vLLM keeps the original image_token_id.
- Token replacement: NOT needed — vLLM keeps the original image_token_id as-is - Fast path token expansion computes token counts from image dimensions directly.
""" """
import base64
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from typing import Any, Sequence
from typing import Any
from urllib.parse import urlparse
import requests
from PIL import Image from PIL import Image
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.hash_utils import compute_mm_uuids_from_images from dynamo.vllm.multimodal_utils.hash_utils import compute_mm_uuids_from_images
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,42 +55,32 @@ def extract_image_urls(messages: list[dict]) -> list[str]: ...@@ -58,42 +55,32 @@ def extract_image_urls(messages: list[dict]) -> list[str]:
return urls return urls
def process_multimodal( async def process_multimodal(
messages: list[dict], messages: list[dict],
image_urls: list[str], image_urls: list[str],
tokenizer: Any, tokenizer: Any,
processor: Any, processor: Any,
model: str, model: str,
image_loader: ImageLoader,
) -> ProcessedInput: ) -> ProcessedInput:
""" """Process multimodal request: load images, get expanded tokens and mm_hashes.
Process multimodal request: load images, get expanded tokens and mm_hashes.
Uses PIL for image loading and hashlib for mm_hash computation. Uses the shared ImageLoader for async loading with HTTP cache.
Unlike TRT-LLM, vLLM keeps original image_token_id (no replacement). Hashes PIL images to natively match the vLLM backend's multi_modal_uuids.
""" """
# The preprocessed request does not carry a rendered template string; it carries prompt = _apply_chat_template(messages, tokenizer, processor)
# original messages in extra_args, so we must apply chat template again here.
prompt = _build_prompt_with_images(messages, tokenizer, processor) image_mm_items = [{"Url": url} for url in image_urls]
logger.info(f"Prompt (first 300 chars): {prompt[:300]}") pil_images = await image_loader.load_image_batch(image_mm_items)
image_dims = [(img.width, img.height) for img in pil_images]
# Load images as PIL
pil_images = []
for url in image_urls:
pil_img = _load_image(url)
pil_images.append(pil_img)
# Get expanded tokens and image ranges (no token replacement for vLLM)
tokens, image_ranges = _get_expanded_tokens( tokens, image_ranges = _get_expanded_tokens(
prompt, pil_images, tokenizer, processor prompt, image_dims, pil_images, tokenizer, processor
) )
logger.info(f"Expanded: {len(tokens)} tokens, " f"image_ranges={image_ranges}")
# Compute mm_hashes exactly like vLLM handler's multi_modal_uuids path.
mm_uuids = compute_mm_uuids_from_images(pil_images) mm_uuids = compute_mm_uuids_from_images(pil_images)
mm_hashes = [int(uuid[:16], 16) for uuid in mm_uuids] mm_hashes = [int(uuid[:16], 16) for uuid in mm_uuids]
logger.info(f"mm_hashes={mm_hashes}")
return ProcessedInput(tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges) return ProcessedInput(tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges)
...@@ -136,202 +123,138 @@ def build_block_mm_infos( ...@@ -136,202 +123,138 @@ def build_block_mm_infos(
# ============================================================================= # =============================================================================
# Internal functions # Token expansion: fast path (dimensions) -> slow path (HF processor)
# ============================================================================= # =============================================================================
def _build_prompt_with_images( def _apply_chat_template(messages: list[dict], tokenizer: Any, processor: Any) -> str:
messages: list[dict], tokenizer: Any, processor: Any """Re-apply chat template for routing token expansion.
) -> str:
"""
Build a prompt that includes image placeholders using the tokenizer's
chat template. This is critical for Qwen2-VL/Qwen2.5-VL models which
need <|vision_start|><|image_pad|>...<|vision_end|> in the prompt for
the processor to expand image tokens correctly.
Raises if chat template cannot be applied. For MM routing correctness, we do Cannot reuse Frontend's token_ids because the Frontend tokenizer may lack
not silently fall back to text-only prompts. vision-specific markers (e.g. <|vision_start|><|image_pad|><|vision_end|>
for Qwen). The processor's template produces the correct placeholder
structure needed for image token expansion and block_mm_infos.
""" """
# Try processor first (has the best chat template for multimodal) for obj in (processor, tokenizer):
if processor is not None and hasattr(processor, "apply_chat_template"): if obj is not None and hasattr(obj, "apply_chat_template"):
return processor.apply_chat_template( return obj.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
# Fall back to tokenizer if available
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
raise ValueError("Neither processor nor tokenizer provides apply_chat_template") raise ValueError("Neither processor nor tokenizer provides apply_chat_template")
def _load_image(url: str) -> Image.Image:
"""
Load an image from URL (http/https or data URI) and return a PIL RGB image.
"""
parsed = urlparse(url)
if parsed.scheme == "data":
# data:image/png;base64,<data>
_, data = parsed.path.split(",", 1)
raw_bytes = base64.b64decode(data)
elif parsed.scheme in ("http", "https"):
response = requests.get(url, timeout=30)
response.raise_for_status()
raw_bytes = response.content
else:
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
return Image.open(BytesIO(raw_bytes)).convert("RGB")
def _get_expanded_tokens( def _get_expanded_tokens(
prompt: str, prompt: str,
image_dims: list[tuple[int, int]],
pil_images: list[Image.Image], pil_images: list[Image.Image],
tokenizer: Any, tokenizer: Any,
processor: Any, processor: Any,
) -> tuple[list[int], list[tuple[int, int]] | None]: ) -> tuple[list[int], list[tuple[int, int]] | None]:
""" """Expand image placeholder tokens. Fast path from dims, slow path via processor."""
Get tokens with visual expansion and find each image's token range.
Unlike TRT-LLM, vLLM keeps the original image_token_id (no replacement).
"""
if processor is None: if processor is None:
return tokenizer.encode(prompt), None return tokenizer.encode(prompt), None
try: try:
output = processor( return _expand_from_dims(prompt, image_dims, tokenizer, processor)
text=[prompt], images=pil_images, return_tensors="pt", padding=True except Exception as e:
) logger.info("Fast path failed (%s), falling back to processor", e)
tokens = output["input_ids"][0].tolist()
# Get image_token_id from processor
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
raise ValueError("processor.image_token_id not found")
# Find contiguous image token ranges (NO replacement for vLLM)
contiguous_ranges = _find_image_token_ranges(tokens, image_token_id)
# Compute tokens per image from processor output
tokens_per_image = _compute_tokens_per_image(output, processor)
# Split ranges according to tokens_per_image
image_ranges = _compute_per_image_ranges(contiguous_ranges, tokens_per_image)
return tokens, image_ranges
try:
return _expand_with_processor(prompt, pil_images, tokenizer, processor)
except Exception as e: except Exception as e:
logger.warning(f"HF processor failed: {e}", exc_info=True) logger.warning("Slow path also failed: %s", e, exc_info=True)
return tokenizer.encode(prompt), None return tokenizer.encode(prompt), None
def _compute_tokens_per_image(processor_output: dict, processor: Any) -> list[int]: # -- Fast path --
"""
Compute the number of visual tokens for each image from processor output.
Only Qwen-style processors (Qwen2-VL, Qwen2.5-VL) are supported.
Other model families will raise ValueError. def _expand_from_dims(
""" prompt: str,
processor_cls = type(processor).__qualname__ image_dims: list[tuple[int, int]],
if "qwen" not in processor_cls.lower(): tokenizer: Any,
raise NotImplementedError( processor: Any,
f"_compute_tokens_per_image only supports Qwen-style processors " ) -> tuple[list[int], list[tuple[int, int]]]:
f"tuples. Got processor class: {processor_cls}" """Expand placeholders using dimension-based token counts (Qwen-style)."""
image_processor = processor.image_processor
get_num_patches = image_processor.get_number_of_image_patches
merge_size = image_processor.merge_size
image_token_id = processor.image_token_id
tokens_per_image = []
for w, h in image_dims:
n_patches: int = int(get_num_patches(h, w, {})) # type: ignore[arg-type]
tokens_per_image.append(n_patches // (merge_size**2))
base_tokens = tokenizer.encode(prompt)
placeholders = [i for i, t in enumerate(base_tokens) if t == image_token_id]
if len(placeholders) != len(image_dims):
raise ValueError(
f"Placeholder count ({len(placeholders)}) != image count ({len(image_dims)})"
) )
grid_thw = processor_output.get("image_grid_thw") expanded: list[int] = []
if grid_thw is None: ranges: list[tuple[int, int]] = []
raise ValueError("image_grid_thw not found in processor output") prev = 0
for idx, pos in enumerate(placeholders):
expanded.extend(base_tokens[prev:pos])
start = len(expanded)
n = tokens_per_image[idx]
expanded.extend([image_token_id] * n)
ranges.append((start, start + n))
prev = pos + 1
expanded.extend(base_tokens[prev:])
return expanded, ranges
merge_size = getattr(processor.image_processor, "merge_size", 2)
return [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
# -- Slow path --
def _find_image_token_ranges(
tokens: list[int], image_token_id: int
) -> list[tuple[int, int]]:
"""
Find all contiguous ranges of image tokens.
Unlike the TRT-LLM version, this does NOT replace tokens — vLLM keeps def _expand_with_processor(
the original image_token_id as-is in KV events. prompt: str,
pil_images: Sequence[Image.Image],
tokenizer: Any,
processor: Any,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""Expand using full HF processor (works for any model, ~55ms)."""
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
return tokens, None
Returns: list of (start, end) ranges for contiguous image token regions. merge_size = getattr(processor.image_processor, "merge_size", 2)
""" grid_thw = output.get("image_grid_thw")
ranges = [] if grid_thw is None:
start = None return tokens, None
tokens_per_image = [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
contiguous: list[tuple[int, int]] = []
run_start = None
for i, t in enumerate(tokens): for i, t in enumerate(tokens):
if t == image_token_id: if t == image_token_id:
if start is None: if run_start is None:
start = i run_start = i
elif start is not None: elif run_start is not None:
ranges.append((start, i)) contiguous.append((run_start, i))
start = None run_start = None
if run_start is not None:
if start is not None: contiguous.append((run_start, len(tokens)))
ranges.append((start, len(tokens)))
result: list[tuple[int, int]] = []
if ranges: img_idx = 0
logger.info( for rng_start, rng_end in contiguous:
f"Found {sum(e - s for s, e in ranges)} image tokens " pos = rng_start
f"(id={image_token_id}) in {len(ranges)} range(s)" while img_idx < len(tokens_per_image):
) needed = tokens_per_image[img_idx]
if pos + needed <= rng_end:
return ranges
def _compute_per_image_ranges(
contiguous_ranges: list[tuple[int, int]],
tokens_per_image: list[int],
) -> list[tuple[int, int]] | None:
"""
Split contiguous image token ranges by each image's token count.
Example: contiguous_ranges=[(0, 100)], tokens_per_image=[60, 40]
Returns: [(0, 60), (60, 100)] # image 1 at 0-60, image 2 at 60-100
"""
if not contiguous_ranges:
if tokens_per_image:
logger.warning(
f"No image tokens found but {len(tokens_per_image)} images expected"
)
return None
# Greedily assign images to ranges in order
result = []
image_idx = 0
for range_start, range_end in contiguous_ranges:
range_size = range_end - range_start
pos = range_start
consumed = 0
# Consume images that fit entirely in this range
# (a single image's tokens are always contiguous, cannot span ranges)
while image_idx < len(tokens_per_image):
needed = tokens_per_image[image_idx]
if consumed + needed <= range_size:
result.append((pos, pos + needed)) result.append((pos, pos + needed))
pos += needed pos += needed
consumed += needed img_idx += 1
image_idx += 1
else: else:
break break
return tokens, result if len(result) == len(tokens_per_image) else None
# Range must be exactly filled (no leftover image tokens)
if consumed != range_size:
logger.warning(
f"Range size mismatch: consumed {consumed} != range {range_size}"
)
return None
# All images must be consumed
if image_idx != len(tokens_per_image):
logger.warning(f"Not all images mapped: {image_idx} < {len(tokens_per_image)}")
return None
return result
...@@ -104,6 +104,8 @@ pub struct MmRoutingInfo { ...@@ -104,6 +104,8 @@ pub struct MmRoutingInfo {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData { pub enum MultimodalData {
Url(url::Url), Url(url::Url),
#[serde(rename(serialize = "Url"))]
RawUrl(String),
Decoded(RdmaMediaDataDescriptor), Decoded(RdmaMediaDataDescriptor),
} }
......
...@@ -18,7 +18,9 @@ import base64 ...@@ -18,7 +18,9 @@ import base64
import os import os
import re import re
import shutil import shutil
import threading
import time import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from io import BytesIO from io import BytesIO
from typing import Any, Generator from typing import Any, Generator
...@@ -58,6 +60,7 @@ _SINGLE_IMAGE_FRESH_COLOR = (123, 45, 67) ...@@ -58,6 +60,7 @@ _SINGLE_IMAGE_FRESH_COLOR = (123, 45, 67)
_DOUBLE_IMAGE_FRESH_COLOR = (89, 210, 34) _DOUBLE_IMAGE_FRESH_COLOR = (89, 210, 34)
_STAIRCASE_IMAGE_FRESH_COLOR = (17, 99, 201) _STAIRCASE_IMAGE_FRESH_COLOR = (17, 99, 201)
_SWAP_ORDER_FRESH_COLORS = [(14, 141, 77), (211, 66, 101), (44, 91, 233)] _SWAP_ORDER_FRESH_COLORS = [(14, 141, 77), (211, 66, 101), (44, 91, 233)]
_HTTP_IMAGE_COLORS = [(180, 30, 90), (30, 180, 90), (90, 30, 180)]
# Contract with lib/llm/src/kv_router/push_router.rs "[ROUTING]" debug log. # Contract with lib/llm/src/kv_router/push_router.rs "[ROUTING]" debug log.
# Keep this parser in sync with the router log format. # Keep this parser in sync with the router log format.
_ROUTING_RECORD_PATTERN = re.compile( _ROUTING_RECORD_PATTERN = re.compile(
...@@ -76,7 +79,7 @@ def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]: ...@@ -76,7 +79,7 @@ def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]:
env = os.environ.copy() env = os.environ.copy()
env["DYN_LOG"] = log_level env["DYN_LOG"] = log_level
env["DYN_NAMESPACE"] = NAMESPACE env["DYN_NAMESPACE"] = NAMESPACE
env["DYN_REQUEST_PLANE"] = "nats" env["DYN_REQUEST_PLANE"] = "tcp"
env.update(extra) env.update(extra)
return env return env
...@@ -210,13 +213,17 @@ def start_vllm_mm_services( ...@@ -210,13 +213,17 @@ def start_vllm_mm_services(
yield frontend_port, router_proc yield frontend_port, router_proc
def _make_data_uri(color: tuple[int, int, int], size: int = 1024) -> str: def _make_png_bytes(color: tuple[int, int, int], size: int = 1024) -> bytes:
from PIL import Image from PIL import Image
img = Image.new("RGB", (size, size), color) img = Image.new("RGB", (size, size), color)
buf = BytesIO() buf = BytesIO()
img.save(buf, format="PNG") img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return buf.getvalue()
def _make_data_uri(color: tuple[int, int, int], size: int = 1024) -> str:
b64 = base64.b64encode(_make_png_bytes(color, size)).decode("utf-8")
return f"data:image/png;base64,{b64}" return f"data:image/png;base64,{b64}"
...@@ -293,6 +300,7 @@ def _send_request_get_overlap( ...@@ -293,6 +300,7 @@ def _send_request_get_overlap(
timeout_s=120, timeout_s=120,
) )
print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}") print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}")
time.sleep(1)
return overlap, total, segment return overlap, total, segment
...@@ -684,5 +692,139 @@ def test_vllm_mm_overlap_swapped_order_less_than_same_order( ...@@ -684,5 +692,139 @@ def test_vllm_mm_overlap_swapped_order_less_than_same_order(
) )
def _make_image_handler(image_map: dict[str, bytes]) -> type:
"""Create an HTTP handler class that serves images from the given map."""
class _ImageHandler(BaseHTTPRequestHandler):
def do_GET(self):
data = image_map.get(self.path)
if data is None:
self.send_error(404)
return
self.send_response(200)
self.send_header("Content-Type", "image/png")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def log_message(self, format, *args):
pass # suppress noisy request logs
return _ImageHandler
@pytest.fixture(scope="module")
def http_image_server() -> Generator[list[str], None, None]:
"""Serve pre-generated PNG images over HTTP for the duration of the module."""
(port,) = allocate_ports(count=1, start_port=18000)
image_map: dict[str, bytes] = {}
for i, color in enumerate(_HTTP_IMAGE_COLORS):
image_map[f"/image_{i}.png"] = _make_png_bytes(color)
server = HTTPServer(("127.0.0.1", port), _make_image_handler(image_map))
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
urls = [
f"http://127.0.0.1:{port}/image_{i}.png" for i in range(len(_HTTP_IMAGE_COLORS))
]
yield urls
server.shutdown()
server.server_close()
thread.join(timeout=5)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_repeated_http_images(
start_vllm_mm_services, predownload_models, http_image_server
):
"""For repeated same 3-HTTP-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_vllm_mm_services
payload = _build_payload(
http_image_server, prompt="MM routing e2e: repeated same 3 HTTP images."
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "http_3_images_req1"
)
time.sleep(1)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "http_3_images_req2"
)
time.sleep(1)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "http_3_images_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_3 <= high, (
f"Unexpected total blocks for same 3 HTTP images (1024): "
f"got {total_3}, expected in [{low}, {high}]"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_http_vs_data_uri_same_image(
start_vllm_mm_services, predownload_models, http_image_server
):
"""HTTP URL and data URI for the same image should produce identical KV cache hashes."""
frontend_port, router_proc = start_vllm_mm_services
# Use the first HTTP image color to build both representations
color = _HTTP_IMAGE_COLORS[0]
data_uri = _make_data_uri(color)
http_url = http_image_server[0]
# Seed KV cache with data URI request
data_uri_payload = _build_payload(
[data_uri], prompt="MM routing e2e: HTTP vs data URI same image."
)
overlap_data, total_data, _ = _send_request_get_overlap(
frontend_port, router_proc, data_uri_payload, "data_uri_seed"
)
time.sleep(1)
# Now send HTTP URL request for the identical image
http_payload = _build_payload(
[http_url], prompt="MM routing e2e: HTTP vs data URI same image."
)
overlap_http, total_http, segment_http = _send_request_get_overlap(
frontend_port, router_proc, http_payload, "http_probe"
)
assert total_http > 0, (
f"No routing score for HTTP request.\n"
f"Recent router logs:\n{segment_http[-4000:]}"
)
assert abs(total_http - total_data) <= 2, (
f"Expected HTTP and data URI total blocks to match, "
f"got http={total_http}, data_uri={total_data}.\n"
f"Recent router logs:\n{segment_http[-4000:]}"
)
assert overlap_http > overlap_data, (
f"Expected HTTP probe overlap > data URI seed overlap "
f"(proving image cache hit, not just text overlap), "
f"got http={overlap_http}/{total_http}, data_uri={overlap_data}/{total_data}.\n"
f"Recent router logs:\n{segment_http[-4000:]}"
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"]) pytest.main([__file__, "-v", "-s"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment