Unverified Commit 23de4e86 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: e2e mm aware kv cache routing support for trtllm backend (#5480)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
Signed-off-by: default avatarZhongdao Ren <zhongdaor@zhongdaor-mlt.client.nvidia.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarZhongdao Ren <zhongdaor@zhongdaor-mlt.client.nvidia.com>
parent 8a098a66
......@@ -118,6 +118,7 @@ class ZmqKvEventPublisher:
block_hashes: list[int],
lora_id: int = 0,
parent_hash: Optional[int] = None,
block_mm_infos: Optional[list[dict | None]] = None,
attention_dp_rank: int = 0,
):
"""Publish a BlockStored event.
......@@ -141,6 +142,10 @@ class ZmqKvEventPublisher:
"lora_id": lora_id if lora_id != 0 else None,
}
# Add multimodal info if present
if block_mm_infos is not None:
event["block_mm_infos"] = block_mm_infos
self._publish_event(event, attention_dp_rank)
def publish_removed(self, block_hashes: list[int], attention_dp_rank: int = 0):
......@@ -537,6 +542,7 @@ class Publisher:
token_ids: list[int] = []
num_block_tokens: list[int] = []
block_hashes: list[int] = []
block_mm_infos: list[dict | None] = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = _to_signed_i64(block["block_hash"])
......@@ -561,6 +567,26 @@ class Publisher:
for token in block["tokens"]:
token_ids.append(int(token["token_id"]))
# Extract multimodal hash info for this block
# {"mm_keys": [{"type":"mm_key","hash":"<hex>","start_offset":N}]}
mm_keys = block.get("mm_keys", [])
mm_hashes = [
int(mm_key["hash"][:16], 16)
for mm_key in mm_keys
if mm_key.get("type") == "mm_key" and mm_key.get("hash")
]
if mm_hashes:
block_mm_infos.append(
{
"mm_objects": [
{"mm_hash": mm_hash, "offsets": []}
for mm_hash in mm_hashes
]
}
)
else:
block_mm_infos.append(None)
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
......@@ -583,6 +609,7 @@ class Publisher:
block_hashes,
lora_id,
parent_hash,
block_mm_infos,
attention_dp_rank,
)
elif self.kv_event_publishers:
......@@ -596,6 +623,7 @@ class Publisher:
block_hashes,
lora_id,
parent_hash,
block_mm_infos,
)
else:
logging.warning(
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# MM Router Worker
Multimodal-aware KV cache routing worker for TRT-LLM backends.
## Overview
This worker sits between the Dynamo frontend and TRT-LLM workers, providing MM-aware KV cache routing:
1. **Receives** OpenAI-format requests from the frontend
2. **Downloads** images and computes `mm_hash` (for routing decision only)
3. **Builds** multimodal routing metadata (`mm_routing_info`)
4. **Uses** KvRouter to select and route to the best TRT-LLM worker
5. **Streams** responses back to the frontend
## Architecture
```
Frontend (standard) MM Router Worker (this) TRT-LLM Worker (standard)
┌──────────────┐ ┌─────────────────────┐ ┌───────────────────┐
│ │───────>│ 1. Download images │───────>│ python -m │
│ round-robin │ │ 2. Compute mm_hash │ │ dynamo.trtllm │
│ to mm_router│<───────│ 3. Build routing │<───────│ --modality mm │
└──────────────┘ │ 4. KvRouter route │ │ (processes images)│
└─────────────────────┘ └───────────────────┘
│ Subscribe KV events
v
┌──────────┐
│ NATS │
└──────────┘
```
**Note**: Images are downloaded twice - once in MM Router (for mm_hash computation) and once in TRT-LLM worker (for actual processing). This simplifies the design by avoiding tensor serialization.
## Usage
### Quick Start
```bash
# Start all services
./launch.sh
```
### Manual Start
```bash
# 1. Start etcd and NATS
docker compose -f deploy/docker-compose.yml up -d
# 2. Start TRT-LLM worker(s)
python -m dynamo.trtllm \
--model Qwen/Qwen2-VL-2B-Instruct \
--namespace default \
--component trtllm \
--endpoint generate \
--modality multimodal \
--publish-events-and-metrics &
# 3. Start MM Router Worker
python -m examples.backends.trtllm.mm_router_worker \
--model Qwen/Qwen2-VL-2B-Instruct \
--model-type qwen2_vl \
--namespace default \
--component mm_router \
--endpoint generate \
--downstream-component trtllm \
--downstream-endpoint generate &
# 4. Start Frontend
python -m dynamo.frontend \
--http-port 8000 \
--router-mode round-robin
```
### Test Request
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
]
}],
"max_tokens": 100
}'
```
## Configuration
| Argument | Default | Description |
|----------|---------|-------------|
| `--model` | `Qwen/Qwen2-VL-2B-Instruct` | Model path or HuggingFace ID |
| `--model-type` | `qwen2_vl` | TRT-LLM model type for multimodal loader |
| `--block-size` | `32` | KV cache block size |
| `--namespace` | `default` | Dynamo namespace |
| `--component` | `mm_router` | This worker's component name |
| `--endpoint` | `generate` | This worker's endpoint name |
| `--downstream-component` | `trtllm` | TRT-LLM workers' component name |
| `--downstream-endpoint` | `generate` | TRT-LLM workers' endpoint name |
## How It Works
### MM Hash Computation
The worker uses TRT-LLM's `apply_mm_hashes()` function to compute a hash of each image's tensor representation. This hash is included in the block hash computation, ensuring that:
- Same image = Same mm_hash = Same block hashes = Cache hit
- Different image = Different mm_hash = Different block hashes = No false cache hit
### KV-Aware Routing
The worker uses `KvRouter.generate(...)` with explicit multimodal routing hints.
When a request comes in:
1. Build routing tokens (`routing_token_ids`) for the request
2. Build `block_mm_infos` with per-block image `mm_hash` metadata
3. Pass both as `mm_routing_info` to `KvRouter.generate(...)`
4. KvRouter computes overlap internally and routes to the best worker
### Block MM Info Structure
For each block that contains image tokens, we build `block_mm_infos`:
```python
block_mm_infos = [
None, # Block 0: no image
{"mm_objects": [{"mm_hash": 12345, "offsets": [[32, 128]]}]}, # Block 1: has image
{"mm_objects": [{"mm_hash": 12345, "offsets": [[32, 128]]}]}, # Block 2: same image
None, # Block 3: no image
]
```
This is included in `mm_routing_info` so KvRouter can compute MM-aware overlap.
## Files
| File | Description |
|------|-------------|
| `mm_router_worker.py` | Main worker with `@dynamo_worker()` |
| `handler.py` | `MMRouterHandler` - routing logic |
| `mm_processor.py` | MM processing utilities |
| `__main__.py` | Entry point |
| `launch.sh` | Launch script |
## Dependencies
- `tensorrt_llm >= 1.2.0rc6` - For `apply_mm_hashes()` and `default_multimodal_input_loader()`. Earlier versions may not include multimodal hash support in KV events.
- `transformers` - For `AutoProcessor`
- `dynamo` - For runtime and KvRouter
## Known Limitations
- **Qwen2-VL specific**: The `_compute_tokens_per_image()` logic in `mm_processor.py` currently only supports `qwen2_vl` model type. Supporting other multimodal models requires adding their visual token computation logic.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for TRT-LLM.
This worker sits between the frontend and TRT-LLM workers, computing mm_hash
for images and routing requests to the worker with the best KV cache overlap.
"""
from .handler import MMRouterHandler
from .mm_processor import ProcessedInput, build_block_mm_infos
__all__ = [
"MMRouterHandler",
"ProcessedInput",
"build_block_mm_infos",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Entry point for running MM Router Worker as a module."""
from .mm_router_worker import main
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Handler - Routes requests to best TRT-LLM worker based on KV cache overlap.
"""
import logging
from typing import Any, AsyncGenerator
from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging
from .mm_processor import build_block_mm_infos, extract_image_urls, process_multimodal
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class MMRouterHandler:
"""
Handler that computes mm_hash for multimodal requests and routes
to the best worker based on KV cache overlap.
"""
def __init__(
self,
kv_router: KvRouter,
tokenizer: Any,
processor: Any,
model: str,
model_type: str,
block_size: int,
):
"""
Initialize the MM Router Handler.
Args:
kv_router: KvRouter for KV-aware worker selection and routing
tokenizer: TRT-LLM tokenizer
processor: HuggingFace AutoProcessor (optional)
model: Model path/name
model_type: Model type (e.g., "qwen2_vl")
block_size: KV cache block size
"""
self.kv_router = kv_router
self.tokenizer = tokenizer
self.processor = processor
self.model = model
self.model_type = model_type
self.block_size = block_size
async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
"""
Main entry point - receives request, computes routing, forwards to best worker.
The request format (after Frontend preprocessing with ModelInput.Tokens):
{
"token_ids": [...],
"sampling_options": {...},
"stop_conditions": {...},
"extra_args": {"messages": [...]}
}
Args:
request: Preprocessed request from Frontend
Yields:
Response chunks from the downstream TRT-LLM worker via KvRouter
"""
# Extract messages from extra_args (set by Frontend preprocessor)
messages = request.get("extra_args", {}).get("messages", [])
image_urls = extract_image_urls(messages)
if image_urls:
# Process multimodal: download images, compute mm_hash
# Do not reuse request["token_ids"] for MM routing: those are placeholder-level
# tokens from frontend. We need processor-expanded tokens to build block_mm_infos.
processed = process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
model_type=self.model_type,
)
# Build block_mm_infos for MM-aware hash computation
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError(
"Failed to build block_mm_infos for multimodal request"
)
routing_tokens = processed.tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"MM request: {len(routing_tokens)} routing tokens, "
f"{len(image_urls)} images, {routing_blocks} routing blocks"
)
else:
# Text-only: rely on frontend-preprocessed token_ids (ModelInput.Tokens contract)
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError(
"Missing or empty token_ids in preprocessed request for text-only routing"
)
routing_tokens = token_ids
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"Text request: {len(routing_tokens)} routing tokens, "
f"{routing_blocks} routing blocks"
)
# Text-only routing has no multimodal objects; provide per-block None entries.
block_mm_infos = [None] * routing_blocks
# Route and generate through KvRouter with explicit fields.
# We pass:
# - execution payload (token_ids + multi_modal_data)
# - routing payload (mm_routing_info: routing_token_ids + block_mm_infos)
# so generate() can select worker internally while preserving MM correctness.
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError("Missing or empty token_ids in preprocessed request")
mm_routing_info: dict[str, Any] = {
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
}
stream = await self.kv_router.generate(
token_ids=token_ids,
model=request["model"],
stop_conditions=request.get("stop_conditions"),
sampling_options=request.get("sampling_options"),
output_options=request.get("output_options"),
router_config_override=request.get("router_config_override"),
extra_args=request.get("extra_args"),
multi_modal_data=request.get("multi_modal_data"),
mm_routing_info=mm_routing_info,
)
async for response in stream:
yield response
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Launch script for MM Router Worker with TRT-LLM backend
#
# This script starts:
# 1. TRT-LLM workers (standard, with KV event publishing)
# 2. MM Router Worker (computes mm_hash, routes to best worker)
# 3. Frontend (HTTP ingress)
set -e
# Get the directory where this script is located and navigate to dynamo root
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DYNAMO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)"
cd "$DYNAMO_ROOT"
echo "Working directory: $(pwd)"
# Configuration
MODEL="${MODEL:-Qwen/Qwen2-VL-2B-Instruct}"
MODEL_TYPE="${MODEL_TYPE:-qwen2_vl}"
NAMESPACE="${NAMESPACE:-default}"
BLOCK_SIZE="${BLOCK_SIZE:-32}"
HTTP_PORT="${HTTP_PORT:-8000}"
NUM_WORKERS="${NUM_WORKERS:-1}"
echo "=== MM Router Worker Launch Script ==="
echo "Model: $MODEL"
echo "Model Type: $MODEL_TYPE"
echo "Namespace: $NAMESPACE"
echo "Block Size: $BLOCK_SIZE"
echo "HTTP Port: $HTTP_PORT"
echo "Num Workers: $NUM_WORKERS"
echo ""
# Collect PIDs for cleanup
PIDS=()
cleanup() {
echo "Cleaning up..."
for pid in "${PIDS[@]}"; do
kill "$pid" 2>/dev/null || true
done
wait 2>/dev/null
}
trap cleanup EXIT
# Start TRT-LLM workers
# Use a different served-model-name so Frontend routes to MM Router instead
# Use NATS request plane to match MM Router
echo ""
echo "=== Starting TRT-LLM Workers ==="
for i in $(seq 0 $((NUM_WORKERS - 1))); do
echo "Starting TRT-LLM worker $i..."
DYN_REQUEST_PLANE=nats python -m dynamo.trtllm \
--model-path "$MODEL" \
--served-model-name "${MODEL}__internal" \
--endpoint "dyn://${NAMESPACE}.trtllm.generate" \
--modality multimodal \
--publish-events-and-metrics \
--kv-block-size "$BLOCK_SIZE" \
2>&1 | sed "s/^/[trtllm-$i] /" &
PIDS+=($!)
done
# Wait for workers to initialize
echo "Waiting for TRT-LLM workers to initialize..."
sleep 15
# Start MM Router Worker
# Use NATS request plane to match Frontend
echo ""
echo "=== Starting MM Router Worker ==="
DYN_REQUEST_PLANE=nats python -m examples.backends.trtllm.mm_router_worker \
--model "$MODEL" \
--model-type "$MODEL_TYPE" \
--namespace "$NAMESPACE" \
--component mm_router \
--endpoint generate \
--downstream-component trtllm \
--downstream-endpoint generate \
--block-size "$BLOCK_SIZE" \
2>&1 | sed "s/^/[mm_router] /" &
PIDS+=($!)
# Wait for router to initialize
echo "Waiting for MM Router to initialize..."
sleep 5
# Start Frontend
# Use NATS request plane to match MM Router
echo ""
echo "=== Starting Frontend ==="
DYN_REQUEST_PLANE=nats python -m dynamo.frontend \
--http-port "$HTTP_PORT" \
--router-mode round-robin \
2>&1 | sed "s/^/[frontend] /" &
PIDS+=($!)
echo ""
echo "=== All services started ==="
echo "Frontend available at http://localhost:$HTTP_PORT"
echo ""
echo "Test with:"
echo "curl http://localhost:$HTTP_PORT/v1/chat/completions \\"
echo " -H 'Content-Type: application/json' \\"
echo " -d '{"
echo " \"model\": \"$MODEL\","
echo " \"messages\": [{"
echo " \"role\": \"user\","
echo " \"content\": [{"
echo " \"type\": \"text\","
echo " \"text\": \"Describe this image\""
echo " }, {"
echo " \"type\": \"image_url\","
echo " \"image_url\": {\"url\": \"https://example.com/image.jpg\"}"
echo " }]"
echo " }],"
echo " \"max_tokens\": 100"
echo " }'"
echo ""
echo "Press Ctrl+C to stop all services"
# Wait for all background processes
wait
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Multimodal processing utilities for MM Router Worker."""
import logging
from dataclasses import dataclass
from typing import Any
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
from tensorrt_llm.inputs.utils import default_multimodal_input_loader, load_image
from transformers import AutoConfig
logger = logging.getLogger(__name__)
# =============================================================================
# Data structures
# =============================================================================
@dataclass
class ProcessedInput:
"""Processed multimodal input."""
tokens: list[int]
mm_hashes: list[int] | None
image_ranges: list[tuple[int, int]] | None # [(start, end), ...] per image
# =============================================================================
# Public functions
# =============================================================================
def extract_image_urls(messages: list[dict]) -> list[str]:
"""Extract image URLs from OpenAI-format messages."""
urls = []
for msg in messages:
content = msg.get("content", [])
if isinstance(content, list):
for part in content:
if part.get("type") == "image_url":
url = part.get("image_url", {}).get("url")
if url:
urls.append(url)
return urls
def build_prompt_from_messages(messages: list[dict]) -> str:
"""Build a simple prompt string from messages."""
parts = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if isinstance(content, str):
parts.append(f"{role}: {content}")
elif isinstance(content, list):
texts = [p.get("text", "") for p in content if p.get("type") == "text"]
if texts:
parts.append(f"{role}: {' '.join(texts)}")
return "\n".join(parts)
def process_multimodal(
messages: list[dict],
image_urls: list[str],
tokenizer: Any,
processor: Any,
model: str,
model_type: str,
) -> ProcessedInput:
"""Process multimodal request: load images, get expanded tokens and mm_hashes."""
try:
prompt = build_prompt_from_messages(messages)
# Use TRT-LLM loader to process images and get mm data
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model,
model_type=model_type,
modality="multiple_image" if len(image_urls) > 1 else "image",
prompts=[prompt],
media=[image_urls],
image_data_format="pt",
device="cuda",
)
mm_input = inputs[0]
processed_prompt = mm_input.get("prompt", prompt)
multi_modal_data = mm_input.get("multi_modal_data")
# Get expanded tokens and image ranges
tokens, image_ranges = _get_expanded_tokens(
processed_prompt, image_urls, tokenizer, processor, model, model_type
)
# Compute mm_hash for each image
mm_hashes = _compute_mm_hashes(multi_modal_data)
return ProcessedInput(
tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges
)
except Exception as e:
logger.error(f"MM processing failed: {e}", exc_info=True)
raise
def build_block_mm_infos(
num_tokens: int,
block_size: int,
mm_hashes: list[int] | None,
image_ranges: list[tuple[int, int]] | None,
) -> list[dict | None] | None:
"""
Build per-block mm_info for routing.
For each block, check which images overlap with it and add their mm_hash.
Assumption: mm_hashes and image_ranges are in the same order as images appear
in the request (which matches their order in the token sequence).
"""
if not mm_hashes or not image_ranges or len(mm_hashes) != len(image_ranges):
return None
num_blocks = (num_tokens + block_size - 1) // block_size
result = []
for block_idx in range(num_blocks):
block_start = block_idx * block_size
block_end = block_start + block_size
# Find images overlapping this block
mm_objects = [
{"mm_hash": mm_hash, "offsets": []}
for mm_hash, (img_start, img_end) in zip(mm_hashes, image_ranges)
if block_end > img_start and block_start < img_end
]
result.append({"mm_objects": mm_objects} if mm_objects else None)
return result
# =============================================================================
# Internal functions
# =============================================================================
def _get_expanded_tokens(
prompt: str,
image_urls: list[str],
tokenizer: Any,
processor: Any,
model_path: str,
model_type: str,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""Get tokens with visual expansion and find each image's token range."""
if processor is None:
return tokenizer.encode(prompt), None
try:
# TODO @zdren: use async_load_image or batch load
pil_images = [load_image(url, format="pil") for url in image_urls]
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
# Get image_token_id from processor
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
raise ValueError("processor.image_token_id not found")
# Get replacement_id: TRTLLM uses vocab_size + 1 in KV events
replacement_id = _get_replacement_id(model_path)
# Find contiguous image token ranges and replace them in one pass
contiguous_ranges = _find_and_replace_image_tokens(
tokens, image_token_id, replacement_id
)
# Compute tokens per image from processor output
tokens_per_image = _compute_tokens_per_image(output, processor, model_type)
# Split ranges according to tokens_per_image
image_ranges = _compute_per_image_ranges(contiguous_ranges, tokens_per_image)
return tokens, image_ranges
except Exception as e:
logger.warning(f"HF processor failed: {e}", exc_info=True)
return tokenizer.encode(prompt), None
def _compute_tokens_per_image(
processor_output: dict, processor: Any, model_type: str
) -> list[int]:
"""Compute the number of visual tokens for each image from processor output."""
if model_type == "qwen2_vl":
grid_thw = processor_output.get("image_grid_thw")
if grid_thw is None:
raise ValueError(
"image_grid_thw not found in processor output for Qwen2-VL"
)
merge_size = getattr(processor.image_processor, "merge_size", 2)
return [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
else:
raise NotImplementedError(f"Model type '{model_type}' is not supported yet")
def _get_replacement_id(model_path: str) -> int:
"""
Get the replacement token ID for image tokens to match TRTLLM's KV event format.
TRTLLM replaces image placeholder tokens with (vocab_size + 1) in KV events.
The vocab_size comes from the model config, not the tokenizer.
"""
try:
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
replacement_id = config.vocab_size + 1
logger.info(f"Got vocab_size={config.vocab_size} from AutoConfig")
return replacement_id
except Exception as e:
raise RuntimeError(
f"Failed to get vocab_size from model config '{model_path}': {e}"
) from e
def _find_and_replace_image_tokens(
tokens: list[int], image_token_id: int, replacement_id: int
) -> list[tuple[int, int]]:
"""
Find all contiguous ranges of image tokens and replace them in place.
Returns: list of (start, end) ranges for contiguous image token regions.
"""
ranges = []
start = None
for i, t in enumerate(tokens):
if t == image_token_id:
tokens[i] = replacement_id
if start is None:
start = i
elif start is not None:
ranges.append((start, i))
start = None
if start is not None:
ranges.append((start, len(tokens)))
if ranges:
logger.info(
f"Replaced {sum(e - s for s, e in ranges)} image tokens: {image_token_id} -> {replacement_id}"
)
return ranges
def _compute_per_image_ranges(
contiguous_ranges: list[tuple[int, int]],
tokens_per_image: list[int],
) -> list[tuple[int, int]] | None:
"""
Split contiguous image token ranges by each image's token count.
Example: contiguous_ranges=[(0, 100)], tokens_per_image=[60, 40]
Returns: [(0, 60), (60, 100)] # image 1 at 0-60, image 2 at 60-100
"""
if not contiguous_ranges:
if tokens_per_image:
logger.warning(
f"No image tokens found but {len(tokens_per_image)} images expected"
)
return None
# Greedily assign images to ranges in order
result = []
image_idx = 0
for range_start, range_end in contiguous_ranges:
range_size = range_end - range_start
pos = range_start
consumed = 0
# Consume images that fit entirely in this range
# (a single image's tokens are always contiguous, cannot span ranges)
while image_idx < len(tokens_per_image):
needed = tokens_per_image[image_idx]
if consumed + needed <= range_size:
result.append((pos, pos + needed))
pos += needed
consumed += needed
image_idx += 1
else:
break
# Range must be exactly filled (no leftover image tokens)
if consumed != range_size:
logger.warning(
f"Range size mismatch: consumed {consumed} != range {range_size}"
)
return None
# All images must be consumed
if image_idx != len(tokens_per_image):
logger.warning(f"Not all images mapped: {image_idx} < {len(tokens_per_image)}")
return None
return result
def _compute_mm_hashes(multi_modal_data: dict | None) -> list[int] | None:
"""Compute mm_hash for each image."""
if not multi_modal_data:
return None
try:
result = apply_mm_hashes(multi_modal_data)
if "image" in result and result["image"]:
return [int(h[:16], 16) for h in result["image"]]
except Exception as e:
logger.warning(f"Failed to compute mm_hashes: {e}")
return None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for TRT-LLM.
This worker receives OpenAI-format requests from the frontend, computes
mm_hash for any images, finds the best TRT-LLM worker based on KV cache
overlap, and forwards the request to that worker.
Usage:
python -m examples.backends.trtllm.mm_router_worker \
--model Qwen/Qwen2-VL-2B-Instruct \
--model-type qwen2_vl \
--namespace default \
--component mm_router \
--endpoint generate \
--downstream-component trtllm \
--downstream-endpoint generate
"""
import argparse
import asyncio
import logging
import signal
import uvloop
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from transformers import AutoProcessor
from dynamo.llm import KvRouter, KvRouterConfig, ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from .handler import MMRouterHandler
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="MM Router Worker - Multimodal-aware KV cache routing"
)
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2-VL-2B-Instruct",
help="Model path or HuggingFace model ID",
)
parser.add_argument(
"--model-type",
type=str,
default="qwen2_vl",
help="Model type for TRT-LLM multimodal loader",
)
parser.add_argument(
"--block-size",
type=int,
default=32,
help="KV cache block size",
)
# This worker's endpoint configuration
parser.add_argument(
"--namespace",
type=str,
default="default",
help="Dynamo namespace",
)
parser.add_argument(
"--component",
type=str,
default="mm_router",
help="This worker's component name",
)
parser.add_argument(
"--endpoint",
type=str,
default="generate",
help="This worker's endpoint name",
)
# Downstream TRT-LLM worker configuration
parser.add_argument(
"--downstream-component",
type=str,
default="trtllm",
help="Downstream TRT-LLM workers' component name",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default="generate",
help="Downstream TRT-LLM workers' endpoint name",
)
return parser.parse_args()
async def graceful_shutdown(runtime: DistributedRuntime) -> None:
"""Handle graceful shutdown."""
logger.info("Received shutdown signal, shutting down...")
runtime.shutdown()
logger.info("Shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime) -> None:
"""
Main worker function.
Sets up connections to downstream TRT-LLM workers, creates KvRouter
for KV-aware routing, and serves the MM router endpoint.
"""
args = parse_args()
# Set up signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("MM Router Worker starting...")
logger.info(f"Model: {args.model}")
logger.info(f"This worker: {args.namespace}.{args.component}.{args.endpoint}")
logger.info(
f"Downstream: {args.namespace}.{args.downstream_component}.{args.downstream_endpoint}"
)
# Connect to downstream TRT-LLM workers
downstream_endpoint = (
runtime.namespace(args.namespace)
.component(args.downstream_component)
.endpoint(args.downstream_endpoint)
)
downstream_client = await downstream_endpoint.client()
logger.info("Waiting for downstream TRT-LLM workers...")
instance_ids = await downstream_client.wait_for_instances()
logger.info(f"Found {len(instance_ids)} workers: {list(instance_ids)}")
# Create KvRouter to select workers based on KV overlap
kv_router = KvRouter(
endpoint=downstream_endpoint,
block_size=args.block_size,
kv_router_config=KvRouterConfig(),
)
logger.info("KvRouter created successfully")
# Initialize tokenizer and processor for MM processing
logger.info(f"Loading tokenizer from {args.model}...")
tokenizer = tokenizer_factory(args.model)
processor = None
try:
logger.info(f"Loading HuggingFace processor from {args.model}...")
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
except Exception as e:
logger.warning(f"Failed to load HF processor: {e}")
logger.warning("Visual token expansion will not be available")
# Create handler
handler = MMRouterHandler(
kv_router=kv_router,
tokenizer=tokenizer,
processor=processor,
model=args.model,
model_type=args.model_type,
block_size=args.block_size,
)
# Register this worker's endpoint
component = runtime.namespace(args.namespace).component(args.component)
endpoint = component.endpoint(args.endpoint)
# Use ModelInput.Tokens so Frontend preprocesses the request
# Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}}
await register_llm(
ModelInput.Tokens,
ModelType.Chat,
endpoint,
args.model,
kv_cache_block_size=args.block_size,
)
logger.info(f"MM Router Worker ready, serving {args.endpoint} endpoint...")
# Serve the endpoint
await endpoint.serve_endpoint(handler.generate)
def main() -> None:
"""Entry point for the MM Router Worker."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
uvloop.install()
asyncio.run(worker())
if __name__ == "__main__":
main()
......@@ -435,7 +435,7 @@ class TrtllmWorker:
await asyncio.sleep(2)
try:
events = self.llm.get_kv_cache_events_async(timeout=None)
events = self.llm.get_kv_cache_events_async(timeout=5)
logger.info(f"Worker {self.worker_id}: KV events iterator obtained")
async for event in events:
......
......@@ -742,6 +742,7 @@ class KvEventPublisher:
block_hashes: List[int],
lora_id: int,
parent_hash: Optional[int] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
) -> None:
"""
Publish a KV stored event.
......@@ -754,6 +755,9 @@ class KvEventPublisher:
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
block_mm_infos: Optional list of multimodal info for each block.
Each item is either None or a dict with "mm_objects" key containing
a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts.
"""
...
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Multimodal router tests."""
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""End-to-end tests for multimodal KV routing with TRT-LLM.
Architecture:
Frontend -> MM Router Worker -> TRT-LLM Worker
(computes mm_hash) (publishes KV events)
This test validates MM-aware routing by sending repeated multimodal requests and
asserting that router overlap is greater than 1 block (regression guard against
text-only/partially-matched hash paths that typically show 1/N overlap).
"""
from __future__ import annotations
import base64
import os
import re
import shutil
import time
from io import BytesIO
from typing import Any, Generator
import pytest
import requests
from tests.conftest import EtcdServer, NatsServer
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_models_api
from tests.utils.port_utils import allocate_ports
TRTLLM_MM_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
TRTLLM_MM_MODEL_TYPE = "qwen2_vl"
BLOCK_SIZE = 32
NAMESPACE = "test-mm"
# Broad guardrails for TRT-LLM + Qwen2-VL-2B under block size 32.
THREE_IMAGE_TOTAL_BLOCKS_RANGE = (80, 520)
SINGLE_IMAGE_TOTAL_BLOCKS_RANGE = (20, 260)
pytestmark = [
pytest.mark.e2e,
pytest.mark.trtllm,
pytest.mark.multimodal,
pytest.mark.gpu_1,
pytest.mark.model(TRTLLM_MM_MODEL),
]
_COLORS = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
]
_ALT_COLORS = [
(255, 255, 0),
(0, 255, 255),
(255, 0, 255),
]
_SINGLE_IMAGE_FRESH_COLOR = (123, 45, 67)
_DOUBLE_IMAGE_FRESH_COLOR = (89, 210, 34)
_STAIRCASE_IMAGE_FRESH_COLOR = (17, 99, 201)
_SWAP_ORDER_FRESH_COLORS = [(14, 141, 77), (211, 66, 101), (44, 91, 233)]
# Contract with lib/llm/src/kv_router/push_router.rs "[ROUTING]" debug log.
# Keep this parser in sync with the router log format.
_ROUTING_RECORD_PATTERN = re.compile(
r"\[ROUTING\].*with\s*(\d+)/(\d+)\s*blocks overlap"
)
def _check_ready(response) -> bool:
try:
return (response.json() or {}).get("status") == "ready"
except ValueError:
return False
def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]:
env = os.environ.copy()
env["DYN_LOG"] = log_level
env["DYN_NAMESPACE"] = NAMESPACE
env["DYN_REQUEST_PLANE"] = "nats"
env.update(extra)
return env
def _prepare_log_dir(request, suffix: str) -> str:
log_dir = f"{request.node.name}_{suffix}"
shutil.rmtree(log_dir, ignore_errors=True)
return log_dir
_COMMON_PROCESS_KWARGS: dict[str, Any] = {
"display_output": True,
"terminate_all_matching_process_names": False,
}
class TRTLLMWorkerProcess(ManagedProcess):
"""TRT-LLM backend worker that emits KV events."""
def __init__(self, request, *, system_port: int):
super().__init__(
command=[
"python3",
"-m",
"dynamo.trtllm",
"--model-path",
TRTLLM_MM_MODEL,
"--served-model-name",
f"{TRTLLM_MM_MODEL}__internal",
"--endpoint",
f"dyn://{NAMESPACE}.trtllm.generate",
"--modality",
"multimodal",
"--publish-events-and-metrics",
"--kv-block-size",
str(BLOCK_SIZE),
],
env=_make_process_env(DYN_SYSTEM_PORT=str(system_port)),
health_check_urls=[
(f"http://localhost:{system_port}/health", _check_ready)
],
timeout=900,
straggler_commands=["-m dynamo.trtllm"],
log_dir=_prepare_log_dir(request, "trtllm-worker"),
**_COMMON_PROCESS_KWARGS,
)
class TRTLLMMMRouterWorkerProcess(ManagedProcess):
"""TRT-LLM MM router worker."""
def __init__(self, request, *, system_port: int):
super().__init__(
command=[
"python3",
"-m",
"examples.backends.trtllm.mm_router_worker",
"--model",
TRTLLM_MM_MODEL,
"--model-type",
TRTLLM_MM_MODEL_TYPE,
"--namespace",
NAMESPACE,
"--component",
"mm_router",
"--endpoint",
"generate",
"--downstream-component",
"trtllm",
"--downstream-endpoint",
"generate",
"--block-size",
str(BLOCK_SIZE),
],
env=_make_process_env(
DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS='["generate"]',
DYN_SYSTEM_PORT=str(system_port),
),
health_check_urls=[
(f"http://localhost:{system_port}/health", _check_ready)
],
timeout=240,
straggler_commands=["mm_router_worker"],
log_dir=_prepare_log_dir(request, "trtllm-mm-router"),
**_COMMON_PROCESS_KWARGS,
)
class FrontendProcess(ManagedProcess):
"""Frontend HTTP ingress."""
def __init__(self, request, *, frontend_port: int):
super().__init__(
command=[
"python3",
"-m",
"dynamo.frontend",
"--http-port",
str(frontend_port),
"--router-mode",
"round-robin",
],
env=_make_process_env(log_level="info"),
health_check_urls=[
(f"http://localhost:{frontend_port}/v1/models", check_models_api)
],
timeout=240,
straggler_commands=["-m dynamo.frontend"],
log_dir=_prepare_log_dir(request, "trtllm-mm-frontend"),
**_COMMON_PROCESS_KWARGS,
)
@pytest.fixture(scope="module")
def mm_runtime_services(request):
with NatsServer(request, port=0) as nats, EtcdServer(request, port=0) as etcd:
os.environ["NATS_SERVER"] = f"nats://localhost:{nats.port}"
os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd.port}"
yield
os.environ.pop("NATS_SERVER", None)
os.environ.pop("ETCD_ENDPOINTS", None)
@pytest.fixture(scope="module")
def start_trtllm_mm_services(
request, mm_runtime_services
) -> Generator[tuple[int, ManagedProcess], None, None]:
frontend_port, trtllm_port, router_port = allocate_ports(count=3, start_port=10000)
with TRTLLMWorkerProcess(request, system_port=trtllm_port):
time.sleep(15)
with TRTLLMMMRouterWorkerProcess(
request, system_port=router_port
) as router_proc:
time.sleep(5)
with FrontendProcess(request, frontend_port=frontend_port):
yield frontend_port, router_proc
def _make_data_uri(color: tuple[int, int, int], size: int = 1024) -> str:
from PIL import Image
img = Image.new("RGB", (size, size), color)
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
def _build_payload(
image_uris: list[str], prompt: str = "Describe what you see."
) -> dict[str, Any]:
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
for uri in image_uris:
content.append({"type": "image_url", "image_url": {"url": uri}})
return {
"model": TRTLLM_MM_MODEL,
"messages": [{"role": "user", "content": content}],
"max_tokens": 1,
}
def _extract_routing_records(log_text: str) -> list[tuple[int, int]]:
return [
(int(overlap), int(total))
for overlap, total in _ROUTING_RECORD_PATTERN.findall(log_text)
]
def _wait_for_new_routing_score(
router_proc: ManagedProcess,
start_offset: int,
pre_request_routing_count: int,
timeout_s: float = 120.0,
) -> tuple[int, int, str]:
deadline = time.time() + timeout_s
last_segment = ""
while time.time() < deadline:
full_logs = router_proc.read_logs()
segment = full_logs[start_offset:]
last_segment = segment
records = _extract_routing_records(full_logs)
if len(records) >= pre_request_routing_count + 1:
overlap, total = records[-1]
return overlap, total, segment
time.sleep(1)
fallback_records = _extract_routing_records(last_segment)
if fallback_records:
overlap, total = fallback_records[-1]
return overlap, total, last_segment
return 0, 0, last_segment
def _send_request_get_overlap(
frontend_port: int,
router_proc: ManagedProcess,
payload: dict[str, Any],
label: str,
) -> tuple[int, int, str]:
"""Send one request and read the new routing overlap score."""
pre_request_logs = router_proc.read_logs()
start_offset = len(pre_request_logs)
pre_request_routing_count = len(_extract_routing_records(pre_request_logs))
resp = requests.post(
f"http://localhost:{frontend_port}/v1/chat/completions",
json=payload,
timeout=240,
)
assert resp.status_code == 200, f"HTTP {resp.status_code}: {resp.text}"
data = resp.json()
assert "choices" in data, f"Missing choices in response: {data}"
overlap, total, segment = _wait_for_new_routing_score(
router_proc=router_proc,
start_offset=start_offset,
pre_request_routing_count=pre_request_routing_count,
timeout_s=120,
)
print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}")
return overlap, total, segment
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_text_only_overlap_repeated_prompt(
start_trtllm_mm_services, predownload_models
):
"""Text-only routing should increase overlap on repeat and then stabilize."""
frontend_port, router_proc = start_trtllm_mm_services
prompt = (
"TEXT routing e2e unique case zeta-7f31. "
"Repeat this sentence to force multiple KV blocks. "
) * 80
payload = _build_payload([], prompt=prompt)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "text_only_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "text_only_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "text_only_req3"
)
assert total_1 > 0 and total_2 > 0 and total_3 > 0, (
f"Expected non-zero total blocks for text-only request, got "
f"{total_1}, {total_2}, {total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert abs(total_1 - total_2) <= 2 and abs(total_2 - total_3) <= 2, (
f"Expected text-only total blocks to remain stable across repeats, got "
f"req1={total_1}, req2={total_2}, req3={total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second text-only overlap > first, got "
f"req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third text-only overlap == second, got "
f"req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_repeated_three_images(
start_trtllm_mm_services, predownload_models
):
"""For repeated same 3-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_trtllm_mm_services
image_uris = [_make_data_uri(c) for c in _COLORS]
payload = _build_payload(
image_uris, prompt="MM routing e2e: repeated same 3-image request."
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_3_images_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_3_images_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_3_images_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got "
f"req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got "
f"req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_3 <= high, (
f"Unexpected total blocks for same 3 images (1024): "
f"got {total_3}, expected in [{low}, {high}]"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_repeated_single_image(
start_trtllm_mm_services, predownload_models
):
"""For repeated same single-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_trtllm_mm_services
payload = _build_payload(
[_make_data_uri(_SINGLE_IMAGE_FRESH_COLOR)],
prompt="MM routing e2e: repeated same single-image request.",
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_single_image_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_single_image_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_single_image_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got "
f"req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got "
f"req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
low, high = SINGLE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_3 <= high, (
f"Unexpected total blocks for same 1 image (1024): "
f"got {total_3}, expected in [{low}, {high}]"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_repeated_two_identical_images(
start_trtllm_mm_services, predownload_models
):
"""For repeated same two-identical-image request: low first overlap, then increase, then stable."""
frontend_port, router_proc = start_trtllm_mm_services
image_uri = _make_data_uri(_DOUBLE_IMAGE_FRESH_COLOR)
payload = _build_payload(
[image_uri, image_uri],
prompt="MM routing e2e: repeated same two-identical-image request.",
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req1"
)
overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req2"
)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req3"
)
assert overlap_1 <= 1, (
f"Expected first overlap <=1, got req1={overlap_1}/{total_1}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_2 > overlap_1, (
f"Expected second overlap > first, got "
f"req1={overlap_1}/{total_1}, req2={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
assert overlap_3 == overlap_2, (
f"Expected third overlap == second, got "
f"req2={overlap_2}/{total_2}, req3={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_staircase_single_to_double_to_triple_identical_image(
start_trtllm_mm_services, predownload_models
):
"""Single->double->triple identical image requests follow prefix-overlap semantics."""
frontend_port, router_proc = start_trtllm_mm_services
image_uri = _make_data_uri(_STAIRCASE_IMAGE_FRESH_COLOR)
staircase_prompt = "MM routing e2e: staircase."
payload_single = _build_payload([image_uri], prompt=staircase_prompt)
payload_double = _build_payload([image_uri, image_uri], prompt=staircase_prompt)
payload_triple = _build_payload(
[image_uri, image_uri, image_uri], prompt=staircase_prompt
)
overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload_single, "staircase_1x_image"
)
overlap_2, total_2, segment_2 = _send_request_get_overlap(
frontend_port, router_proc, payload_double, "staircase_2x_image"
)
time.sleep(1)
overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload_triple, "staircase_3x_image"
)
assert overlap_2 > overlap_1, (
f"Expected overlap to increase from 1 image to 2 images, got "
f"1x={overlap_1}/{total_1}, 2x={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_2[-4000:]}"
)
assert abs(overlap_3 - overlap_2) <= 1, (
"Expected first 3-image request overlap to stay near 2-image overlap "
"(third-image suffix is cold on first 3-image request), got "
f"2x={overlap_2}/{total_2}, 3x={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
total_step_12 = total_2 - total_1
total_step_23 = total_3 - total_2
assert abs(total_step_12 - total_step_23) <= 4, (
"Expected similar total-block increment per additional identical image, got "
f"step(1->2)={total_step_12}, step(2->3)={total_step_23}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_diff_images_less_than_same(
start_trtllm_mm_services, predownload_models
):
"""Different images should produce lower overlap than repeated identical images."""
frontend_port, router_proc = start_trtllm_mm_services
baseline_payload = _build_payload(
[_make_data_uri(c) for c in _COLORS],
prompt="MM routing e2e: baseline same-images overlap.",
)
overlap_baseline_1, total_baseline_1, _ = _send_request_get_overlap(
frontend_port, router_proc, baseline_payload, "baseline_same_images_req1"
)
overlap_baseline_2, total_baseline_2, segment_baseline = _send_request_get_overlap(
frontend_port, router_proc, baseline_payload, "baseline_same_images_req2"
)
overlap_baseline = max(overlap_baseline_1, overlap_baseline_2)
total_baseline = total_baseline_2
assert abs(total_baseline_1 - total_baseline_2) <= 4, (
"Expected total blocks to stay nearly identical for repeated same request, "
f"got req1={total_baseline_1}, req2={total_baseline_2}"
)
assert overlap_baseline >= 2, (
f"Baseline overlap did not reach 2 blocks. "
f"got {overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_baseline[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_baseline <= high, (
f"Unexpected total blocks for baseline same-images request: "
f"got {total_baseline}, expected in [{low}, {high}]"
)
probe_payload = _build_payload(
[_make_data_uri(c) for c in _ALT_COLORS],
prompt="MM routing e2e: baseline same-images overlap.",
)
overlap_probe, total_probe, segment_probe = _send_request_get_overlap(
frontend_port, router_proc, probe_payload, "probe_different_images_req1"
)
assert (
total_probe > 0
), f"No routing score found.\nRecent logs:\n{segment_probe[-4000:]}"
assert abs(total_probe - total_baseline) <= 4, (
f"Expected different-images total blocks to stay near baseline, "
f"got different={total_probe}, baseline={total_baseline}"
)
assert overlap_probe < overlap_baseline, (
f"Expected different-images overlap < baseline overlap, "
f"got different={overlap_probe}/{total_probe}, "
f"baseline={overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_probe[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_same_images_different_prompt_less_than_same_prompt(
start_trtllm_mm_services, predownload_models
):
"""Same images but different prompt should produce lower overlap than repeated same prompt."""
frontend_port, router_proc = start_trtllm_mm_services
baseline_payload = _build_payload(
[_make_data_uri(c) for c in _COLORS],
prompt="MM routing e2e: prompt-sensitive baseline alpha.",
)
overlap_baseline_1, total_baseline_1, _ = _send_request_get_overlap(
frontend_port,
router_proc,
baseline_payload,
"baseline_same_images_prompt_a_req1",
)
overlap_baseline_2, total_baseline_2, segment_baseline = _send_request_get_overlap(
frontend_port,
router_proc,
baseline_payload,
"baseline_same_images_prompt_a_req2",
)
overlap_baseline = max(overlap_baseline_1, overlap_baseline_2)
total_baseline = total_baseline_2
assert abs(total_baseline_1 - total_baseline_2) <= 4, (
"Expected total blocks to stay nearly identical for repeated same request, "
f"got req1={total_baseline_1}, req2={total_baseline_2}"
)
assert overlap_baseline >= 2, (
f"Baseline overlap did not reach 2 blocks. "
f"got {overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_baseline[-4000:]}"
)
low, high = THREE_IMAGE_TOTAL_BLOCKS_RANGE
assert low <= total_baseline <= high, (
f"Unexpected total blocks for baseline same-images request: "
f"got {total_baseline}, expected in [{low}, {high}]"
)
probe_payload = _build_payload(
[_make_data_uri(c) for c in _COLORS],
prompt="MM routing e2e: prompt-sensitive baseline omega.",
)
overlap_probe, total_probe, segment_probe = _send_request_get_overlap(
frontend_port, router_proc, probe_payload, "probe_same_images_prompt_b_req1"
)
assert (
total_probe > 0
), f"No routing score found.\nRecent logs:\n{segment_probe[-4000:]}"
assert abs(total_probe - total_baseline) <= 4, (
f"Expected different-prompt total blocks to stay near baseline, "
f"got different_prompt={total_probe}, baseline={total_baseline}"
)
assert overlap_probe < overlap_baseline, (
f"Expected different-prompt overlap < baseline overlap, "
f"got different_prompt={overlap_probe}/{total_probe}, "
f"baseline={overlap_baseline}/{total_baseline}.\n"
f"Recent router logs:\n{segment_probe[-4000:]}"
)
@pytest.mark.timeout(1800)
@pytest.mark.nightly
def test_trtllm_mm_overlap_swapped_order_less_than_same_order(
start_trtllm_mm_services, predownload_models
):
"""Swapping order of three distinct images should result in near-zero overlap."""
frontend_port, router_proc = start_trtllm_mm_services
ordered_uris = [_make_data_uri(c) for c in _SWAP_ORDER_FRESH_COLORS]
ordered_payload = _build_payload(
ordered_uris, prompt="MM routing e2e: order sensitivity ordered baseline."
)
swapped_payload = _build_payload(
list(reversed(ordered_uris)),
prompt="MM routing e2e: order sensitivity ordered baseline.",
)
overlap_ordered_1, total_ordered_1, _ = _send_request_get_overlap(
frontend_port,
router_proc,
ordered_payload,
"ordered_distinct_images_req1",
)
overlap_ordered_2, total_ordered_2, segment_ordered_2 = _send_request_get_overlap(
frontend_port,
router_proc,
ordered_payload,
"ordered_distinct_images_req2",
)
overlap_swapped, total_swapped, segment_swapped = _send_request_get_overlap(
frontend_port,
router_proc,
swapped_payload,
"swapped_distinct_images_req1",
)
assert overlap_ordered_2 > overlap_ordered_1, (
"Expected repeated identical order to increase overlap before swapped-order "
"probe, "
f"got req1={overlap_ordered_1}/{total_ordered_1}, "
f"req2={overlap_ordered_2}/{total_ordered_2}.\n"
f"Recent router logs:\n{segment_ordered_2[-4000:]}"
)
assert abs(total_swapped - total_ordered_2) <= 4, (
f"Expected swapped-order total blocks to stay near ordered baseline, "
f"got swapped={total_swapped}, ordered={total_ordered_2}"
)
assert overlap_swapped <= 1, (
"Expected near-zero overlap for swapped order of three distinct images "
f"(allowing 1 shared text block), got {overlap_swapped}/{total_swapped}.\n"
f"Recent router logs:\n{segment_swapped[-4000:]}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment