Unverified Commit eed11ebe authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 300acb83
......@@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_height,
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
......@@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
prompt_repls=prompt_repls,
mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts,
)
......@@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
for p in placeholders
]
placeholders = {
modality: [
_PlaceholderInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
replacement=p.replacement,
) for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
......
......@@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
class PixtralHFMLP(nn.Module):
......
......@@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
......@@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
]
def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves
# _find_placeholders may incorrectly think that HF has already performed
# processing for multi-audio input when the input audios are short
# (the corresponding placeholders may take up fewer tokens than
# the number of audio items)
# HF never applies prompt replacements, so we have to do it ourselves.
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return True
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
......
......@@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
......@@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params
# === Vision input helpers === #
def _get_vision_info(
vision_config: Qwen2VLVisionConfig,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
*,
do_resize: bool = True,
modality: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if modality == "image":
grid_t = mm_count
elif modality == "video":
grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = vision_tokens // (merge_size**2)
return resized_height, resized_width, llm_num_vision_tokens
def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
......@@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int:
def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
) -> tuple[ImageSize, int]:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
modality=modality,
image_processor = self._get_image_processor(hf_processor)
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
grid_t = max(num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_dummy_image_size(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
_, max_image_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_tokens
def _get_max_video_tokens(self, num_frames: int) -> int:
_, max_video_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=num_frames,
)
return max_llm_image_tokens
return max_video_tokens
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {
"image": self._get_max_mm_tokens("image"),
"video": self._get_max_mm_tokens("video"),
"image": max_image_tokens,
"video": max_video_tokens,
}
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_hf_processor(
self,
*,
......@@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor)
image_processor = self._get_image_processor(hf_processor)
if min_pixels:
image_processor.min_pixels = min_pixels
......@@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
image_processor = self._get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
......@@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
num_images = mm_counts.get("image", 0)
video_token: str = hf_processor.video_token
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=resized_width,
height=resized_height,
num_images=num_images)
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
......
......@@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):
......
......@@ -6,7 +6,6 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
......@@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
......@@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
......@@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
if not audios:
return super()._call_hf_processor(
......@@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate=feature_extractor.sampling_rate,
)
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
audio_features, audio_token_len = [], []
......@@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
)
]
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
......
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Final, Generic, Optional, Protocol, TypeVar
from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig)
......@@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError
@abstractmethod
def get_num_patches(self) -> int:
def get_image_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_size(self) -> int:
def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError
......@@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError
......@@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video")
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems):
......
......@@ -16,7 +16,8 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
......@@ -69,19 +70,6 @@ def _cached_encode(
add_special_tokens=add_special_tokens)
def _decode(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
......@@ -89,9 +77,9 @@ def _cached_decode(
*,
skip_special_tokens: bool = False,
) -> str:
return _decode(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
return decode_tokens(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
class _HasModalityAttr(Protocol):
......@@ -269,8 +257,10 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
return self.match.end()
class _PlaceholderInfo(NamedTuple):
@dataclass
class _PlaceholderInfo:
modality: str
item_idx: int
start_idx: int
replacement: list[int]
......@@ -311,12 +301,14 @@ def find_text_matches(
def _resolve_matches(
prompt: _PromptSeq,
matches: Sequence[_PromptReplacementMatch],
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
) -> list[_PromptReplacementMatch]:
"""
Resolve :code:`matches` to ensure that there are no overlapping matches,
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
] * len(prompt)
......@@ -334,14 +326,15 @@ def _resolve_matches(
def _replace_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch],
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
mm_item_counts: Mapping[str, int],
) -> list[_S]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
out_seqs = list[_S]()
prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, matches):
for match in _resolve_matches(prompt, mm_matches):
modality = match.modality
item_idx = next_idx_by_modality[modality]
......@@ -371,28 +364,28 @@ def _replace_matches(
def replace_token_matches(
prompt: list[int],
matches: Sequence[_PromptReplacementTokenMatch],
mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
if not mm_matches:
return prompt
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
return flatten_2d_lists(token_id_seqs)
def replace_text_matches(
prompt: str,
matches: Sequence[_PromptReplacementTextMatch],
mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
mm_item_counts: Mapping[str, int],
) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
if not mm_matches:
return prompt
texts = _replace_matches(prompt, matches, mm_item_counts)
texts = _replace_matches(prompt, mm_matches, mm_item_counts)
return "".join(texts)
......@@ -407,14 +400,14 @@ def _iter_modality_placeholders(
return
prompt_len = len(prompt)
item_index = 0
item_idx = 0
start_idx = 0
while start_idx < prompt_len:
found = False
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_index)
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len
......@@ -425,12 +418,13 @@ def _iter_modality_placeholders(
if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)
item_index += 1
if item_index >= modal_item_count:
item_idx += 1
if item_idx >= modal_item_count:
return
# Exclude overlapping matches
......@@ -442,28 +436,36 @@ def _iter_modality_placeholders(
start_idx += 1
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement],
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]:
"""
Yield each set of placeholder tokens found in :code:`prompt`.
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
Note that empty matches are ignored.
"""
repls_by_modality = dict(full_groupby_modality(prompt_repls))
for modality, modal_item_count in mm_item_counts.items():
if modality in repls_by_modality:
if modality in mm_prompt_repls:
yield from _iter_modality_placeholders(
prompt,
modality,
repls_by_modality[modality],
mm_prompt_repls[modality],
modal_item_count,
)
def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
@dataclass
class ProcessorInputs:
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
......@@ -620,7 +622,7 @@ class BaseMultiModalProcessor(ABC):
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
......@@ -703,14 +705,14 @@ class BaseMultiModalProcessor(ABC):
"""
raise NotImplementedError
def _find_placeholders(
def _find_mm_placeholders(
self,
all_prompt_repls: Sequence[_BoundPromptReplacement],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> list[_PlaceholderInfo]:
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
) -> Mapping[str, list[_PlaceholderInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts)
def _get_hf_mm_data(
self,
......@@ -797,7 +799,10 @@ class BaseMultiModalProcessor(ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)
dummy_inputs = self._get_dummy_processor_inputs(
self.ctx.model_config.max_model_len,
mm_missing_counts,
)
_, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=dummy_inputs.prompt_text,
......@@ -889,50 +894,44 @@ class BaseMultiModalProcessor(ABC):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_all_counts()
for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
try:
mm_kwargs.get_item(modality, item_idx)
except Exception as e:
# Make it easy to set a breakpoint in the debugger
raise e
return prompt_ids, mm_kwargs
def _bind_prompt_replacements(
def _bind_and_group_repls(
self,
prompt_repls: list[PromptReplacement],
) -> list[_BoundPromptReplacement]:
) -> dict[str, list[_BoundPromptReplacement]]:
tokenizer = self._get_tokenizer()
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it))
def _always_apply_prompt_replacements(self) -> bool:
"""
A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via :meth:`_find_placeholders`.
detect that HF has performed processing via
:meth:`_find_placeholders_by_modality`.
This is useful in cases where :meth:`_find_placeholders` cannot be
reliably used to detect whether HF has performed processing or not.
This is useful in cases where :meth:`_find_placeholders_by_modality`
cannot be reliably used to detect whether HF has performed processing.
"""
return False
def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
tokenizer = self._get_tokenizer()
token_matches = find_token_matches(token_ids, prompt_repls)
mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
mm_match_counts = {
modality: len(matches)
for modality, matches in full_groupby_modality(token_matches)
for modality, matches in mm_token_matches.items()
}
# If the search text does not represent a special token,
......@@ -951,32 +950,92 @@ class BaseMultiModalProcessor(ABC):
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
token_matches,
mm_token_matches,
mm_item_counts,
)
text = _decode(tokenizer, token_ids)
matched_repls = [match.prompt_repl for match in token_matches]
text = decode_tokens(tokenizer, token_ids)
matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_token_matches.items()
}
else:
text = _decode(tokenizer, token_ids)
text = decode_tokens(tokenizer, token_ids)
text_matches = find_text_matches(text, prompt_repls)
mm_text_matches = {
modality: find_text_matches(text, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
text = replace_text_matches(
text,
text_matches,
mm_text_matches,
mm_item_counts,
)
token_ids = encode_tokens(tokenizer,
text,
add_special_tokens=False)
matched_repls = [match.prompt_repl for match in text_matches]
placeholders = self._find_placeholders(matched_repls, token_ids,
mm_item_counts)
matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_text_matches.items()
}
placeholders = self._find_mm_placeholders(
matched_repls,
token_ids,
mm_item_counts,
)
return token_ids, text, placeholders
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargs,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
if modality in mm_kwargs.modalities:
items = mm_kwargs.get_items(modality)
else:
items = []
if len(items) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} {modality} items in "
f"keyword arguments corresponding to {item_count} "
f"{modality} data items, but only found {len(items)}! "
"There is likely a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_mm_fields_config`).")
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,
) -> Mapping[str, int]:
missing_repl_counts = dict[str, int]()
for modality, item_count in mm_item_counts.items():
placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count and not allow_missing:
raise RuntimeError(
f"Expected there to be {item_count} prompt replacements "
f"corresponding to {item_count} {modality} items, but only "
f"found {len(placeholders)} prompt replacements! Either "
"the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).")
missing_repl_counts[modality] = item_count - len(placeholders)
return missing_repl_counts
def apply(
self,
prompt_text: str,
......@@ -1009,56 +1068,69 @@ class BaseMultiModalProcessor(ABC):
hf_processor_mm_kwargs,
mm_kwargs,
)
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = mm_items.get_all_counts()
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts)
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
hf_mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
prompt_ids,
mm_item_counts,
)
if self._always_apply_prompt_replacements():
mm_missing_repl_counts = mm_item_counts
mm_missing_repls = dict(mm_prompt_repls)
else:
mm_missing_repl_counts = self._validate_mm_placeholders(
hf_mm_placeholders,
mm_item_counts,
allow_missing=True,
)
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
elif missing_repl_count == mm_item_counts.get(modality, 0):
mm_missing_repls[modality] = mm_prompt_repls[modality]
else:
raise ValueError("Partial prompt replacement within "
f"{modality=} is not supported")
if all_placeholders and not self._always_apply_prompt_replacements():
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids)
prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else:
(
prompt_ids,
prompt_text,
all_placeholders,
missing_mm_placeholders,
) = self._apply_prompt_replacements(
prompt_ids,
prompt_repls,
mm_item_counts,
mm_missing_repls,
mm_missing_repl_counts,
)
mm_placeholders = dict[str, list[PlaceholderRange]]()
err_suffix = ("This suggests a problem with your implementation of "
"the merged multi-modal processor for this model, "
"particularly in the `_get_prompt_replacements` method.")
for modality, placeholders in full_groupby_modality(all_placeholders):
if modality not in mm_items:
raise AssertionError(
f"Expected no placeholders for {modality=}, "
f"but found {placeholders=}. Input items: {mm_items}"
f"\n{err_suffix}")
if len(placeholders) != len(mm_items[modality]):
raise AssertionError(
f"Expected length of {placeholders=} for {modality=} "
f"to equal that of input items: {mm_items[modality]}"
f"\n{err_suffix}")
mm_placeholders[modality] = [
item.to_range() for item in placeholders
]
mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
mm_placeholders=mm_placeholder_ranges,
)
def _get_dummy_audios(
......@@ -1092,8 +1164,9 @@ class BaseMultiModalProcessor(ABC):
return [video] * num_videos
@abstractmethod
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
......@@ -1121,12 +1194,25 @@ class BaseMultiModalProcessor(ABC):
return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self._get_and_validate_dummy_mm_counts()
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item()
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
......@@ -1134,13 +1220,7 @@ class BaseMultiModalProcessor(ABC):
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
mm_inputs = self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
......@@ -1171,6 +1251,12 @@ class BaseMultiModalProcessor(ABC):
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
......
......@@ -223,7 +223,8 @@ class MultiModalRegistry:
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item()
seq_len = model_config.max_model_len
return processor.get_mm_max_tokens_per_item(seq_len)
return {
key: plugin.get_max_multimodal_tokens(model_config)
......
......@@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def encode_tokens(
tokenizer: AnyTokenizer,
text: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment