Unverified Commit 6634f33f authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix(perf): Skip duplicate image downloads and unnecessary image processing in...


fix(perf): Skip duplicate image downloads and unnecessary image processing in MM Router (vLLM) (#7080)
Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent ae770ad7
......@@ -17,6 +17,7 @@ import asyncio
import base64
import binascii
import logging
import os
from io import BytesIO
from typing import Any, Dict, Final, List, Optional
from urllib.parse import urlparse
......@@ -38,7 +39,7 @@ DECODED_VARIANT_KEY: Final = "Decoded"
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
CACHE_SIZE_MAXIMUM = int(os.environ.get("DYN_MM_IMAGE_CACHE_SIZE", "8"))
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
......
......@@ -130,7 +130,7 @@ Use the same environment in terminals 2/3/4/5:
cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats
export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379
```
......@@ -143,7 +143,7 @@ Use the same model string here and in the MM router.
cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats
export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379
export DYN_SYSTEM_PORT=18081
......@@ -172,7 +172,7 @@ worker again for a repeated multimodal request.
cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats
export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379
export DYN_SYSTEM_PORT=18083
......@@ -204,12 +204,12 @@ Important:
- If you customize backend/MM router component names, update the MM router CLI args to match.
- `--block-size` defaults to `16`; if your vLLM backend uses a different KV cache block size,
pass the same value to the MM router.
```bash
cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats
export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379
export DYN_LOG=debug
......@@ -226,7 +226,7 @@ python -m examples.backends.vllm.mm_router_worker \
cd "$DYNAMO_ROOT"
export DYN_NAMESPACE=dynamo
export DYN_REQUEST_PLANE=nats
export DYN_REQUEST_PLANE=tcp
export NATS_SERVER=nats://127.0.0.1:4222
export ETCD_ENDPOINTS=http://127.0.0.1:2379
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Handler - Routes requests to best vLLM worker based on KV cache overlap.
"""
"""MM Router Handler — routes multimodal requests via KV-cache-aware worker selection."""
import logging
from typing import Any, AsyncGenerator
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -18,10 +17,7 @@ logger = logging.getLogger(__name__)
class MMRouterHandler:
"""
Handler that computes mm_hash for multimodal requests and routes
to the best vLLM worker based on KV cache overlap.
"""
"""Routes requests to the vLLM worker with the best KV cache overlap."""
def __init__(
self,
......@@ -31,112 +27,31 @@ class MMRouterHandler:
model: str,
block_size: int,
):
"""
Initialize the MM Router Handler.
Args:
kv_router: KvRouter for KV-aware worker selection and routing
tokenizer: HuggingFace AutoTokenizer
processor: HuggingFace AutoProcessor (optional)
model: Model path/name
block_size: KV cache block size
"""
self.kv_router = kv_router
self.tokenizer = tokenizer
self.processor = processor
self.model = model
self.block_size = block_size
self._image_loader = ImageLoader()
async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
"""
Main entry point - receives request, computes routing, forwards to best worker.
The request format (after Frontend preprocessing with ModelInput.Tokens):
{
"token_ids": [...],
"sampling_options": {...},
"stop_conditions": {...},
"extra_args": {"messages": [...]}
}
Args:
request: Preprocessed request from Frontend
Yields:
Response chunks from the downstream vLLM worker
"""
# Extract messages from extra_args (set by Frontend preprocessor)
"""Main entry point: process request, compute routing, forward to best worker."""
messages = request.get("extra_args", {}).get("messages", [])
image_urls = extract_image_urls(messages)
if image_urls:
# Process multimodal: download images, compute mm_hash
# Do not reuse request["token_ids"] for MM routing: those are placeholder-level
# tokens from frontend. We need processor-expanded tokens to build block_mm_infos.
# Request payload does not include a rendered template string; extra_args carries
# original messages, so mm_processor reapplies chat template locally.
processed = process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
)
# Build block_mm_infos for MM-aware hash computation
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError(
"Failed to build block_mm_infos for multimodal request"
)
routing_tokens = processed.tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"MM request: {len(routing_tokens)} routing tokens, "
f"{len(image_urls)} images, {routing_blocks} routing blocks"
routing_tokens, block_mm_infos = await self._process_mm_request(
request, messages, image_urls
)
else:
# Text-only: rely on frontend-preprocessed token_ids (ModelInput.Tokens contract)
tokens = request.get("token_ids")
if not tokens:
raise ValueError(
"Missing or empty token_ids in preprocessed request for text-only routing"
)
routing_tokens = tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"Text request: {len(routing_tokens)} routing tokens, {routing_blocks} routing blocks"
)
# Text-only routing has no multimodal objects; provide per-block None entries.
block_mm_infos = [None] * routing_blocks
# Route and generate through KvRouter with explicit fields.
# We pass:
# - execution payload (token_ids + multi_modal_data)
# - routing payload (mm_routing_info: routing_token_ids + block_mm_infos)
# so generate() can select worker internally while preserving MM correctness.
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError("Missing or empty token_ids in preprocessed request")
mm_routing_info: dict[str, Any] = {
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
}
routing_tokens = request.get("token_ids")
if not routing_tokens:
raise ValueError("Missing token_ids in preprocessed request")
n_blocks = (len(routing_tokens) + self.block_size - 1) // self.block_size
block_mm_infos = [None] * n_blocks
stream = await self.kv_router.generate(
token_ids=token_ids,
token_ids=request.get("token_ids"),
model=request["model"],
stop_conditions=request.get("stop_conditions"),
sampling_options=request.get("sampling_options"),
......@@ -144,8 +59,52 @@ class MMRouterHandler:
router_config_override=request.get("router_config_override"),
extra_args=request.get("extra_args"),
multi_modal_data=request.get("multi_modal_data"),
mm_routing_info=mm_routing_info,
mm_routing_info={
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
},
)
async for response in stream:
yield response
async def _process_mm_request(
self,
request: dict,
messages: list[dict],
image_urls: list[str],
) -> tuple[list[int], list[dict | None]]:
"""Process multimodal: load images, expand tokens, build routing info."""
processed = await process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
image_loader=self._image_loader,
)
# Strip image content from messages to reduce serialization payload
for msg in messages:
content = msg.get("content", [])
if isinstance(content, list):
for part in content:
if part.get("type") == "image_url":
part["image_url"]["url"] = "<stripped>"
# Rewrite Url → RawUrl to skip url::Url::parse in Rust depythonize
mm_data = request.get("multi_modal_data", {})
if isinstance(mm_data, dict):
for item in mm_data.get("image_url", []):
if isinstance(item, dict) and "Url" in item:
item["RawUrl"] = item.pop("Url")
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError("Failed to build block_mm_infos")
return processed.tokens, block_mm_infos
......@@ -105,7 +105,7 @@ echo
COMMON_ENV=(
"DYN_NAMESPACE=${NAMESPACE}"
"DYN_REQUEST_PLANE=nats"
"DYN_REQUEST_PLANE=tcp"
"NATS_SERVER=${NATS_SERVER}"
"ETCD_ENDPOINTS=${ETCD_ENDPOINTS}"
)
......
......@@ -4,22 +4,19 @@
"""
Multimodal processing utilities for vLLM MM Router Worker.
Key differences from TRT-LLM version:
- Image loading: PIL + requests/base64 (no TRT-LLM dependency)
- mm_hash: SHA256 of normalized PNG bytes (matches vLLM multi_modal_uuids)
- Token replacement: NOT needed — vLLM keeps the original image_token_id as-is
Key differences from the TRT-LLM version:
- mm_hash uses PIL image bytes to match the vLLM backend's multi_modal_uuids.
- Token replacement is not needed — vLLM keeps the original image_token_id.
- Fast path token expansion computes token counts from image dimensions directly.
"""
import base64
import logging
from dataclasses import dataclass
from io import BytesIO
from typing import Any
from urllib.parse import urlparse
from typing import Any, Sequence
import requests
from PIL import Image
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.hash_utils import compute_mm_uuids_from_images
logger = logging.getLogger(__name__)
......@@ -58,42 +55,32 @@ def extract_image_urls(messages: list[dict]) -> list[str]:
return urls
def process_multimodal(
async def process_multimodal(
messages: list[dict],
image_urls: list[str],
tokenizer: Any,
processor: Any,
model: str,
image_loader: ImageLoader,
) -> ProcessedInput:
"""
Process multimodal request: load images, get expanded tokens and mm_hashes.
"""Process multimodal request: load images, get expanded tokens and mm_hashes.
Uses PIL for image loading and hashlib for mm_hash computation.
Unlike TRT-LLM, vLLM keeps original image_token_id (no replacement).
Uses the shared ImageLoader for async loading with HTTP cache.
Hashes PIL images to natively match the vLLM backend's multi_modal_uuids.
"""
# The preprocessed request does not carry a rendered template string; it carries
# original messages in extra_args, so we must apply chat template again here.
prompt = _build_prompt_with_images(messages, tokenizer, processor)
logger.info(f"Prompt (first 300 chars): {prompt[:300]}")
# Load images as PIL
pil_images = []
for url in image_urls:
pil_img = _load_image(url)
pil_images.append(pil_img)
# Get expanded tokens and image ranges (no token replacement for vLLM)
prompt = _apply_chat_template(messages, tokenizer, processor)
image_mm_items = [{"Url": url} for url in image_urls]
pil_images = await image_loader.load_image_batch(image_mm_items)
image_dims = [(img.width, img.height) for img in pil_images]
tokens, image_ranges = _get_expanded_tokens(
prompt, pil_images, tokenizer, processor
prompt, image_dims, pil_images, tokenizer, processor
)
logger.info(f"Expanded: {len(tokens)} tokens, " f"image_ranges={image_ranges}")
# Compute mm_hashes exactly like vLLM handler's multi_modal_uuids path.
mm_uuids = compute_mm_uuids_from_images(pil_images)
mm_hashes = [int(uuid[:16], 16) for uuid in mm_uuids]
logger.info(f"mm_hashes={mm_hashes}")
return ProcessedInput(tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges)
......@@ -136,202 +123,138 @@ def build_block_mm_infos(
# =============================================================================
# Internal functions
# Token expansion: fast path (dimensions) -> slow path (HF processor)
# =============================================================================
def _build_prompt_with_images(
messages: list[dict], tokenizer: Any, processor: Any
) -> str:
"""
Build a prompt that includes image placeholders using the tokenizer's
chat template. This is critical for Qwen2-VL/Qwen2.5-VL models which
need <|vision_start|><|image_pad|>...<|vision_end|> in the prompt for
the processor to expand image tokens correctly.
def _apply_chat_template(messages: list[dict], tokenizer: Any, processor: Any) -> str:
"""Re-apply chat template for routing token expansion.
Raises if chat template cannot be applied. For MM routing correctness, we do
not silently fall back to text-only prompts.
Cannot reuse Frontend's token_ids because the Frontend tokenizer may lack
vision-specific markers (e.g. <|vision_start|><|image_pad|><|vision_end|>
for Qwen). The processor's template produces the correct placeholder
structure needed for image token expansion and block_mm_infos.
"""
# Try processor first (has the best chat template for multimodal)
if processor is not None and hasattr(processor, "apply_chat_template"):
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Fall back to tokenizer if available
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for obj in (processor, tokenizer):
if obj is not None and hasattr(obj, "apply_chat_template"):
return obj.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
raise ValueError("Neither processor nor tokenizer provides apply_chat_template")
def _load_image(url: str) -> Image.Image:
"""
Load an image from URL (http/https or data URI) and return a PIL RGB image.
"""
parsed = urlparse(url)
if parsed.scheme == "data":
# data:image/png;base64,<data>
_, data = parsed.path.split(",", 1)
raw_bytes = base64.b64decode(data)
elif parsed.scheme in ("http", "https"):
response = requests.get(url, timeout=30)
response.raise_for_status()
raw_bytes = response.content
else:
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
return Image.open(BytesIO(raw_bytes)).convert("RGB")
def _get_expanded_tokens(
prompt: str,
image_dims: list[tuple[int, int]],
pil_images: list[Image.Image],
tokenizer: Any,
processor: Any,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""
Get tokens with visual expansion and find each image's token range.
Unlike TRT-LLM, vLLM keeps the original image_token_id (no replacement).
"""
"""Expand image placeholder tokens. Fast path from dims, slow path via processor."""
if processor is None:
return tokenizer.encode(prompt), None
try:
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
# Get image_token_id from processor
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
raise ValueError("processor.image_token_id not found")
# Find contiguous image token ranges (NO replacement for vLLM)
contiguous_ranges = _find_image_token_ranges(tokens, image_token_id)
# Compute tokens per image from processor output
tokens_per_image = _compute_tokens_per_image(output, processor)
# Split ranges according to tokens_per_image
image_ranges = _compute_per_image_ranges(contiguous_ranges, tokens_per_image)
return tokens, image_ranges
return _expand_from_dims(prompt, image_dims, tokenizer, processor)
except Exception as e:
logger.info("Fast path failed (%s), falling back to processor", e)
try:
return _expand_with_processor(prompt, pil_images, tokenizer, processor)
except Exception as e:
logger.warning(f"HF processor failed: {e}", exc_info=True)
logger.warning("Slow path also failed: %s", e, exc_info=True)
return tokenizer.encode(prompt), None
def _compute_tokens_per_image(processor_output: dict, processor: Any) -> list[int]:
"""
Compute the number of visual tokens for each image from processor output.
# -- Fast path --
Only Qwen-style processors (Qwen2-VL, Qwen2.5-VL) are supported.
Other model families will raise ValueError.
"""
processor_cls = type(processor).__qualname__
if "qwen" not in processor_cls.lower():
raise NotImplementedError(
f"_compute_tokens_per_image only supports Qwen-style processors "
f"tuples. Got processor class: {processor_cls}"
def _expand_from_dims(
prompt: str,
image_dims: list[tuple[int, int]],
tokenizer: Any,
processor: Any,
) -> tuple[list[int], list[tuple[int, int]]]:
"""Expand placeholders using dimension-based token counts (Qwen-style)."""
image_processor = processor.image_processor
get_num_patches = image_processor.get_number_of_image_patches
merge_size = image_processor.merge_size
image_token_id = processor.image_token_id
tokens_per_image = []
for w, h in image_dims:
n_patches: int = int(get_num_patches(h, w, {})) # type: ignore[arg-type]
tokens_per_image.append(n_patches // (merge_size**2))
base_tokens = tokenizer.encode(prompt)
placeholders = [i for i, t in enumerate(base_tokens) if t == image_token_id]
if len(placeholders) != len(image_dims):
raise ValueError(
f"Placeholder count ({len(placeholders)}) != image count ({len(image_dims)})"
)
grid_thw = processor_output.get("image_grid_thw")
if grid_thw is None:
raise ValueError("image_grid_thw not found in processor output")
expanded: list[int] = []
ranges: list[tuple[int, int]] = []
prev = 0
for idx, pos in enumerate(placeholders):
expanded.extend(base_tokens[prev:pos])
start = len(expanded)
n = tokens_per_image[idx]
expanded.extend([image_token_id] * n)
ranges.append((start, start + n))
prev = pos + 1
expanded.extend(base_tokens[prev:])
return expanded, ranges
merge_size = getattr(processor.image_processor, "merge_size", 2)
return [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
# -- Slow path --
def _find_image_token_ranges(
tokens: list[int], image_token_id: int
) -> list[tuple[int, int]]:
"""
Find all contiguous ranges of image tokens.
Unlike the TRT-LLM version, this does NOT replace tokens — vLLM keeps
the original image_token_id as-is in KV events.
def _expand_with_processor(
prompt: str,
pil_images: Sequence[Image.Image],
tokenizer: Any,
processor: Any,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""Expand using full HF processor (works for any model, ~55ms)."""
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
return tokens, None
Returns: list of (start, end) ranges for contiguous image token regions.
"""
ranges = []
start = None
merge_size = getattr(processor.image_processor, "merge_size", 2)
grid_thw = output.get("image_grid_thw")
if grid_thw is None:
return tokens, None
tokens_per_image = [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
contiguous: list[tuple[int, int]] = []
run_start = None
for i, t in enumerate(tokens):
if t == image_token_id:
if start is None:
start = i
elif start is not None:
ranges.append((start, i))
start = None
if start is not None:
ranges.append((start, len(tokens)))
if ranges:
logger.info(
f"Found {sum(e - s for s, e in ranges)} image tokens "
f"(id={image_token_id}) in {len(ranges)} range(s)"
)
return ranges
def _compute_per_image_ranges(
contiguous_ranges: list[tuple[int, int]],
tokens_per_image: list[int],
) -> list[tuple[int, int]] | None:
"""
Split contiguous image token ranges by each image's token count.
Example: contiguous_ranges=[(0, 100)], tokens_per_image=[60, 40]
Returns: [(0, 60), (60, 100)] # image 1 at 0-60, image 2 at 60-100
"""
if not contiguous_ranges:
if tokens_per_image:
logger.warning(
f"No image tokens found but {len(tokens_per_image)} images expected"
)
return None
# Greedily assign images to ranges in order
result = []
image_idx = 0
for range_start, range_end in contiguous_ranges:
range_size = range_end - range_start
pos = range_start
consumed = 0
# Consume images that fit entirely in this range
# (a single image's tokens are always contiguous, cannot span ranges)
while image_idx < len(tokens_per_image):
needed = tokens_per_image[image_idx]
if consumed + needed <= range_size:
if run_start is None:
run_start = i
elif run_start is not None:
contiguous.append((run_start, i))
run_start = None
if run_start is not None:
contiguous.append((run_start, len(tokens)))
result: list[tuple[int, int]] = []
img_idx = 0
for rng_start, rng_end in contiguous:
pos = rng_start
while img_idx < len(tokens_per_image):
needed = tokens_per_image[img_idx]
if pos + needed <= rng_end:
result.append((pos, pos + needed))
pos += needed
consumed += needed
image_idx += 1
img_idx += 1
else:
break
# Range must be exactly filled (no leftover image tokens)
if consumed != range_size:
logger.warning(
f"Range size mismatch: consumed {consumed} != range {range_size}"
)
return None
# All images must be consumed
if image_idx != len(tokens_per_image):
logger.warning(f"Not all images mapped: {image_idx} < {len(tokens_per_image)}")
return None
return result
return tokens, result if len(result) == len(tokens_per_image) else None
......@@ -104,6 +104,8 @@ pub struct MmRoutingInfo {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
#[serde(rename(serialize = "Url"))]
RawUrl(String),
Decoded(RdmaMediaDataDescriptor),
}
......
......@@ -18,7 +18,9 @@ import base64
import os
import re
import shutil
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from io import BytesIO
from typing import Any, Generator
......@@ -58,6 +60,7 @@ _SINGLE_IMAGE_FRESH_COLOR = (123, 45, 67)
_DOUBLE_IMAGE_FRESH_COLOR = (89, 210, 34)
_STAIRCASE_IMAGE_FRESH_COLOR = (17, 99, 201)
_SWAP_ORDER_FRESH_COLORS = [(14, 141, 77), (211, 66, 101), (44, 91, 233)]
_HTTP_IMAGE_COLORS = [(180, 30, 90), (30, 180, 90), (90, 30, 180)]
# Contract with lib/llm/src/kv_router/push_router.rs "[ROUTING]" debug log.
# Keep this parser in sync with the router log format.
_ROUTING_RECORD_PATTERN = re.compile(
......@@ -76,7 +79,7 @@ def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]:
env = os.environ.copy()
env["DYN_LOG"] = log_level
env["DYN_NAMESPACE"] = NAMESPACE
env["DYN_REQUEST_PLANE"] = "nats"
env["DYN_REQUEST_PLANE"] = "tcp"
env.update(extra)
return env
......@@ -210,13 +213,17 @@ def start_vllm_mm_services(
yield frontend_port, router_proc
def _make_data_uri(color: tuple[int, int, int], size: int = 1024) -> str:
def _make_png_bytes(color: tuple[int, int, int], size: int = 1024) -> bytes:
from PIL import Image
img = Image.new("RGB", (size, size), color)
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return buf.getvalue()
def _make_data_uri(color: tuple[int, int, int], size: int = 1024) -> str:
b64 = base64.b64encode(_make_png_bytes(color, size)).decode("utf-8")
return f"data:image/png;base64,{b64}"
......@@ -293,6 +300,7 @@ def _send_request_get_overlap(
timeout_s=120,
)
print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}")
time.sleep(1)
return overlap, total, segment
......@@ -684,5 +692,139 @@ def test_vllm_mm_overlap_swapped_order_less_than_same_order(
)
def _make_image_handler(image_map: dict[str, bytes]) -> type:
"""Create an HTTP handler class that serves images from the given map."""
class _ImageHandler(BaseHTTPRequestHandler):
def do_GET(self):
data = image_map.get(self.path)
if data is None:
self.send_error(404)
return
self.send_response(200)
self.send_header("Content-Type", "image/png")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def log_message(self, format, *args):
pass # suppress noisy request logs
return _ImageHandler
@pytest.fixture(scope="module")
def http_image_server() -> Generator[list[str], None, None]:
"""Serve pre-generated PNG images over HTTP for the duration of the module."""
(port,) = allocate_ports(count=1, start_port=18000)
image_map: dict[str, bytes] = {}
for i, color in enumerate(_HTTP_IMAGE_COLORS):
image_map[f"/image_{i}.png"] = _make_png_bytes(color)
server = HTTPServer(("127.0.0.1", port), _make_image_handler(image_map))
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
urls = [
f"http://127.0.0.1:{port}/image_{i}.png" for i in range(len(_HTTP_IMAGE_COLORS))
]
yield urls
server.shutdown()
server.server_close()
thread.join(timeout=5)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_repeated_http_images(
start_vllm_mm_services, predownload_models, http_image_server
):
"""For repeated same 3-HTTP-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_vllm_mm_services
payload = _build_payload(
http_image_server, prompt="MM routing e2e: repeated same 3 HTTP images."
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "http_3_images_req1"
)
time.sleep(1)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "http_3_images_req2"
)
time.sleep(1)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "http_3_images_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_3 <= high, (
f"Unexpected total blocks for same 3 HTTP images (1024): "
f"got {total_3}, expected in [{low}, {high}]"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_vllm_mm_overlap_http_vs_data_uri_same_image(
start_vllm_mm_services, predownload_models, http_image_server
):
"""HTTP URL and data URI for the same image should produce identical KV cache hashes."""
frontend_port, router_proc = start_vllm_mm_services
# Use the first HTTP image color to build both representations
color = _HTTP_IMAGE_COLORS[0]
data_uri = _make_data_uri(color)
http_url = http_image_server[0]
# Seed KV cache with data URI request
data_uri_payload = _build_payload(
[data_uri], prompt="MM routing e2e: HTTP vs data URI same image."
)
overlap_data, total_data, _ = _send_request_get_overlap(
frontend_port, router_proc, data_uri_payload, "data_uri_seed"
)
time.sleep(1)
# Now send HTTP URL request for the identical image
http_payload = _build_payload(
[http_url], prompt="MM routing e2e: HTTP vs data URI same image."
)
overlap_http, total_http, segment_http = _send_request_get_overlap(
frontend_port, router_proc, http_payload, "http_probe"
)
assert total_http > 0, (
f"No routing score for HTTP request.\n"
f"Recent router logs:\n{segment_http[-4000:]}"
)
assert abs(total_http - total_data) <= 2, (
f"Expected HTTP and data URI total blocks to match, "
f"got http={total_http}, data_uri={total_data}.\n"
f"Recent router logs:\n{segment_http[-4000:]}"
)
assert overlap_http > overlap_data, (
f"Expected HTTP probe overlap > data URI seed overlap "
f"(proving image cache hit, not just text overlap), "
f"got http={overlap_http}/{total_http}, data_uri={overlap_data}/{total_data}.\n"
f"Recent router logs:\n{segment_http[-4000:]}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment