Unverified Commit 5103efdb authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

fix(vllm): include image geometry in multimodal hash preimage (#8341)


Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent ca15a79d
...@@ -123,8 +123,10 @@ def _compute_mm_uuids( ...@@ -123,8 +123,10 @@ def _compute_mm_uuids(
""" """
Compute multi_modal_uuids from multi_modal_data. Compute multi_modal_uuids from multi_modal_data.
Each image gets a SHA256 hex digest as its UUID, ensuring consistent Each image gets a blake3 hex digest as its UUID (computed by
hashing across the MM Router, vLLM handler, and Rust KV publisher. compute_mm_uuids_from_images over a fixed-length header + pixel
preimage), ensuring consistent hashing across the MM Router, vLLM
handler, and Rust KV publisher.
""" """
if not multi_modal_data or "image" not in multi_modal_data: if not multi_modal_data or "image" not in multi_modal_data:
return None return None
...@@ -1379,8 +1381,14 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -1379,8 +1381,14 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
"token_ids": [], "token_ids": [],
}, },
) )
# Normal path: use token IDs # Normal path: use token IDs.
mm_uuids = _compute_mm_uuids(multi_modal_data) # In EPD mode multi_modal_data carries pre-computed embeddings from the
# encode worker, not raw images — skip UUID production here; raw-image
# identity lives upstream at the Router / URL-keyed encoder cache.
if self.embedding_loader is None:
mm_uuids = _compute_mm_uuids(multi_modal_data)
else:
mm_uuids = None
prompt_kwargs = dict[str, Any]( prompt_kwargs = dict[str, Any](
prompt_token_ids=request["token_ids"], prompt_token_ids=request["token_ids"],
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
......
...@@ -2,38 +2,94 @@ ...@@ -2,38 +2,94 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Any, Sequence import struct
from typing import Sequence, Union
import blake3 import blake3
import numpy as np import numpy as np
import torch from PIL import Image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ImageInput = Union[Image.Image, np.ndarray]
def image_to_bytes(img: Any) -> bytes: # Preimage layout (12 bytes, fixed-length) || pixel bytes.
"""Convert a supported image object to PNG bytes for hashing.""" # offset bytes field
from PIL import Image # 0 1 version_byte currently 0x01; bump to rotate preimage format
# 1 1 mode_tag 0x00 = RGB
# 2 1 dtype_tag 0x00 = uint8
# 3 1 channels currently 3
# 4 4 height u32 little-endian
# 8 4 width u32 little-endian
#
# The fixed-length header prevents delimiter-ambiguity collisions: no two
# distinct (header, pixel) pairs can produce the same preimage.
#
# struct format "<BBBBII":
# < little-endian, no padding
# BBBB four uint8 fields (version, mode, dtype, channels)
# II two uint32 fields (height, width)
_HEADER_STRUCT = struct.Struct("<BBBBII")
_VERSION_BYTE = 0x01
_MODE_RGB = 0x00
_DTYPE_UINT8 = 0x00
_CHANNELS_RGB = 3
if isinstance(img, bytes):
return img
if isinstance(img, Image.Image | np.ndarray): def _header(height: int, width: int) -> bytes:
return img.tobytes() return _HEADER_STRUCT.pack(
_VERSION_BYTE, _MODE_RGB, _DTYPE_UINT8, _CHANNELS_RGB, height, width
)
if isinstance(img, torch.Tensor):
# Make sure the bytes are on the CPU
return img.cpu().numpy().tobytes()
raise TypeError(f"Unsupported image type for hashing: {type(img)}") def _image_preimage_parts(img: ImageInput) -> tuple[bytes, bytes]:
"""Return `(header_bytes, pixel_bytes)` for a canonicalized image.
The returned pair is safe to feed into an incremental blake3 hasher. Two
images that differ only in (H, W) produce different `header_bytes` and
therefore different digests.
def compute_mm_uuids_from_images(images: Sequence[Any]) -> list[str]: Raises:
ValueError: input shape, dtype, or mode violates the RGB uint8 contract.
TypeError: input is neither a PIL.Image.Image nor an np.ndarray.
""" """
Compute blake3 hex UUIDs for image inputs. if isinstance(img, Image.Image):
if img.mode != "RGB":
raise ValueError(
f"compute_mm_uuids_from_images expected RGB mode, got {img.mode!r}"
)
width, height = img.size
return _header(height, width), img.tobytes()
if isinstance(img, np.ndarray):
if img.dtype != np.uint8 or img.ndim != 3 or img.shape[2] != _CHANNELS_RGB:
raise ValueError(
"compute_mm_uuids_from_images expected dtype=uint8 and shape "
f"(H, W, 3), got dtype={img.dtype} shape={img.shape}"
)
contiguous = np.ascontiguousarray(img)
height, width, _ = contiguous.shape
return _header(height, width), contiguous.tobytes()
raise TypeError(
"compute_mm_uuids_from_images expected PIL.Image.Image or np.ndarray, "
f"got {type(img).__name__}"
)
def compute_mm_uuids_from_images(images: Sequence[ImageInput]) -> list[str]:
"""Compute blake3 hex UUIDs for image inputs.
Each preimage is a fixed-length header (version, mode, dtype, channels,
height, width) followed by the raw RGB uint8 pixel bytes. Including
geometry in the preimage prevents two different-shape images with equal
pixel count from colliding on the same cache key.
""" """
uuids: list[str] = [] uuids: list[str] = []
for img in images: for img in images:
raw_bytes = image_to_bytes(img) header, pixels = _image_preimage_parts(img)
uuids.append(blake3.blake3(raw_bytes).hexdigest()) h = blake3.blake3()
h.update(header)
h.update(pixels)
uuids.append(h.hexdigest())
return uuids return uuids
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for dynamo.vllm.multimodal_utils.hash_utils.
The hash preimage must include image geometry; otherwise two RGB images with
different (W, H) but equal pixel count produce identical cache keys. These
tests also pin the RGB uint8 canonicalization contract and the on-disk
preimage format via a known-digest stability anchor.
"""
import numpy as np
import pytest
from PIL import Image
from dynamo.vllm.multimodal_utils.hash_utils import compute_mm_uuids_from_images
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.multimodal,
]
@pytest.mark.parametrize(
"make_image",
[
pytest.param(
lambda h, w, buf: Image.frombytes("RGB", (w, h), buf),
id="pil",
),
pytest.param(
lambda h, w, buf: np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 3),
id="ndarray",
),
],
)
def test_dimension_swap_no_collision(make_image):
"""Two RGB images sharing the same flat pixel buffer but with swapped
(W, H) must hash to different UUIDs. Raw pixel bytes carry no geometry,
so the preimage must include dimensions explicitly. Covers both the PIL
input path (URL decode) and the ndarray path (NIXL / Rust decoder).
"""
buf = bytes(range(256)) * ((30 * 150 * 3) // 256 + 1)
buf = buf[: 30 * 150 * 3]
wide = make_image(30, 150, buf)
tall = make_image(150, 30, buf)
[wide_uuid] = compute_mm_uuids_from_images([wide])
[tall_uuid] = compute_mm_uuids_from_images([tall])
assert wide_uuid != tall_uuid
@pytest.mark.parametrize(
"bad_input, exc",
[
pytest.param(
lambda: Image.new("L", (8, 8)),
ValueError,
id="pil_mode_L",
),
pytest.param(
lambda: np.zeros((8, 8, 3), dtype=np.float32),
ValueError,
id="ndarray_dtype_float32",
),
pytest.param(
lambda: np.zeros((8, 8, 4), dtype=np.uint8),
ValueError,
id="ndarray_shape_4ch",
),
pytest.param(
lambda: b"\x00" * (8 * 8 * 3),
TypeError,
id="bytes",
),
],
)
def test_rejects_invalid_input(bad_input, exc):
"""Inputs outside the RGB uint8 (H, W, 3) contract must raise before any
hashing work — loud failure beats silent collision.
"""
with pytest.raises(exc):
compute_mm_uuids_from_images([bad_input()])
def test_known_digest_stability():
"""A pinned 8x4 RGB gradient must hash to a fixed hex digest. If the
preimage layout ever changes unintentionally, this test fails. If it is
ever changed intentionally, bump the preimage version byte and update
the pinned digest in the same commit.
"""
h, w = 4, 8
arr = np.zeros((h, w, 3), dtype=np.uint8)
for y in range(h):
for x in range(w):
arr[y, x] = (x * 16, y * 32, (x + y) * 8)
[uuid] = compute_mm_uuids_from_images([arr])
assert uuid == "1a53ddd0d1539154841e71befde56e9d90661e41b2256223f9ab9ed3fc7c02d5"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment