Unverified Commit 23de4e86 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: e2e mm aware kv cache routing support for trtllm backend (#5480)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
Signed-off-by: default avatarZhongdao Ren <zhongdaor@zhongdaor-mlt.client.nvidia.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarZhongdao Ren <zhongdaor@zhongdaor-mlt.client.nvidia.com>
parent 8a098a66
......@@ -118,6 +118,7 @@ class ZmqKvEventPublisher:
block_hashes: list[int],
lora_id: int = 0,
parent_hash: Optional[int] = None,
block_mm_infos: Optional[list[dict | None]] = None,
attention_dp_rank: int = 0,
):
"""Publish a BlockStored event.
......@@ -141,6 +142,10 @@ class ZmqKvEventPublisher:
"lora_id": lora_id if lora_id != 0 else None,
}
# Add multimodal info if present
if block_mm_infos is not None:
event["block_mm_infos"] = block_mm_infos
self._publish_event(event, attention_dp_rank)
def publish_removed(self, block_hashes: list[int], attention_dp_rank: int = 0):
......@@ -537,6 +542,7 @@ class Publisher:
token_ids: list[int] = []
num_block_tokens: list[int] = []
block_hashes: list[int] = []
block_mm_infos: list[dict | None] = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = _to_signed_i64(block["block_hash"])
......@@ -561,6 +567,26 @@ class Publisher:
for token in block["tokens"]:
token_ids.append(int(token["token_id"]))
# Extract multimodal hash info for this block
# {"mm_keys": [{"type":"mm_key","hash":"<hex>","start_offset":N}]}
mm_keys = block.get("mm_keys", [])
mm_hashes = [
int(mm_key["hash"][:16], 16)
for mm_key in mm_keys
if mm_key.get("type") == "mm_key" and mm_key.get("hash")
]
if mm_hashes:
block_mm_infos.append(
{
"mm_objects": [
{"mm_hash": mm_hash, "offsets": []}
for mm_hash in mm_hashes
]
}
)
else:
block_mm_infos.append(None)
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
......@@ -583,6 +609,7 @@ class Publisher:
block_hashes,
lora_id,
parent_hash,
block_mm_infos,
attention_dp_rank,
)
elif self.kv_event_publishers:
......@@ -596,6 +623,7 @@ class Publisher:
block_hashes,
lora_id,
parent_hash,
block_mm_infos,
)
else:
logging.warning(
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# MM Router Worker
Multimodal-aware KV cache routing worker for TRT-LLM backends.
## Overview
This worker sits between the Dynamo frontend and TRT-LLM workers, providing MM-aware KV cache routing:
1. **Receives** OpenAI-format requests from the frontend
2. **Downloads** images and computes `mm_hash` (for routing decision only)
3. **Builds** multimodal routing metadata (`mm_routing_info`)
4. **Uses** KvRouter to select and route to the best TRT-LLM worker
5. **Streams** responses back to the frontend
## Architecture
```
Frontend (standard) MM Router Worker (this) TRT-LLM Worker (standard)
┌──────────────┐ ┌─────────────────────┐ ┌───────────────────┐
│ │───────>│ 1. Download images │───────>│ python -m │
│ round-robin │ │ 2. Compute mm_hash │ │ dynamo.trtllm │
│ to mm_router│<───────│ 3. Build routing │<───────│ --modality mm │
└──────────────┘ │ 4. KvRouter route │ │ (processes images)│
└─────────────────────┘ └───────────────────┘
│ Subscribe KV events
v
┌──────────┐
│ NATS │
└──────────┘
```
**Note**: Images are downloaded twice - once in MM Router (for mm_hash computation) and once in TRT-LLM worker (for actual processing). This simplifies the design by avoiding tensor serialization.
## Usage
### Quick Start
```bash
# Start all services
./launch.sh
```
### Manual Start
```bash
# 1. Start etcd and NATS
docker compose -f deploy/docker-compose.yml up -d
# 2. Start TRT-LLM worker(s)
python -m dynamo.trtllm \
--model Qwen/Qwen2-VL-2B-Instruct \
--namespace default \
--component trtllm \
--endpoint generate \
--modality multimodal \
--publish-events-and-metrics &
# 3. Start MM Router Worker
python -m examples.backends.trtllm.mm_router_worker \
--model Qwen/Qwen2-VL-2B-Instruct \
--model-type qwen2_vl \
--namespace default \
--component mm_router \
--endpoint generate \
--downstream-component trtllm \
--downstream-endpoint generate &
# 4. Start Frontend
python -m dynamo.frontend \
--http-port 8000 \
--router-mode round-robin
```
### Test Request
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
]
}],
"max_tokens": 100
}'
```
## Configuration
| Argument | Default | Description |
|----------|---------|-------------|
| `--model` | `Qwen/Qwen2-VL-2B-Instruct` | Model path or HuggingFace ID |
| `--model-type` | `qwen2_vl` | TRT-LLM model type for multimodal loader |
| `--block-size` | `32` | KV cache block size |
| `--namespace` | `default` | Dynamo namespace |
| `--component` | `mm_router` | This worker's component name |
| `--endpoint` | `generate` | This worker's endpoint name |
| `--downstream-component` | `trtllm` | TRT-LLM workers' component name |
| `--downstream-endpoint` | `generate` | TRT-LLM workers' endpoint name |
## How It Works
### MM Hash Computation
The worker uses TRT-LLM's `apply_mm_hashes()` function to compute a hash of each image's tensor representation. This hash is included in the block hash computation, ensuring that:
- Same image = Same mm_hash = Same block hashes = Cache hit
- Different image = Different mm_hash = Different block hashes = No false cache hit
### KV-Aware Routing
The worker uses `KvRouter.generate(...)` with explicit multimodal routing hints.
When a request comes in:
1. Build routing tokens (`routing_token_ids`) for the request
2. Build `block_mm_infos` with per-block image `mm_hash` metadata
3. Pass both as `mm_routing_info` to `KvRouter.generate(...)`
4. KvRouter computes overlap internally and routes to the best worker
### Block MM Info Structure
For each block that contains image tokens, we build `block_mm_infos`:
```python
block_mm_infos = [
None, # Block 0: no image
{"mm_objects": [{"mm_hash": 12345, "offsets": [[32, 128]]}]}, # Block 1: has image
{"mm_objects": [{"mm_hash": 12345, "offsets": [[32, 128]]}]}, # Block 2: same image
None, # Block 3: no image
]
```
This is included in `mm_routing_info` so KvRouter can compute MM-aware overlap.
## Files
| File | Description |
|------|-------------|
| `mm_router_worker.py` | Main worker with `@dynamo_worker()` |
| `handler.py` | `MMRouterHandler` - routing logic |
| `mm_processor.py` | MM processing utilities |
| `__main__.py` | Entry point |
| `launch.sh` | Launch script |
## Dependencies
- `tensorrt_llm >= 1.2.0rc6` - For `apply_mm_hashes()` and `default_multimodal_input_loader()`. Earlier versions may not include multimodal hash support in KV events.
- `transformers` - For `AutoProcessor`
- `dynamo` - For runtime and KvRouter
## Known Limitations
- **Qwen2-VL specific**: The `_compute_tokens_per_image()` logic in `mm_processor.py` currently only supports `qwen2_vl` model type. Supporting other multimodal models requires adding their visual token computation logic.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for TRT-LLM.
This worker sits between the frontend and TRT-LLM workers, computing mm_hash
for images and routing requests to the worker with the best KV cache overlap.
"""
from .handler import MMRouterHandler
from .mm_processor import ProcessedInput, build_block_mm_infos
__all__ = [
"MMRouterHandler",
"ProcessedInput",
"build_block_mm_infos",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Entry point for running MM Router Worker as a module."""
from .mm_router_worker import main
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Handler - Routes requests to best TRT-LLM worker based on KV cache overlap.
"""
import logging
from typing import Any, AsyncGenerator
from dynamo.llm import KvRouter
from dynamo.runtime.logging import configure_dynamo_logging
from .mm_processor import build_block_mm_infos, extract_image_urls, process_multimodal
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class MMRouterHandler:
"""
Handler that computes mm_hash for multimodal requests and routes
to the best worker based on KV cache overlap.
"""
def __init__(
self,
kv_router: KvRouter,
tokenizer: Any,
processor: Any,
model: str,
model_type: str,
block_size: int,
):
"""
Initialize the MM Router Handler.
Args:
kv_router: KvRouter for KV-aware worker selection and routing
tokenizer: TRT-LLM tokenizer
processor: HuggingFace AutoProcessor (optional)
model: Model path/name
model_type: Model type (e.g., "qwen2_vl")
block_size: KV cache block size
"""
self.kv_router = kv_router
self.tokenizer = tokenizer
self.processor = processor
self.model = model
self.model_type = model_type
self.block_size = block_size
async def generate(self, request: dict) -> AsyncGenerator[dict, None]:
"""
Main entry point - receives request, computes routing, forwards to best worker.
The request format (after Frontend preprocessing with ModelInput.Tokens):
{
"token_ids": [...],
"sampling_options": {...},
"stop_conditions": {...},
"extra_args": {"messages": [...]}
}
Args:
request: Preprocessed request from Frontend
Yields:
Response chunks from the downstream TRT-LLM worker via KvRouter
"""
# Extract messages from extra_args (set by Frontend preprocessor)
messages = request.get("extra_args", {}).get("messages", [])
image_urls = extract_image_urls(messages)
if image_urls:
# Process multimodal: download images, compute mm_hash
# Do not reuse request["token_ids"] for MM routing: those are placeholder-level
# tokens from frontend. We need processor-expanded tokens to build block_mm_infos.
processed = process_multimodal(
messages=messages,
image_urls=image_urls,
tokenizer=self.tokenizer,
processor=self.processor,
model=self.model,
model_type=self.model_type,
)
# Build block_mm_infos for MM-aware hash computation
block_mm_infos = build_block_mm_infos(
num_tokens=len(processed.tokens),
block_size=self.block_size,
mm_hashes=processed.mm_hashes,
image_ranges=processed.image_ranges,
)
if block_mm_infos is None:
raise ValueError(
"Failed to build block_mm_infos for multimodal request"
)
routing_tokens = processed.tokens
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"MM request: {len(routing_tokens)} routing tokens, "
f"{len(image_urls)} images, {routing_blocks} routing blocks"
)
else:
# Text-only: rely on frontend-preprocessed token_ids (ModelInput.Tokens contract)
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError(
"Missing or empty token_ids in preprocessed request for text-only routing"
)
routing_tokens = token_ids
routing_blocks = (
len(routing_tokens) + self.block_size - 1
) // self.block_size
logger.debug(
f"Text request: {len(routing_tokens)} routing tokens, "
f"{routing_blocks} routing blocks"
)
# Text-only routing has no multimodal objects; provide per-block None entries.
block_mm_infos = [None] * routing_blocks
# Route and generate through KvRouter with explicit fields.
# We pass:
# - execution payload (token_ids + multi_modal_data)
# - routing payload (mm_routing_info: routing_token_ids + block_mm_infos)
# so generate() can select worker internally while preserving MM correctness.
token_ids = request.get("token_ids")
if not token_ids:
raise ValueError("Missing or empty token_ids in preprocessed request")
mm_routing_info: dict[str, Any] = {
"routing_token_ids": routing_tokens,
"block_mm_infos": block_mm_infos,
}
stream = await self.kv_router.generate(
token_ids=token_ids,
model=request["model"],
stop_conditions=request.get("stop_conditions"),
sampling_options=request.get("sampling_options"),
output_options=request.get("output_options"),
router_config_override=request.get("router_config_override"),
extra_args=request.get("extra_args"),
multi_modal_data=request.get("multi_modal_data"),
mm_routing_info=mm_routing_info,
)
async for response in stream:
yield response
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Launch script for MM Router Worker with TRT-LLM backend
#
# This script starts:
# 1. TRT-LLM workers (standard, with KV event publishing)
# 2. MM Router Worker (computes mm_hash, routes to best worker)
# 3. Frontend (HTTP ingress)
set -e
# Get the directory where this script is located and navigate to dynamo root
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DYNAMO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)"
cd "$DYNAMO_ROOT"
echo "Working directory: $(pwd)"
# Configuration
MODEL="${MODEL:-Qwen/Qwen2-VL-2B-Instruct}"
MODEL_TYPE="${MODEL_TYPE:-qwen2_vl}"
NAMESPACE="${NAMESPACE:-default}"
BLOCK_SIZE="${BLOCK_SIZE:-32}"
HTTP_PORT="${HTTP_PORT:-8000}"
NUM_WORKERS="${NUM_WORKERS:-1}"
echo "=== MM Router Worker Launch Script ==="
echo "Model: $MODEL"
echo "Model Type: $MODEL_TYPE"
echo "Namespace: $NAMESPACE"
echo "Block Size: $BLOCK_SIZE"
echo "HTTP Port: $HTTP_PORT"
echo "Num Workers: $NUM_WORKERS"
echo ""
# Collect PIDs for cleanup
PIDS=()
cleanup() {
echo "Cleaning up..."
for pid in "${PIDS[@]}"; do
kill "$pid" 2>/dev/null || true
done
wait 2>/dev/null
}
trap cleanup EXIT
# Start TRT-LLM workers
# Use a different served-model-name so Frontend routes to MM Router instead
# Use NATS request plane to match MM Router
echo ""
echo "=== Starting TRT-LLM Workers ==="
for i in $(seq 0 $((NUM_WORKERS - 1))); do
echo "Starting TRT-LLM worker $i..."
DYN_REQUEST_PLANE=nats python -m dynamo.trtllm \
--model-path "$MODEL" \
--served-model-name "${MODEL}__internal" \
--endpoint "dyn://${NAMESPACE}.trtllm.generate" \
--modality multimodal \
--publish-events-and-metrics \
--kv-block-size "$BLOCK_SIZE" \
2>&1 | sed "s/^/[trtllm-$i] /" &
PIDS+=($!)
done
# Wait for workers to initialize
echo "Waiting for TRT-LLM workers to initialize..."
sleep 15
# Start MM Router Worker
# Use NATS request plane to match Frontend
echo ""
echo "=== Starting MM Router Worker ==="
DYN_REQUEST_PLANE=nats python -m examples.backends.trtllm.mm_router_worker \
--model "$MODEL" \
--model-type "$MODEL_TYPE" \
--namespace "$NAMESPACE" \
--component mm_router \
--endpoint generate \
--downstream-component trtllm \
--downstream-endpoint generate \
--block-size "$BLOCK_SIZE" \
2>&1 | sed "s/^/[mm_router] /" &
PIDS+=($!)
# Wait for router to initialize
echo "Waiting for MM Router to initialize..."
sleep 5
# Start Frontend
# Use NATS request plane to match MM Router
echo ""
echo "=== Starting Frontend ==="
DYN_REQUEST_PLANE=nats python -m dynamo.frontend \
--http-port "$HTTP_PORT" \
--router-mode round-robin \
2>&1 | sed "s/^/[frontend] /" &
PIDS+=($!)
echo ""
echo "=== All services started ==="
echo "Frontend available at http://localhost:$HTTP_PORT"
echo ""
echo "Test with:"
echo "curl http://localhost:$HTTP_PORT/v1/chat/completions \\"
echo " -H 'Content-Type: application/json' \\"
echo " -d '{"
echo " \"model\": \"$MODEL\","
echo " \"messages\": [{"
echo " \"role\": \"user\","
echo " \"content\": [{"
echo " \"type\": \"text\","
echo " \"text\": \"Describe this image\""
echo " }, {"
echo " \"type\": \"image_url\","
echo " \"image_url\": {\"url\": \"https://example.com/image.jpg\"}"
echo " }]"
echo " }],"
echo " \"max_tokens\": 100"
echo " }'"
echo ""
echo "Press Ctrl+C to stop all services"
# Wait for all background processes
wait
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Multimodal processing utilities for MM Router Worker."""
import logging
from dataclasses import dataclass
from typing import Any
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
from tensorrt_llm.inputs.utils import default_multimodal_input_loader, load_image
from transformers import AutoConfig
logger = logging.getLogger(__name__)
# =============================================================================
# Data structures
# =============================================================================
@dataclass
class ProcessedInput:
"""Processed multimodal input."""
tokens: list[int]
mm_hashes: list[int] | None
image_ranges: list[tuple[int, int]] | None # [(start, end), ...] per image
# =============================================================================
# Public functions
# =============================================================================
def extract_image_urls(messages: list[dict]) -> list[str]:
"""Extract image URLs from OpenAI-format messages."""
urls = []
for msg in messages:
content = msg.get("content", [])
if isinstance(content, list):
for part in content:
if part.get("type") == "image_url":
url = part.get("image_url", {}).get("url")
if url:
urls.append(url)
return urls
def build_prompt_from_messages(messages: list[dict]) -> str:
"""Build a simple prompt string from messages."""
parts = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if isinstance(content, str):
parts.append(f"{role}: {content}")
elif isinstance(content, list):
texts = [p.get("text", "") for p in content if p.get("type") == "text"]
if texts:
parts.append(f"{role}: {' '.join(texts)}")
return "\n".join(parts)
def process_multimodal(
messages: list[dict],
image_urls: list[str],
tokenizer: Any,
processor: Any,
model: str,
model_type: str,
) -> ProcessedInput:
"""Process multimodal request: load images, get expanded tokens and mm_hashes."""
try:
prompt = build_prompt_from_messages(messages)
# Use TRT-LLM loader to process images and get mm data
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model,
model_type=model_type,
modality="multiple_image" if len(image_urls) > 1 else "image",
prompts=[prompt],
media=[image_urls],
image_data_format="pt",
device="cuda",
)
mm_input = inputs[0]
processed_prompt = mm_input.get("prompt", prompt)
multi_modal_data = mm_input.get("multi_modal_data")
# Get expanded tokens and image ranges
tokens, image_ranges = _get_expanded_tokens(
processed_prompt, image_urls, tokenizer, processor, model, model_type
)
# Compute mm_hash for each image
mm_hashes = _compute_mm_hashes(multi_modal_data)
return ProcessedInput(
tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges
)
except Exception as e:
logger.error(f"MM processing failed: {e}", exc_info=True)
raise
def build_block_mm_infos(
num_tokens: int,
block_size: int,
mm_hashes: list[int] | None,
image_ranges: list[tuple[int, int]] | None,
) -> list[dict | None] | None:
"""
Build per-block mm_info for routing.
For each block, check which images overlap with it and add their mm_hash.
Assumption: mm_hashes and image_ranges are in the same order as images appear
in the request (which matches their order in the token sequence).
"""
if not mm_hashes or not image_ranges or len(mm_hashes) != len(image_ranges):
return None
num_blocks = (num_tokens + block_size - 1) // block_size
result = []
for block_idx in range(num_blocks):
block_start = block_idx * block_size
block_end = block_start + block_size
# Find images overlapping this block
mm_objects = [
{"mm_hash": mm_hash, "offsets": []}
for mm_hash, (img_start, img_end) in zip(mm_hashes, image_ranges)
if block_end > img_start and block_start < img_end
]
result.append({"mm_objects": mm_objects} if mm_objects else None)
return result
# =============================================================================
# Internal functions
# =============================================================================
def _get_expanded_tokens(
prompt: str,
image_urls: list[str],
tokenizer: Any,
processor: Any,
model_path: str,
model_type: str,
) -> tuple[list[int], list[tuple[int, int]] | None]:
"""Get tokens with visual expansion and find each image's token range."""
if processor is None:
return tokenizer.encode(prompt), None
try:
# TODO @zdren: use async_load_image or batch load
pil_images = [load_image(url, format="pil") for url in image_urls]
output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = output["input_ids"][0].tolist()
# Get image_token_id from processor
image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
raise ValueError("processor.image_token_id not found")
# Get replacement_id: TRTLLM uses vocab_size + 1 in KV events
replacement_id = _get_replacement_id(model_path)
# Find contiguous image token ranges and replace them in one pass
contiguous_ranges = _find_and_replace_image_tokens(
tokens, image_token_id, replacement_id
)
# Compute tokens per image from processor output
tokens_per_image = _compute_tokens_per_image(output, processor, model_type)
# Split ranges according to tokens_per_image
image_ranges = _compute_per_image_ranges(contiguous_ranges, tokens_per_image)
return tokens, image_ranges
except Exception as e:
logger.warning(f"HF processor failed: {e}", exc_info=True)
return tokenizer.encode(prompt), None
def _compute_tokens_per_image(
processor_output: dict, processor: Any, model_type: str
) -> list[int]:
"""Compute the number of visual tokens for each image from processor output."""
if model_type == "qwen2_vl":
grid_thw = processor_output.get("image_grid_thw")
if grid_thw is None:
raise ValueError(
"image_grid_thw not found in processor output for Qwen2-VL"
)
merge_size = getattr(processor.image_processor, "merge_size", 2)
return [int(t * h * w) // (merge_size**2) for t, h, w in grid_thw]
else:
raise NotImplementedError(f"Model type '{model_type}' is not supported yet")
def _get_replacement_id(model_path: str) -> int:
"""
Get the replacement token ID for image tokens to match TRTLLM's KV event format.
TRTLLM replaces image placeholder tokens with (vocab_size + 1) in KV events.
The vocab_size comes from the model config, not the tokenizer.
"""
try:
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
replacement_id = config.vocab_size + 1
logger.info(f"Got vocab_size={config.vocab_size} from AutoConfig")
return replacement_id
except Exception as e:
raise RuntimeError(
f"Failed to get vocab_size from model config '{model_path}': {e}"
) from e
def _find_and_replace_image_tokens(
tokens: list[int], image_token_id: int, replacement_id: int
) -> list[tuple[int, int]]:
"""
Find all contiguous ranges of image tokens and replace them in place.
Returns: list of (start, end) ranges for contiguous image token regions.
"""
ranges = []
start = None
for i, t in enumerate(tokens):
if t == image_token_id:
tokens[i] = replacement_id
if start is None:
start = i
elif start is not None:
ranges.append((start, i))
start = None
if start is not None:
ranges.append((start, len(tokens)))
if ranges:
logger.info(
f"Replaced {sum(e - s for s, e in ranges)} image tokens: {image_token_id} -> {replacement_id}"
)
return ranges
def _compute_per_image_ranges(
contiguous_ranges: list[tuple[int, int]],
tokens_per_image: list[int],
) -> list[tuple[int, int]] | None:
"""
Split contiguous image token ranges by each image's token count.
Example: contiguous_ranges=[(0, 100)], tokens_per_image=[60, 40]
Returns: [(0, 60), (60, 100)] # image 1 at 0-60, image 2 at 60-100
"""
if not contiguous_ranges:
if tokens_per_image:
logger.warning(
f"No image tokens found but {len(tokens_per_image)} images expected"
)
return None
# Greedily assign images to ranges in order
result = []
image_idx = 0
for range_start, range_end in contiguous_ranges:
range_size = range_end - range_start
pos = range_start
consumed = 0
# Consume images that fit entirely in this range
# (a single image's tokens are always contiguous, cannot span ranges)
while image_idx < len(tokens_per_image):
needed = tokens_per_image[image_idx]
if consumed + needed <= range_size:
result.append((pos, pos + needed))
pos += needed
consumed += needed
image_idx += 1
else:
break
# Range must be exactly filled (no leftover image tokens)
if consumed != range_size:
logger.warning(
f"Range size mismatch: consumed {consumed} != range {range_size}"
)
return None
# All images must be consumed
if image_idx != len(tokens_per_image):
logger.warning(f"Not all images mapped: {image_idx} < {len(tokens_per_image)}")
return None
return result
def _compute_mm_hashes(multi_modal_data: dict | None) -> list[int] | None:
"""Compute mm_hash for each image."""
if not multi_modal_data:
return None
try:
result = apply_mm_hashes(multi_modal_data)
if "image" in result and result["image"]:
return [int(h[:16], 16) for h in result["image"]]
except Exception as e:
logger.warning(f"Failed to compute mm_hashes: {e}")
return None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
MM Router Worker - Multimodal-aware KV cache routing for TRT-LLM.
This worker receives OpenAI-format requests from the frontend, computes
mm_hash for any images, finds the best TRT-LLM worker based on KV cache
overlap, and forwards the request to that worker.
Usage:
python -m examples.backends.trtllm.mm_router_worker \
--model Qwen/Qwen2-VL-2B-Instruct \
--model-type qwen2_vl \
--namespace default \
--component mm_router \
--endpoint generate \
--downstream-component trtllm \
--downstream-endpoint generate
"""
import argparse
import asyncio
import logging
import signal
import uvloop
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from transformers import AutoProcessor
from dynamo.llm import KvRouter, KvRouterConfig, ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from .handler import MMRouterHandler
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="MM Router Worker - Multimodal-aware KV cache routing"
)
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2-VL-2B-Instruct",
help="Model path or HuggingFace model ID",
)
parser.add_argument(
"--model-type",
type=str,
default="qwen2_vl",
help="Model type for TRT-LLM multimodal loader",
)
parser.add_argument(
"--block-size",
type=int,
default=32,
help="KV cache block size",
)
# This worker's endpoint configuration
parser.add_argument(
"--namespace",
type=str,
default="default",
help="Dynamo namespace",
)
parser.add_argument(
"--component",
type=str,
default="mm_router",
help="This worker's component name",
)
parser.add_argument(
"--endpoint",
type=str,
default="generate",
help="This worker's endpoint name",
)
# Downstream TRT-LLM worker configuration
parser.add_argument(
"--downstream-component",
type=str,
default="trtllm",
help="Downstream TRT-LLM workers' component name",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default="generate",
help="Downstream TRT-LLM workers' endpoint name",
)
return parser.parse_args()
async def graceful_shutdown(runtime: DistributedRuntime) -> None:
"""Handle graceful shutdown."""
logger.info("Received shutdown signal, shutting down...")
runtime.shutdown()
logger.info("Shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime) -> None:
"""
Main worker function.
Sets up connections to downstream TRT-LLM workers, creates KvRouter
for KV-aware routing, and serves the MM router endpoint.
"""
args = parse_args()
# Set up signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("MM Router Worker starting...")
logger.info(f"Model: {args.model}")
logger.info(f"This worker: {args.namespace}.{args.component}.{args.endpoint}")
logger.info(
f"Downstream: {args.namespace}.{args.downstream_component}.{args.downstream_endpoint}"
)
# Connect to downstream TRT-LLM workers
downstream_endpoint = (
runtime.namespace(args.namespace)
.component(args.downstream_component)
.endpoint(args.downstream_endpoint)
)
downstream_client = await downstream_endpoint.client()
logger.info("Waiting for downstream TRT-LLM workers...")
instance_ids = await downstream_client.wait_for_instances()
logger.info(f"Found {len(instance_ids)} workers: {list(instance_ids)}")
# Create KvRouter to select workers based on KV overlap
kv_router = KvRouter(
endpoint=downstream_endpoint,
block_size=args.block_size,
kv_router_config=KvRouterConfig(),
)
logger.info("KvRouter created successfully")
# Initialize tokenizer and processor for MM processing
logger.info(f"Loading tokenizer from {args.model}...")
tokenizer = tokenizer_factory(args.model)
processor = None
try:
logger.info(f"Loading HuggingFace processor from {args.model}...")
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
except Exception as e:
logger.warning(f"Failed to load HF processor: {e}")
logger.warning("Visual token expansion will not be available")
# Create handler
handler = MMRouterHandler(
kv_router=kv_router,
tokenizer=tokenizer,
processor=processor,
model=args.model,
model_type=args.model_type,
block_size=args.block_size,
)
# Register this worker's endpoint
component = runtime.namespace(args.namespace).component(args.component)
endpoint = component.endpoint(args.endpoint)
# Use ModelInput.Tokens so Frontend preprocesses the request
# Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}}
await register_llm(
ModelInput.Tokens,
ModelType.Chat,
endpoint,
args.model,
kv_cache_block_size=args.block_size,
)
logger.info(f"MM Router Worker ready, serving {args.endpoint} endpoint...")
# Serve the endpoint
await endpoint.serve_endpoint(handler.generate)
def main() -> None:
"""Entry point for the MM Router Worker."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
uvloop.install()
asyncio.run(worker())
if __name__ == "__main__":
main()
......@@ -435,7 +435,7 @@ class TrtllmWorker:
await asyncio.sleep(2)
try:
events = self.llm.get_kv_cache_events_async(timeout=None)
events = self.llm.get_kv_cache_events_async(timeout=5)
logger.info(f"Worker {self.worker_id}: KV events iterator obtained")
async for event in events:
......
......@@ -742,6 +742,7 @@ class KvEventPublisher:
block_hashes: List[int],
lora_id: int,
parent_hash: Optional[int] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
) -> None:
"""
Publish a KV stored event.
......@@ -754,6 +755,9 @@ class KvEventPublisher:
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
block_mm_infos: Optional list of multimodal info for each block.
Each item is either None or a dict with "mm_objects" key containing
a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts.
"""
...
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Multimodal router tests."""
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment