Unverified Commit 4da1f667 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Keep track of whether prompt replacements have been applied (#13215)

parent 556ef7f7
...@@ -484,6 +484,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): ...@@ -484,6 +484,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
......
...@@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor( ...@@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values") pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None: if pixel_values is not None:
# Before/after https://github.com/huggingface/transformers/pull/35122 # Before/after https://github.com/huggingface/transformers/pull/35122
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"): if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
images = mm_data["images"] images = mm_data["images"]
assert isinstance(images, list) assert isinstance(images, list)
...@@ -819,7 +819,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -819,7 +819,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids, prompt_ids,
mm_item_counts, mm_item_counts,
) )
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = { mm_placeholder_ranges = {
......
...@@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor( ...@@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor(
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
# So, we process each component separately
# NOTE: No prompt replacement is applied in this case
processor = self.info.get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos text_outputs = super()._call_hf_processor(
# with different sizes when converting back to tensors
text_image_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data={},
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
images = mm_data.pop("images", [])
assert isinstance(images, list)
if images:
processor_outputs = super()._call_hf_processor(
prompt=image_token * len(images),
mm_data={"images": images},
mm_kwargs=mm_kwargs,
)
image_outputs = {
k: v
for k, v in processor_outputs.items()
if k in ("pixel_values", "image_sizes")
}
else:
image_outputs = {}
pixel_values_videos = [] pixel_values_videos = []
for video in videos: for video in videos:
item_processor_data = dict(prompt=video_token, videos=video)
item_outputs = super()._call_hf_processor( item_outputs = super()._call_hf_processor(
prompt=prompt, prompt=video_token,
mm_data=item_processor_data, mm_data={"videos": video},
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
pixel_values_videos.append( pixel_values_videos.append(item_outputs["pixel_values_videos"][0])
item_outputs.pop("pixel_values_videos")[0])
video_outputs = {"pixel_values_videos": pixel_values_videos}
combined_outputs = dict( combined_outputs = dict(
**text_image_outputs, text_outputs,
pixel_values_videos=pixel_values_videos, **image_outputs,
**video_outputs,
) )
return BatchFeature(combined_outputs) return BatchFeature(combined_outputs)
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_repl(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -27,8 +27,8 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, ...@@ -27,8 +27,8 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
import torch.types
from torch import nn from torch import nn
from transformers import BatchFeature
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
...@@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata ...@@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ModalityData, ModalityDataItems, from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData,
MultiModalDataItems, MultiModalDataParser, ModalityDataItems, MultiModalDataItems,
VideoItem) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import PromptReplacement
PromptReplacement)
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo) MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems
class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
...@@ -103,28 +101,49 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, ...@@ -103,28 +101,49 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs] MiniCPMOAudioEmbeddingInputs]
class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
def __init__(self, data: Dict) -> None: return dict(
super().__init__(data, "audio") **_minicpmv_field_config(hf_inputs),
audio_embeds = self.data.get("audio_embeds", None) audio_features=MultiModalFieldConfig.flat_from_sizes(
if audio_embeds is None: "audio", audio_num_slices),
raise ValueError("Incorrect type of video_embeds", audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"Got type: None") "audio", audio_num_slices),
self.data["audio_embeds"] = audio_embeds audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
)
def get(self, index: int) -> object:
return self.data["audio_embeds"][index] class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"audio_embeds"},
)
class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data( def _parse_audio_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(data) return MiniCPMOAudioEmbeddingItems(
data,
fields_config=_minicpmo_field_config(data),
)
return super()._parse_audio_data(data) return super()._parse_audio_data(data)
...@@ -167,6 +186,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -167,6 +186,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_max_audio_chunks_with_most_features(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int:
return 30 return 30
def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate() sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio> # exclude <audio> </audio>
...@@ -194,7 +217,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -194,7 +217,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
return num_frames return num_frames
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder): class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, seq_len: int, mm_counts: Mapping[str, self, seq_len: int, mm_counts: Mapping[str,
...@@ -222,8 +246,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder): ...@@ -222,8 +246,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
class MiniCPMOMultiModalProcessor( class MiniCPMOMultiModalProcessor(
MiniCPMVMultiModalProcessor, MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
BaseMultiModalProcessor[MiniCPMOProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMOMultiModalDataParser( return MiniCPMOMultiModalDataParser(
...@@ -369,21 +392,10 @@ class MiniCPMOMultiModalProcessor( ...@@ -369,21 +392,10 @@ class MiniCPMOMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0)) return _minicpmo_field_config(hf_inputs)
return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices))
class MultiModalProjector(nn.Module): class MultiModalProjector(nn.Module):
...@@ -406,7 +418,7 @@ class MultiModalProjector(nn.Module): ...@@ -406,7 +418,7 @@ class MultiModalProjector(nn.Module):
class MiniCPMWhisperEncoderLayer(nn.Module): class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None): def __init__(self, config: WhisperConfig, layer_idx: int):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[ self.self_attn = WHISPER_ATTENTION_CLASSES[
......
...@@ -35,6 +35,7 @@ import torch.types ...@@ -35,6 +35,7 @@ import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, PlaceholderRange) MultiModalInputs, PlaceholderRange)
from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData, from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
ModalityDataItems, MultiModalDataItems, ModalityData, ModalityDataItems,
MultiModalDataParser, VideoItem) MultiModalDataItems, MultiModalDataParser,
VideoItem)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...@@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): ...@@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs] MiniCPMVImageEmbeddingInputs]
class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
def __init__(self, data: Dict, modality: str) -> None:
super().__init__(data, modality)
def get_processor_data(self) -> Mapping[str, object]:
return self.data
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
def get_count(self) -> int:
return len(self.data[f"{self.modality}_embeds"])
def get(self, index: int) -> Dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
out[k] = v[index]
return out
class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "image")
image_embeds = self.data.get("image_embeds", None)
image_sizes = self.data.get("image_sizes", None)
if image_embeds is None:
raise ValueError("In correct type of image_embeds",
"Got type: None")
if not isinstance(image_embeds[0], torch.Tensor):
raise ValueError("In correct type of image_embeds",
f"Got type: {type(image_embeds[0])}")
if image_sizes is None:
raise ValueError(
"In correct type of image_sizes", "Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`")
if len(image_embeds[0].shape) == 2:
image_embeds = [image_embeds]
image_sizes = [image_sizes]
self.data["image_embeds"] = image_embeds
self.data["image_sizes"] = image_sizes
def get_image_size(self, index: int) -> ImageSize:
image_size = self.data["image_sizes"][index]
return ImageSize(width=image_size[0], height=image_size[1])
class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "video")
video_embeds = self.data.get("video_embeds", None)
image_sizes = self.data.get("image_sizes", None)
num_frames = self.data.get("num_frames", None)
if video_embeds is None:
raise ValueError("In correct type of video_embeds",
"Got type: None")
if not isinstance(video_embeds[0], torch.Tensor):
raise ValueError("In correct type of video_embeds",
f"Got type: {type(video_embeds[0])}")
if image_sizes is None:
raise ValueError(
"In correct type of image_sizes", "Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`")
if num_frames is None:
raise ValueError("In correct type of numframes", "Got type: None")
if len(video_embeds[0].shape) == 2:
video_embeds = [video_embeds]
image_sizes = [image_sizes]
num_frames = [num_frames]
self.data["video_embeds"] = video_embeds
self.data["image_sizes"] = image_sizes
self.data["num_frames"] = num_frames
def get_frame_size(self, index: int) -> ImageSize:
frame_size = self.data["image_sizes"][index]
return ImageSize(width=frame_size[0], height=frame_size[1])
def get_num_frames(self, index: int) -> int:
return self.data["num_frames"][index]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
...@@ -311,6 +226,71 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: ...@@ -311,6 +226,71 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return tuple(int(x) for x in version_str.split(".")) return tuple(int(x) for x in version_str.split("."))
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"),
)
class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"image_embeds", "image_sizes"},
)
def get_image_size(self, index: int) -> ImageSize:
image_size = self.get(index)["image_sizes"].tolist()
return ImageSize(width=image_size[0], height=image_size[1])
class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="video",
fields_config=fields_config,
required_fields={"video_embeds", "video_image_sizes"},
)
def get_frame_size(self, index: int) -> ImageSize:
frame_size = self.get(index)["video_image_sizes"].tolist()
return ImageSize(width=frame_size[0], height=frame_size[1])
def get_num_frames(self, index: int) -> int:
return len(self.get(index)["video_image_sizes"])
class MiniCPMVMultiModalDataParser(MultiModalDataParser): class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data( def _parse_image_data(
...@@ -318,7 +298,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -318,7 +298,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(data) return MiniCPMVImageEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
)
return super()._parse_image_data(data) return super()._parse_image_data(data)
def _parse_video_data( def _parse_video_data(
...@@ -326,7 +310,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -326,7 +310,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(data) return MiniCPMVVideoEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
)
return super()._parse_video_data(data) return super()._parse_video_data(data)
...@@ -392,10 +380,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -392,10 +380,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_max_video_frame_tokens( return self.get_max_video_frame_tokens(
) * self.get_num_frames_with_most_features(seq_len) ) * self.get_num_frames_with_most_features(seq_len)
def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()
def get_slice_query_num(self) -> int: def get_slice_query_num(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
query_num = getattr(hf_config, "query_num", 64) query_num = getattr(hf_config, "query_num", 64)
...@@ -476,8 +460,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -476,8 +460,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return ImageSize(width=image_size, height=image_size * num_slices) return ImageSize(width=image_size, height=image_size * num_slices)
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] _I = TypeVar("_I",
): bound=MiniCPMVProcessingInfo,
default=MiniCPMVProcessingInfo)
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
...@@ -514,8 +502,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] ...@@ -514,8 +502,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
mm_data=mm_data) mm_data=mm_data)
class MiniCPMVMultiModalProcessor( class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
BaseMultiModalProcessor[MiniCPMVProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser() return MiniCPMVMultiModalDataParser()
...@@ -675,7 +662,7 @@ class MiniCPMVMultiModalProcessor( ...@@ -675,7 +662,7 @@ class MiniCPMVMultiModalProcessor(
self.info.get_video_max_slice_num() self.info.get_video_max_slice_num()
) * inputs[modality]["num_frames"][index] ) * inputs[modality]["num_frames"][index]
else: else:
raise ValueError(f"UnExpected modality: {modality}") raise ValueError(f"Unexpected modality: {modality}")
def check_mm_inputs(self, inputs: Dict[str, object], def check_mm_inputs(self, inputs: Dict[str, object],
matches: List[str]) -> None: matches: List[str]) -> None:
...@@ -700,7 +687,7 @@ class MiniCPMVMultiModalProcessor( ...@@ -700,7 +687,7 @@ class MiniCPMVMultiModalProcessor(
inputs["video"]["video_image_sizes"][index], inputs["video"]["video_image_sizes"][index],
inputs["video"]["num_frames"][index]) inputs["video"]["num_frames"][index])
else: else:
raise ValueError(f"UnExpected modality: {modality}") raise ValueError(f"Unexpected modality: {modality}")
def call_base_hf_processor( def call_base_hf_processor(
self, self,
...@@ -742,6 +729,14 @@ class MiniCPMVMultiModalProcessor( ...@@ -742,6 +729,14 @@ class MiniCPMVMultiModalProcessor(
} }
} }
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_replacements( def _get_prompt_replacements(
self, mm_items: MultiModalDataItems, self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
...@@ -770,28 +765,10 @@ class MiniCPMVMultiModalProcessor( ...@@ -770,28 +765,10 @@ class MiniCPMVMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0)) return _minicpmv_field_config(hf_inputs)
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"))
def apply( def apply(
self, self,
......
...@@ -243,16 +243,6 @@ class Qwen2AudioMultiModalProcessor( ...@@ -243,16 +243,6 @@ class Qwen2AudioMultiModalProcessor(
) )
] ]
def _always_apply_prompt_replacements(self) -> bool:
# Qwen2-Audio processor will start inserting placeholder tokens
# in an upcoming release:
# https://github.com/huggingface/transformers/pull/35534
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token")
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor, Qwen2AudioMultiModalProcessor,
......
...@@ -58,8 +58,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -58,8 +58,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
VideoItem) VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems, from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
MultiModalDataItems, MultiModalDataParser) ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...@@ -657,49 +658,25 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -657,49 +658,25 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params return loaded_params
class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
dict[str, torch.Tensor]]): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data, modality)
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
self._slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(grid_thw))
]
def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])
def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]
out[k] = v
return out
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems): video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
def __init__(self, data: dict) -> None: return dict(
super().__init__(data, "video") pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
class Qwen2VLMultiModalDataParser(MultiModalDataParser): class Qwen2VLMultiModalDataParser(MultiModalDataParser):
...@@ -709,7 +686,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): ...@@ -709,7 +686,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return Qwen2VLEmbeddingItems(data, modality="image") return DictEmbeddingItems(
data,
modality="image",
fields_config=_qwen2vl_field_config(data),
required_fields={"image_embeds", "image_grid_thw"},
)
return super()._parse_image_data(data) return super()._parse_image_data(data)
...@@ -718,7 +700,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): ...@@ -718,7 +700,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]: ) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict): if isinstance(data, dict):
return Qwen2VLEmbeddingItems(data, modality="video") return DictEmbeddingItems(
data,
modality="video",
fields_config=_qwen2vl_field_config(data),
required_fields={"video_embeds", "video_grid_thw"},
)
return super()._parse_video_data(data) return super()._parse_video_data(data)
...@@ -999,24 +986,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -999,24 +986,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) return _qwen2vl_field_config(hf_inputs)
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
......
...@@ -520,10 +520,7 @@ class QwenVLProcessingInfo(BaseProcessingInfo): ...@@ -520,10 +520,7 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
return _get_tokenizer_without_image_pad(tokenizer) return _get_tokenizer_without_image_pad(tokenizer)
def get_hf_processor(self) -> QwenVLProcessor: def get_hf_processor(self) -> QwenVLProcessor:
tokenizer = self.ctx.tokenizer return QwenVLProcessor(self.get_hf_config(), self.get_tokenizer())
assert isinstance(tokenizer, PreTrainedTokenizer)
return QwenVLProcessor(self.get_hf_config(), tokenizer)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
...@@ -605,6 +602,14 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -605,6 +602,14 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
......
...@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, ...@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
import numpy as np import numpy as np
import torch import torch
from PIL.Image import Image from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .audio import resample_audio from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict, VideoItem) ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
_T = TypeVar("_T") _T = TypeVar("_T")
_I = TypeVar("_I") _I = TypeVar("_I")
...@@ -111,6 +113,60 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], ...@@ -111,6 +113,60 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
return len(self.get(item_idx)) return len(self.get(item_idx))
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, torch.Tensor]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
fields_config: Mapping[str, MultiModalFieldConfig],
required_fields: set[str],
) -> None:
super().__init__(data, modality)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality)
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return {
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None: def __init__(self, data: Sequence[HfAudioItem]) -> None:
......
...@@ -23,7 +23,8 @@ from .hasher import MultiModalHasher ...@@ -23,7 +23,8 @@ from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
if TYPE_CHECKING: if TYPE_CHECKING:
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder
...@@ -830,15 +831,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -830,15 +831,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs, mm_kwargs,
) )
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
"""
Return whether the HF processor applies prompt replacements.
For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings
are passed.
"""
return not any(
isinstance(items, (EmbeddingItems, DictEmbeddingItems))
for items in mm_items.values())
def _apply_hf_processor_text_mm( def _apply_hf_processor_text_mm(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]: ) -> tuple[list[int], MultiModalKwargs, bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
together. together.
In addition, return whether prompt replacements have been applied.
""" """
processor_data, passthrough_data = self._get_hf_mm_data(mm_items) processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
...@@ -856,7 +876,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -856,7 +876,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
) )
return prompt_ids, mm_kwargs is_repl_applied = self._hf_processor_applies_repl(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return prompt_ids, mm_kwargs, is_repl_applied
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
""" """
...@@ -866,7 +892,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -866,7 +892,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
correspond to each other, we create dummy multi-modal items correspond to each other, we create dummy multi-modal items
to go along with the text. to go along with the text.
""" """
prompt_ids, _ = self._apply_hf_processor_text_mm( prompt_ids, _, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=MultiModalDataItems({}), mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
...@@ -908,7 +934,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -908,7 +934,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_counts, mm_counts,
) )
_, mm_kwargs = self._apply_hf_processor_text_mm( _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
prompt_text=dummy_inputs.prompt_text, prompt_text=dummy_inputs.prompt_text,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -923,13 +949,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -923,13 +949,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_replacement: bool, enable_hf_prompt_replacement: bool,
) -> tuple[list[int], MultiModalKwargs]: ) -> tuple[list[int], MultiModalKwargs, bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data. Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt replacements have been applied
(for most HF processors, this should be :code:`True`).
Note: Note:
If :code:`enable_hf_prompt_replacement=False`, the prompt should If :code:`enable_hf_prompt_replacement=False`, we use HF processor
correspond to the multi-modal items. to perform prompt replacement if available; HF processor requires
that the prompt corresponds to multi-modal items.
""" """
if isinstance(prompt, str): if isinstance(prompt, str):
if enable_hf_prompt_replacement: if enable_hf_prompt_replacement:
...@@ -943,19 +973,19 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -943,19 +973,19 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else: else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_missing_kwargs = self._apply_hf_processor_mm_only( mm_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
return prompt_ids, mm_missing_kwargs return prompt_ids, mm_kwargs, False
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]: ) -> tuple[list[int], MultiModalKwargs, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
caching the results and reusing cached results. caching the results and reusing cached results.
...@@ -992,8 +1022,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -992,8 +1022,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items = self._to_mm_items(mm_missing_data) mm_missing_data_items = self._to_mm_items(mm_missing_data)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we need to pass `enable_hf_prompt_replacement=False` # so we can't apply prompt replacements until the new multimodal
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main( # items are combined with the cached multimodal items
(
prompt_ids,
mm_missing_kwargs,
is_repl_applied,
) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=mm_missing_data_items, mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -1036,7 +1071,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1036,7 +1071,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
return prompt_ids, mm_kwargs return prompt_ids, mm_kwargs, is_repl_applied
def _bind_and_group_repls( def _bind_and_group_repls(
self, self,
...@@ -1047,18 +1082,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1047,18 +1082,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
def _always_apply_prompt_replacements(self) -> bool:
"""
A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via
:meth:`_find_placeholders_by_modality`.
This is useful in cases where :meth:`_find_placeholders_by_modality`
cannot be reliably used to detect whether HF has performed processing.
"""
return False
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
...@@ -1155,29 +1178,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1155,29 +1178,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
*, ) -> None:
allow_missing: bool = False,
) -> Mapping[str, int]:
missing_repl_counts = dict[str, int]()
for modality, item_count in mm_item_counts.items(): for modality, item_count in mm_item_counts.items():
placeholders = mm_placeholders.get(modality, []) placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count and not allow_missing: if len(placeholders) != item_count:
raise RuntimeError( raise RuntimeError(
f"Expected there to be {item_count} prompt replacements " f"Expected there to be {item_count} prompt replacements "
f"corresponding to {item_count} {modality} items, but only " f"corresponding to {item_count} {modality} items, but "
f"found {len(placeholders)} prompt replacements! Either " f"instead found {len(placeholders)} prompt replacements! "
"the prompt text has missing/incorrect tokens for " "Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your " "multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this " "implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between " "model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).") "`_call_hf_processor` and `_get_prompt_replacements`).")
missing_repl_counts[modality] = item_count - len(placeholders)
return missing_repl_counts
def apply( def apply(
self, self,
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
...@@ -1217,7 +1232,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1217,7 +1232,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else: else:
mm_hashes = None mm_hashes = None
prompt_ids, mm_kwargs = self._cached_apply_hf_processor( (
prompt_ids,
mm_kwargs,
is_repl_applied,
) = self._cached_apply_hf_processor(
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
...@@ -1233,51 +1252,26 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1233,51 +1252,26 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts) self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
hf_mm_placeholders = self._find_mm_placeholders( if is_repl_applied:
mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls, mm_prompt_repls,
prompt_ids, prompt_ids,
mm_item_counts, mm_item_counts,
) )
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
if self._always_apply_prompt_replacements():
mm_missing_repl_counts = mm_item_counts
mm_missing_repls = dict(mm_prompt_repls)
else:
mm_missing_repl_counts = self._validate_mm_placeholders(
hf_mm_placeholders,
mm_item_counts,
allow_missing=True,
)
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
elif missing_repl_count == mm_item_counts.get(modality, 0):
mm_missing_repls[modality] = mm_prompt_repls[modality]
else:
raise ValueError("Partial prompt replacement within "
f"{modality=} is not supported")
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids) prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else: else:
( (
prompt_ids, prompt_ids,
prompt, prompt,
missing_mm_placeholders, mm_placeholders,
) = self._apply_prompt_replacements( ) = self._apply_prompt_replacements(
prompt_ids, prompt_ids,
mm_missing_repls, mm_prompt_repls,
mm_missing_repl_counts, mm_item_counts,
) )
mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = { mm_placeholder_ranges = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment