Unverified Commit 91b361ae authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[V1] Extend beyond image modality and support mixed-modality inference with...


[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (#11685)
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent e20c92bb
...@@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
- -
- ✅︎ - ✅︎
- - ✅︎
* - `MiniCPMV` * - `MiniCPMV`
- MiniCPM-V - MiniCPM-V
- T + I<sup>E+</sup> - T + I<sup>E+</sup>
......
...@@ -2,16 +2,22 @@ import base64 ...@@ -2,16 +2,22 @@ import base64
import mimetypes import mimetypes
import os import os
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Tuple from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple
import numpy as np import numpy as np
import pytest import pytest
from PIL import Image, ImageChops from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
...@@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model): ...@@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
assert new_prompt == expected_prompt assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids assert new_token_ids == expected_token_ids
assert ranges == expected_ranges assert ranges == expected_ranges
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]
def test_merge_and_sort_multimodal_metadata():
test_cases = [
# Single modality should return result as is but flattened
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
]
},
mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
],
expected_hashes=["hash1", "hash2"],
),
# Single modality without hashes return None for mm hash.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
]
},
mm_hashes=None,
expected_modalities=["image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
],
expected_hashes=None,
),
# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["audio", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=[
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
],
),
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes=None,
expected_modalities=["audio", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=None,
),
# Three modalities
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
"audio": [
PlaceholderRange(offset=0, length=2),
],
"video": [
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"]
},
expected_modalities=["audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
],
),
]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
def test_merge_and_sort_multimodal_metadata_with_interleaving():
test_cases = [
# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
),
# <image> <image> <video> <audio> <image>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=20, length=4),
],
"audio": [
PlaceholderRange(offset=5, length=2),
],
"video": [
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes=None,
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
),
]
for case in test_cases:
with pytest.raises(ValueError) as ex_info:
merge_and_sort_multimodal_metadata(case.mm_positions,
case.mm_hashes)
assert "Interleaved mixed-modality" in str(ex_info.value)
import pytest import pytest
from vllm.inputs import token_inputs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, KVCacheBlock,
...@@ -14,14 +14,18 @@ def make_request(request_id, ...@@ -14,14 +14,18 @@ def make_request(request_id,
prompt_token_ids, prompt_token_ids,
mm_positions=None, mm_positions=None,
mm_hashes=None): mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
return Request( return Request(
request_id=request_id, request_id=request_id,
inputs=token_inputs( prompt=None,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions} multi_modal_inputs=multi_modal_inputs,
if mm_positions else None, multi_modal_hashes=mm_hashes,
multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions,
),
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
......
"""Compare the with and without prefix caching.""" """Compare the with and without prefix caching."""
import pytest import pytest
from vllm.inputs import token_inputs from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
...@@ -13,12 +12,18 @@ def make_request(request_id, ...@@ -13,12 +12,18 @@ def make_request(request_id,
prompt_token_ids, prompt_token_ids,
mm_positions=None, mm_positions=None,
mm_hashes=None): mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
return Request( return Request(
request_id=request_id, request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids, prompt=None,
multi_modal_placeholders={"image": mm_positions} prompt_token_ids=prompt_token_ids,
if mm_positions else None, multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes), multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
......
...@@ -39,8 +39,12 @@ class SupportsMultiModal(Protocol): ...@@ -39,8 +39,12 @@ class SupportsMultiModal(Protocol):
The output embeddings must be one of the following formats: The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to - A list or tuple of 2D tensors, where each tensor corresponds to
each input image. each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors. - A single 3D tensor, with the batch dimension grouping the 2D tensors.
NOTE: The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
""" """
... ...
......
...@@ -35,6 +35,9 @@ from .siglip import SiglipVisionModel ...@@ -35,6 +35,9 @@ from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict): class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
...@@ -223,8 +226,10 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin, ...@@ -223,8 +226,10 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
max_image_tokens = self._get_max_image_tokens() * max_images max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
return max(max_total_frames // max(max_videos, 1), 1) return max(max_frames_per_video, 1)
def _get_max_video_tokens(self, seq_len: int) -> int: def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self._get_image_size_with_most_features()
...@@ -558,13 +563,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -558,13 +563,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {} modalities = {}
if "pixel_values" in kwargs: # Preserve the order of modalities if there are multiple of them
modalities["images"] = self._parse_and_validate_image_input( # from the order of kwargs.
**kwargs) for input_key in kwargs:
if input_key == "pixel_values" and "images" not in modalities:
if "pixel_values_videos" in kwargs: modalities["images"] = self._parse_and_validate_image_input(
modalities["videos"] = self._parse_and_validate_video_input( **kwargs)
**kwargs) if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities return modalities
...@@ -824,21 +831,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -824,21 +831,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
if not modalities: if not modalities:
return None return None
# We make a tuple of each embedding with its modality string. This is a # The result multimodal_embeddings is tuple of tensors, with each
# temporary workaround for models to handle mixed modalities when # tensor correspoending to a multimodal data item (image or video).
# get_multimodal_embeddings and get_input_embeddings are called multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1. # NOTE: It is important to iterate over the keys in this dictionary
multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] # to preserve the order of the modalities.
for modality in modalities:
if "images" in modalities: if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings.append((vision_embeddings, "image")) multimodal_embeddings += tuple(vision_embeddings)
if "videos" in modalities: if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input) video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings.append((video_embeddings, "video")) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
...@@ -850,15 +857,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -850,15 +857,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings: inputs_embeds = merge_multimodal_embeddings(
if modality == "image": input_ids, inputs_embeds, multimodal_embeddings,
inputs_embeds = merge_multimodal_embeddings( [self.config.image_token_index, self.config.video_token_index])
input_ids, inputs_embeds, embeddings,
self.config.image_token_index)
if modality == "video":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.video_token_index)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -972,8 +972,6 @@ def image_input_mapper_for_molmo( ...@@ -972,8 +972,6 @@ def image_input_mapper_for_molmo(
assert len(data) == 1, "Molmo supports only one image per prompt." assert len(data) == 1, "Molmo supports only one image per prompt."
data = data[0] data = data[0]
# Remove unused dummy PIL image
data.pop('raw_mm_data', None)
return MultiModalKwargs(data) return MultiModalKwargs(data)
...@@ -1019,7 +1017,6 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, ...@@ -1019,7 +1017,6 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
dummy_imgdata = { dummy_imgdata = {
"images": out["images"], "images": out["images"],
"image_input_idx": out["image_input_idx"], "image_input_idx": out["image_input_idx"],
"raw_mm_data": dummy_image,
} }
if "image_masks" in out: if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["image_masks"] = out["image_masks"]
......
from .base import MultiModalPlaceholderMap, MultiModalPlugin from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .hasher import MultiModalHashDict, MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs, MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors) MultiModalPlaceholderDict, NestedTensors)
...@@ -18,6 +19,8 @@ __all__ = [ ...@@ -18,6 +19,8 @@ __all__ = [
"ModalityData", "ModalityData",
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalHashDict",
"MultiModalHasher",
"MultiModalKwargs", "MultiModalKwargs",
"MultiModalPlaceholderDict", "MultiModalPlaceholderDict",
"MultiModalPlaceholderMap", "MultiModalPlaceholderMap",
......
import pickle
from typing import TYPE_CHECKING, Iterable, Mapping, Optional
import numpy as np
import torch
from blake3 import blake3
from PIL import Image
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.inputs import TokensPrompt
logger = init_logger(__name__)
MultiModalHashDict = Mapping[str, list[str]]
"""
A dictionary containing hashes for items in each modality.
"""
class MultiModalHasher:
@classmethod
def serialize_item(cls, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()
# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()
logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))
return pickle.dumps(obj)
@classmethod
def item_to_bytes(
cls,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from cls.item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from cls.item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = cls.serialize_item(key)
value_bytes = cls.serialize_item(obj)
yield key_bytes, value_bytes
@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for k_bytes, v_bytes in cls.item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)
return hasher.hexdigest()
@classmethod
def hash_prompt_mm_data(
cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
mm_items = {
modality: items if isinstance(items, list) else [items]
for modality, items in mm_data.items()
}
mm_hashes = {
modality: [cls.hash_kwargs(**{modality: item}) for item in items]
for modality, items in mm_items.items()
}
return mm_hashes
...@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod ...@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast, from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
final) Union, cast, final)
import numpy as np import numpy as np
import torch import torch
...@@ -14,6 +14,9 @@ from typing_extensions import NotRequired, TypeAlias ...@@ -14,6 +14,9 @@ from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
_T = TypeVar("_T") _T = TypeVar("_T")
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
...@@ -513,7 +516,7 @@ class MultiModalInputsV2(TypedDict): ...@@ -513,7 +516,7 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[list[str]] mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
"""The hashes of the multi-modal data.""" """The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict mm_placeholders: MultiModalPlaceholderDict
......
import pickle
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
...@@ -7,18 +6,16 @@ from dataclasses import dataclass, field ...@@ -7,18 +6,16 @@ from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
import torch
from blake3 import blake3
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm import envs
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalFieldConfig, from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
...@@ -486,56 +483,6 @@ class ProcessingCache: ...@@ -486,56 +483,6 @@ class ProcessingCache:
logger.debug("ProcessingCache: hit_ratio = %.2f", logger.debug("ProcessingCache: hit_ratio = %.2f",
cache_stats.hit_ratio) cache_stats.hit_ratio)
def _serialize_item(self, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()
# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()
logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))
return pickle.dumps(obj)
def _item_to_bytes(
self,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from self._item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from self._item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = self._serialize_item(key)
value_bytes = self._serialize_item(obj)
yield key_bytes, value_bytes
def _hash_kwargs(self, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for k_bytes, v_bytes in self._item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)
return hasher.hexdigest()
def get( def get(
self, self,
model_id: str, model_id: str,
...@@ -554,9 +501,9 @@ class ProcessingCache: ...@@ -554,9 +501,9 @@ class ProcessingCache:
""" """
self._maybe_log_cache_stats() self._maybe_log_cache_stats()
cache_key = self._hash_kwargs(model_id=model_id, cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item}, **{modality: input_item},
**input_kwargs) **input_kwargs)
return self._cache.get(cache_key) return self._cache.get(cache_key)
def put( def put(
...@@ -571,9 +518,9 @@ class ProcessingCache: ...@@ -571,9 +518,9 @@ class ProcessingCache:
Put a processed multi-modal item into the cache Put a processed multi-modal item into the cache
according to its dependencies (see :meth:`get`). according to its dependencies (see :meth:`get`).
""" """
cache_key = self._hash_kwargs(model_id=model_id, cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item}, **{modality: input_item},
**input_kwargs) **input_kwargs)
self._cache.put(cache_key, output_kwargs) self._cache.put(cache_key, output_kwargs)
...@@ -1049,6 +996,24 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1049,6 +996,24 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
# Create MM hashes (only used in V1)
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
if envs.VLLM_USE_V1:
model_id = self.ctx.model_config.model
mm_hashes = {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
else:
mm_hashes = None
prompt_ids, mm_kwargs = self._cached_apply_hf_processor( prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text, prompt_text,
mm_items, mm_items,
...@@ -1122,6 +1087,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1122,6 +1087,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )
...@@ -1174,7 +1140,9 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1174,7 +1140,9 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"tokens.") "tokens.")
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
if total_len > seq_len:
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning( logger.warning(
"The context length (%d) of the model is too short " "The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
......
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Optional, TypeVar, Union from typing import TYPE_CHECKING, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse from urllib.parse import ParseResult, urlparse
import numpy as np import numpy as np
...@@ -25,6 +25,10 @@ cached_get_tokenizer = lru_cache(get_tokenizer) ...@@ -25,6 +25,10 @@ cached_get_tokenizer = lru_cache(get_tokenizer)
_M = TypeVar("_M") _M = TypeVar("_M")
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
from .inputs import MultiModalPlaceholderDict
class MediaConnector: class MediaConnector:
...@@ -437,3 +441,83 @@ def consecutive_placeholder_ranges( ...@@ -437,3 +441,83 @@ def consecutive_placeholder_ranges(
PlaceholderRange(offset=initial_offset + i * item_size, PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items) length=item_size) for i in range(num_items)
] ]
def merge_and_sort_multimodal_metadata(
mm_positions: "MultiModalPlaceholderDict",
mm_hashes: Optional["MultiModalHashDict"],
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
objects from all available modalities into a single list of
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.
Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")
Returns:
list[str]: Sorted list of involved modalities.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
given, None otherwise.
"""
modalities = list(mm_positions.keys())
assert len(modalities) > 0, "No modalities found in the mm_positions."
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
if len(modalities) == 1:
if mm_hashes is None:
return modalities, list(mm_positions[modalities[0]]), None
else:
return modalities, list(mm_positions[modalities[0]]), list(
mm_hashes[modalities[0]])
placeholder_lists_with_modality = [(modality, mm_positions[modality])
for modality in modalities]
if mm_hashes is None:
sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
key=lambda x: x[1][0]['offset'])
sorted_hash_lists = None
else:
hashes_lists = [
mm_hashes[modality] for modality in modalities
if modality in mm_hashes
]
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
hashes_lists),
key=lambda x: x[0][1][0]['offset'])
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
sorted_placeholder_lists = list(sorted_placeholder_tuple)
sorted_hash_lists = list(sorted_hash_tuple)
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
# Flatten sorted list of lists to a single list and verify there is no
# interleaving of placeholders from different modalities.
merged_placeholders: list[PlaceholderRange] = []
for modality, placeholder_list in sorted_placeholder_lists:
if merged_placeholders and placeholder_list[0][
'offset'] < merged_placeholders[-1]['offset']:
raise ValueError(
"Interleaved mixed-modality inference is currently not "
"supported.")
merged_placeholders.extend(placeholder_list)
if sorted_hash_lists is not None:
merged_hashes = []
for hash_list in sorted_hash_lists:
merged_hashes.extend(hash_list)
else:
merged_hashes = None
return sorted_modalities, merged_placeholders, merged_hashes
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import msgspec import msgspec
from vllm.lora.request import LoRARequest if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
@dataclass @dataclass
...@@ -21,13 +23,13 @@ class EngineCoreRequest: ...@@ -21,13 +23,13 @@ class EngineCoreRequest:
# always be tokenized? # always be tokenized?
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: List[int] prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]] mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
mm_hashes: Optional[List[str]] mm_hashes: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict] mm_placeholders: Optional[List["PlaceholderRange"]]
sampling_params: SamplingParams sampling_params: "SamplingParams"
eos_token_id: Optional[int] eos_token_id: Optional[int]
arrival_time: float arrival_time: float
lora_request: Optional[LoRARequest] lora_request: Optional["LoRARequest"]
class EngineCoreOutput( class EngineCoreOutput(
......
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry) MultiModalKwargs, MultiModalRegistry)
...@@ -144,66 +140,3 @@ class MMInputMapperServer: ...@@ -144,66 +140,3 @@ class MMInputMapperServer:
full_mm_inputs.append(mm_input) full_mm_inputs.append(mm_input)
return full_mm_inputs return full_mm_inputs
class MMHasher:
def __init__(self):
pass
def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""
if mm_data is None:
return None
image_inputs = mm_data['image']
# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")
return self.hash_images(image_inputs)
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)
# Convert image to bytes
bytes = image.tobytes()
# Hash image bytes
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())
return ret
...@@ -7,14 +7,15 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, ...@@ -7,14 +7,15 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
MultiModalRegistry) MultiModalKwargs, MultiModalRegistry)
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
class Processor: class Processor:
...@@ -47,7 +48,6 @@ class Processor: ...@@ -47,7 +48,6 @@ class Processor:
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
def process_inputs( def process_inputs(
self, self,
...@@ -73,11 +73,6 @@ class Processor: ...@@ -73,11 +73,6 @@ class Processor:
assert priority == 0, "vLLM V1 does not support priority at the moment." assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet." assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)
# Process inputs. # Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
...@@ -108,8 +103,20 @@ class Processor: ...@@ -108,8 +103,20 @@ class Processor:
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
# Multimodal related.
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
# Use mm_hashes from processed inputs if the model has merged
# input processor.
if decoder_inputs.multi_modal_hashes:
mm_hashes = decoder_inputs.multi_modal_hashes
# Fallback to using MultiModalHasher directly.
else:
mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)
# For merged preprocessor, mm_data is already mm_inputs # For merged preprocessor, mm_data is already mm_inputs
precomputed_mm_inputs = None precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
decoder_mm_data = decoder_inputs.multi_modal_data decoder_mm_data = decoder_inputs.multi_modal_data
if isinstance(decoder_mm_data, MultiModalKwargs): if isinstance(decoder_mm_data, MultiModalKwargs):
# The output of merged multi-modal processor (`decoder_mm_data`) # The output of merged multi-modal processor (`decoder_mm_data`)
...@@ -122,27 +129,67 @@ class Processor: ...@@ -122,27 +129,67 @@ class Processor:
for item in decoder_mm_data.get_items(modality) for item in decoder_mm_data.get_items(modality)
] ]
# Apply MM mapper mm_positions = decoder_inputs.multi_modal_placeholders
mm_inputs = None
if len(decoder_mm_data) > 0: # Last-mile processing of multimodal metadata and inputs.
mm_inputs = self.mm_input_mapper_client.process_inputs( if mm_positions:
decoder_mm_data,
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
sorted_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
mm_positions,
mm_hashes, mm_hashes,
decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs,
) )
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved AND the model supports merged input processor.
if len(sorted_modalities) > 1 and precomputed_mm_inputs:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}
# Sanity check to make sure each multimodal input has only one
# modality key.
for mm_input in precomputed_mm_inputs:
assert len(mm_input.modalities) == 1
# Sort MultiModalKwags to match sorted_mm_positions
precomputed_mm_inputs = sorted(
precomputed_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
# Apply mm input cache update (and input mapper if necessary).
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs(
mm_data=decoder_mm_data,
mm_hashes=sorted_mm_hashes,
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs=precomputed_mm_inputs,
)
else:
sorted_mm_inputs = None
sorted_mm_hashes = None
sorted_mm_positions = None
return EngineCoreRequest( return EngineCoreRequest(
request_id, request_id=request_id,
decoder_inputs.prompt, prompt=decoder_inputs.prompt,
decoder_inputs.prompt_token_ids, prompt_token_ids=decoder_inputs.prompt_token_ids,
mm_inputs, mm_inputs=sorted_mm_inputs,
mm_hashes, mm_hashes=sorted_mm_hashes,
decoder_inputs.multi_modal_placeholders, mm_placeholders=sorted_mm_positions,
sampling_params, sampling_params=sampling_params,
eos_token_id, eos_token_id=eos_token_id,
arrival_time, arrival_time=arrival_time,
lora_request, lora_request=lora_request,
) )
def _validate_model_inputs(self, inputs: ProcessorInputs): def _validate_model_inputs(self, inputs: ProcessorInputs):
......
import enum import enum
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_utils import BlockHashType from vllm.v1.core.kv_cache_utils import BlockHashType
...@@ -18,14 +18,17 @@ class Request: ...@@ -18,14 +18,17 @@ class Request:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
inputs: DecoderOnlyInputs, prompt: Optional[str],
prompt_token_ids: List[int],
multi_modal_inputs: Optional[List["MultiModalKwargs"]],
multi_modal_hashes: Optional[List[str]],
multi_modal_placeholders: Optional[List["PlaceholderRange"]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
eos_token_id: Optional[int], eos_token_id: Optional[int],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.inputs = SingletonInputsAdapter(inputs)
self.sampling_params = sampling_params self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
...@@ -41,26 +44,21 @@ class Request: ...@@ -41,26 +44,21 @@ class Request:
assert sampling_params.max_tokens is not None assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
self.prompt = self.inputs.prompt self.prompt = prompt
self.prompt_token_ids = self.inputs.prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = [] self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy() self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Multi-modal input metadata. # Multi-modal related
mm_positions = self.inputs.multi_modal_placeholders self.mm_positions = multi_modal_placeholders or []
if mm_positions: self.mm_inputs = multi_modal_inputs or []
# FIXME(woosuk): Support other modalities. self.mm_hashes: List[str] = multi_modal_hashes or []
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
assert len(self.mm_inputs) == len(self.mm_hashes)
# Cache the computed kv block hashes of the request to avoid # Cache the computed kv block hashes of the request to avoid
# recomputing. # recomputing.
...@@ -70,15 +68,11 @@ class Request: ...@@ -70,15 +68,11 @@ class Request:
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
inputs=token_inputs( prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt, multi_modal_inputs=request.mm_inputs,
multi_modal_data=None, multi_modal_hashes=request.mm_hashes,
multi_modal_inputs=request.mm_inputs, multi_modal_placeholders=request.mm_placeholders,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None,
),
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
......
...@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available) LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -82,12 +82,10 @@ class GPUModelRunner: ...@@ -82,12 +82,10 @@ class GPUModelRunner:
self.input_registry = INPUT_REGISTRY self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory # NOTE: Initialized input mapper is only used for processing dummy
# profiling. # multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_client = MMInputMapperClient(self.model_config) self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_hasher = MMHasher() self.mm_input_mapper_profiling.use_cache = False
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size self.encoder_cache_size = self.scheduler_config.encoder_cache_size
...@@ -722,8 +720,6 @@ class GPUModelRunner: ...@@ -722,8 +720,6 @@ class GPUModelRunner:
] ]
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO (ywang96): generalize this beyond image modality since
# mm_input_mapper only supports image inputs.
if self.is_multimodal_model: if self.is_multimodal_model:
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
...@@ -735,15 +731,30 @@ class GPUModelRunner: ...@@ -735,15 +731,30 @@ class GPUModelRunner:
dummy_mm_data = dummy_request_data.multi_modal_data dummy_mm_data = dummy_request_data.multi_modal_data
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple. # modality with the max possible input tokens even when
max_tokens_per_mm_item = max( # it supports multiple.
self.mm_registry.get_max_tokens_per_item_by_modality( max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
self.model_config).values()) self.model_config)
max_num_mm_items_encoder_budget = min( dummy_data_modality, max_tokens_per_mm_item = max(
self.max_num_encoder_input_tokens, max_tokens_by_modality_dict.items(), key=lambda item: item[1])
self.encoder_cache_size) // max_tokens_per_mm_item
# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item
# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max( max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt( self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values()) self.model_config).values())
...@@ -763,33 +774,24 @@ class GPUModelRunner: ...@@ -763,33 +774,24 @@ class GPUModelRunner:
# they are scheduled to be processed separately. # they are scheduled to be processed separately.
# Case when models have a merged processor, their dummy data is # Case when models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we need to "unbatch" # already batched `MultiModalKwargs`, therefore we take the first
# and take the first item in each batched tensor. # `MultiModalKwargsItem` from the desired modality to profile on.
# TODO (ywang96): This is somewhat hacky. Refactor this to be
# consistent with the other case.
if isinstance(dummy_mm_data, MultiModalKwargs): if isinstance(dummy_mm_data, MultiModalKwargs):
dummy_mm_kwargs = { dummy_mm_item = dummy_mm_data.get_item(
k: v[0].unsqueeze(0) modality=dummy_data_modality, item_index=0)
for k, v in dummy_mm_data.items() dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
}
# Case when models have dummy data explicitly defined as # Case when models have dummy data explicitly defined as
# `MultiModalDataDict`, so they need to be processed through input # `MultiModalDataDict`, so they need to be processed through input
# mapper. # mapper.
# TODO (ywang96): deprecate this path once merged processor is
# supported on all models.
else: else:
# Compute MM hashes (if enabled) mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_dummy_mm_data(
dummy_mm_data)
mm_kwargs_list = self.mm_input_mapper_client.process_inputs(
mm_data=dummy_mm_data, mm_data=dummy_mm_data,
mm_hashes=mm_hashes, mm_hashes=None,
mm_processor_kwargs=None, mm_processor_kwargs=None,
precomputed_mm_inputs=None) precomputed_mm_inputs=None)
# Take the first `MultiModalKwargs`
dummy_mm_kwargs = mm_kwargs_list[0] dummy_mm_kwargs = mm_kwargs_list[0]
batched_dummy_mm_inputs = MultiModalKwargs.batch( batched_dummy_mm_inputs = MultiModalKwargs.batch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment