Unverified Commit d993f9d3 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix: support TRT-LLM 1.3 apply_mm_hashes API (#6810)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent 3c0fdd05
...@@ -154,7 +154,7 @@ This is included in `mm_routing_info` so KvRouter can compute MM-aware overlap. ...@@ -154,7 +154,7 @@ This is included in `mm_routing_info` so KvRouter can compute MM-aware overlap.
## Dependencies ## Dependencies
- `tensorrt_llm >= 1.2.0rc6` - For `apply_mm_hashes()` and `default_multimodal_input_loader()`. Earlier versions may not include multimodal hash support in KV events. - `tensorrt_llm >= 1.3.0rc5` - Required for the current `apply_mm_hashes()` tuple return contract (`(mm_hashes_by_modality, uuids)`), used by this worker's routing hash extraction path.
- `transformers` - For `AutoProcessor` - `transformers` - For `AutoProcessor`
- `dynamo` - For runtime and KvRouter - `dynamo` - For runtime and KvRouter
......
...@@ -83,6 +83,8 @@ class MMRouterHandler: ...@@ -83,6 +83,8 @@ class MMRouterHandler:
processor=self.processor, processor=self.processor,
model=self.model, model=self.model,
model_type=self.model_type, model_type=self.model_type,
request_token_ids=request.get("token_ids"),
request_multi_modal_data=request.get("multi_modal_data"),
) )
# Build block_mm_infos for MM-aware hash computation # Build block_mm_infos for MM-aware hash computation
......
...@@ -8,7 +8,7 @@ from dataclasses import dataclass ...@@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Any from typing import Any
from tensorrt_llm.inputs.multimodal import apply_mm_hashes from tensorrt_llm.inputs.multimodal import apply_mm_hashes
from tensorrt_llm.inputs.utils import default_multimodal_input_loader, load_image from tensorrt_llm.inputs.utils import load_image
from transformers import AutoConfig from transformers import AutoConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,34 +69,48 @@ def process_multimodal( ...@@ -69,34 +69,48 @@ def process_multimodal(
processor: Any, processor: Any,
model: str, model: str,
model_type: str, model_type: str,
request_token_ids: list[int] | None = None,
request_multi_modal_data: dict | None = None,
) -> ProcessedInput: ) -> ProcessedInput:
"""Process multimodal request: load images, get expanded tokens and mm_hashes.""" """Process multimodal request: load images, get expanded tokens and mm_hashes."""
try: try:
prompt = build_prompt_from_messages(messages) # Prefer the exact image sources from preprocessed request payload so routing
# hash inputs follow the same media path as backend execution.
# Use TRT-LLM loader to process images and get mm data effective_image_urls = (
inputs = default_multimodal_input_loader( _extract_image_urls_from_request_mm_data(request_multi_modal_data)
tokenizer=tokenizer, or image_urls
model_dir=model,
model_type=model_type,
modality="multiple_image" if len(image_urls) > 1 else "image",
prompts=[prompt],
media=[image_urls],
image_data_format="pt",
device="cuda",
) )
if not effective_image_urls:
raise ValueError("No image URLs found for multimodal processing")
mm_input = inputs[0] if not request_token_ids:
processed_prompt = mm_input.get("prompt", prompt) raise ValueError("Missing request token_ids for multimodal routing")
multi_modal_data = mm_input.get("multi_modal_data")
# Match TRT-LLM 1.3 preprocessing path:
# request has prompt_token_ids -> decode to prompt text -> processor re-encodes.
processed_prompt = _decode_prompt_from_token_ids(tokenizer, request_token_ids)
if processed_prompt is None:
raise ValueError(
"Failed to decode request token_ids for multimodal routing"
)
# Load PIL images once from effective image sources. Reuse for both token
# expansion and mm_hash computation to keep routing inputs self-consistent.
pil_images = [load_image(url, format="pil") for url in effective_image_urls]
# Get expanded tokens and image ranges # Get expanded tokens and image ranges
tokens, image_ranges = _get_expanded_tokens( tokens, image_ranges = _get_expanded_tokens(
processed_prompt, image_urls, tokenizer, processor, model, model_type processed_prompt,
effective_image_urls,
tokenizer,
processor,
model,
model_type,
pil_images=pil_images,
) )
# Compute mm_hash for each image # Compute mm_hash for each image from backend-like multimodal structure.
mm_hashes = _compute_mm_hashes(multi_modal_data) mm_hashes = _compute_mm_hashes({"image": pil_images})
return ProcessedInput( return ProcessedInput(
tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges tokens=tokens, mm_hashes=mm_hashes, image_ranges=image_ranges
...@@ -147,6 +161,24 @@ def build_block_mm_infos( ...@@ -147,6 +161,24 @@ def build_block_mm_infos(
# ============================================================================= # =============================================================================
def _decode_prompt_from_token_ids(
tokenizer: Any, request_token_ids: list[int] | None
) -> str | None:
"""Decode frontend token_ids back to prompt text (TRT-LLM 1.3 VLM path)."""
if not request_token_ids:
return None
# tensorrt_llm tokenizers and HF tokenizers expose slightly different decode signatures.
decode = getattr(tokenizer, "decode", None)
if decode is None:
return None
try:
return decode(request_token_ids, skip_special_tokens=False)
except TypeError:
return decode(request_token_ids)
def _get_expanded_tokens( def _get_expanded_tokens(
prompt: str, prompt: str,
image_urls: list[str], image_urls: list[str],
...@@ -154,14 +186,16 @@ def _get_expanded_tokens( ...@@ -154,14 +186,16 @@ def _get_expanded_tokens(
processor: Any, processor: Any,
model_path: str, model_path: str,
model_type: str, model_type: str,
pil_images: list[Any] | None = None,
) -> tuple[list[int], list[tuple[int, int]] | None]: ) -> tuple[list[int], list[tuple[int, int]] | None]:
"""Get tokens with visual expansion and find each image's token range.""" """Get tokens with visual expansion and find each image's token range."""
if processor is None: if processor is None:
return tokenizer.encode(prompt), None return tokenizer.encode(prompt), None
try: try:
# TODO @zdren: use async_load_image or batch load if pil_images is None:
pil_images = [load_image(url, format="pil") for url in image_urls] # TODO @zdren: use async_load_image or batch load
pil_images = [load_image(url, format="pil") for url in image_urls]
output = processor( output = processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True text=[prompt], images=pil_images, return_tensors="pt", padding=True
) )
...@@ -318,9 +352,34 @@ def _compute_mm_hashes(multi_modal_data: dict | None) -> list[int] | None: ...@@ -318,9 +352,34 @@ def _compute_mm_hashes(multi_modal_data: dict | None) -> list[int] | None:
if not multi_modal_data: if not multi_modal_data:
return None return None
try: try:
result = apply_mm_hashes(multi_modal_data) # TRT-LLM 1.3+ returns Tuple[Dict[str, List[str]], Optional[List[Optional[str]]]].
# This worker targets TRT-LLM >= 1.3.0rc5.
result = apply_mm_hashes(multi_modal_data)[0]
if "image" in result and result["image"]: if "image" in result and result["image"]:
return [int(h[:16], 16) for h in result["image"]] return [int(h[:16], 16) for h in result["image"]]
except Exception as e: except Exception as e:
logger.warning(f"Failed to compute mm_hashes: {e}") logger.warning(f"Failed to compute mm_hashes: {e}")
return None return None
def _extract_image_urls_from_request_mm_data(
request_multi_modal_data: dict | None,
) -> list[str] | None:
"""Extract image URLs from request multi_modal_data.image_url payload."""
if not isinstance(request_multi_modal_data, dict):
return None
image_items = request_multi_modal_data.get("image_url")
if not isinstance(image_items, list):
return None
urls: list[str] = []
for item in image_items:
if isinstance(item, dict):
url = item.get("Url")
if isinstance(url, str) and url:
urls.append(url)
elif isinstance(item, str) and item:
urls.append(item)
return urls if urls else None
...@@ -240,7 +240,8 @@ class ServiceAPI: ...@@ -240,7 +240,8 @@ class ServiceAPI:
modality=modality, modality=modality,
prompts=[prompt], prompts=[prompt],
media=[image_urls], media=[image_urls],
image_data_format="pt", # Align hash input type with backend multimodal processor path.
image_data_format="pil",
device="cuda", device="cuda",
) )
mm_input = inputs[0] mm_input = inputs[0]
...@@ -327,12 +328,22 @@ class ServiceAPI: ...@@ -327,12 +328,22 @@ class ServiceAPI:
if not multi_modal_data: if not multi_modal_data:
return None return None
mm_hashes_dict = apply_mm_hashes(multi_modal_data) # TRT-LLM 1.3 returns Tuple[Dict[str, List[str]], Optional[List[Optional[str]]]].
if "image" in mm_hashes_dict and mm_hashes_dict["image"]: mm_hashes_dict = apply_mm_hashes(multi_modal_data)[0]
# Convert each 256-bit hex digest to 64-bit int if not isinstance(mm_hashes_dict, dict) or not mm_hashes_dict:
mm_hashes = [ return None
int(hex_digest[:16], 16) for hex_digest in mm_hashes_dict["image"]
# Prefer image modality for stable behavior, but fall back to flattening
# all modality hashes to stay forward-compatible.
hash_hexes = mm_hashes_dict.get("image")
if not hash_hexes:
hash_hexes = [
h for hashes in mm_hashes_dict.values() for h in (hashes or [])
] ]
if hash_hexes:
# Convert each 256-bit hex digest to 64-bit int
mm_hashes = [int(hex_digest[:16], 16) for hex_digest in hash_hexes]
logger.debug(f"Computed mm_hashes for {len(mm_hashes)} images: {mm_hashes}") logger.debug(f"Computed mm_hashes for {len(mm_hashes)} images: {mm_hashes}")
return mm_hashes return mm_hashes
return None return None
......
...@@ -40,6 +40,7 @@ SINGLE_IMAGE_TOTAL_BLOCKS_RANGE = (20, 260) ...@@ -40,6 +40,7 @@ SINGLE_IMAGE_TOTAL_BLOCKS_RANGE = (20, 260)
pytestmark = [ pytestmark = [
pytest.mark.e2e, pytest.mark.e2e,
pytest.mark.pre_merge,
pytest.mark.trtllm, pytest.mark.trtllm,
pytest.mark.multimodal, pytest.mark.multimodal,
pytest.mark.gpu_1, pytest.mark.gpu_1,
...@@ -301,6 +302,13 @@ def _send_request_get_overlap( ...@@ -301,6 +302,13 @@ def _send_request_get_overlap(
timeout_s=120, timeout_s=120,
) )
print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}") print(f"[MM_ROUTER_E2E] {label}: current={overlap}/{total}")
# Allow time for KV cache events to propagate from the TRT-LLM worker
# through the publisher to the router's indexer and radix tree. Without
# this, the next request may be routed before newly cached blocks are
# visible, causing spurious 0-overlap results.
time.sleep(2)
return overlap, total, segment return overlap, total, segment
...@@ -452,9 +460,11 @@ def test_trtllm_mm_overlap_repeated_two_identical_images( ...@@ -452,9 +460,11 @@ def test_trtllm_mm_overlap_repeated_two_identical_images(
overlap_1, total_1, _ = _send_request_get_overlap( overlap_1, total_1, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req1" frontend_port, router_proc, payload, "same_two_identical_images_req1"
) )
time.sleep(1)
overlap_2, total_2, _ = _send_request_get_overlap( overlap_2, total_2, _ = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req2" frontend_port, router_proc, payload, "same_two_identical_images_req2"
) )
time.sleep(1)
overlap_3, total_3, segment_3 = _send_request_get_overlap( overlap_3, total_3, segment_3 = _send_request_get_overlap(
frontend_port, router_proc, payload, "same_two_identical_images_req3" frontend_port, router_proc, payload, "same_two_identical_images_req3"
) )
...@@ -507,13 +517,20 @@ def test_trtllm_mm_overlap_staircase_single_to_double_to_triple_identical_image( ...@@ -507,13 +517,20 @@ def test_trtllm_mm_overlap_staircase_single_to_double_to_triple_identical_image(
f"1x={overlap_1}/{total_1}, 2x={overlap_2}/{total_2}.\n" f"1x={overlap_1}/{total_1}, 2x={overlap_2}/{total_2}.\n"
f"Recent router logs:\n{segment_2[-4000:]}" f"Recent router logs:\n{segment_2[-4000:]}"
) )
assert abs(overlap_3 - overlap_2) <= 1, ( assert overlap_3 > overlap_2, (
"Expected first 3-image request overlap to stay near 2-image overlap " "Expected overlap to increase from 2 images to 3 images, got "
"(third-image suffix is cold on first 3-image request), got "
f"2x={overlap_2}/{total_2}, 3x={overlap_3}/{total_3}.\n" f"2x={overlap_2}/{total_2}, 3x={overlap_3}/{total_3}.\n"
f"Recent router logs:\n{segment_3[-4000:]}" f"Recent router logs:\n{segment_3[-4000:]}"
) )
delta21 = overlap_2 - overlap_1
delta32 = overlap_3 - overlap_2
assert abs(delta32 - delta21) <= 4, (
"Expected similar overlap increment per additional identical image, got "
f"step(1->2)={delta21}, step(2->3)={delta32}.\n"
f"Recent router logs:\n{segment_3[-4000:]}"
)
total_step_12 = total_2 - total_1 total_step_12 = total_2 - total_1
total_step_23 = total_3 - total_2 total_step_23 = total_3 - total_2
assert abs(total_step_12 - total_step_23) <= 4, ( assert abs(total_step_12 - total_step_23) <= 4, (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment