Unverified Commit eed11ebe authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 300acb83
......@@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_height,
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
......@@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
prompt_repls=prompt_repls,
mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts,
)
......@@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
for p in placeholders
]
placeholders = {
modality: [
_PlaceholderInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
replacement=p.replacement,
) for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
......
......@@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
class PixtralHFMLP(nn.Module):
......
......@@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
......@@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
]
def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves
# _find_placeholders may incorrectly think that HF has already performed
# processing for multi-audio input when the input audios are short
# (the corresponding placeholders may take up fewer tokens than
# the number of audio items)
# HF never applies prompt replacements, so we have to do it ourselves.
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return True
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
......
......@@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
......@@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params
# === Vision input helpers === #
def _get_vision_info(
vision_config: Qwen2VLVisionConfig,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
*,
do_resize: bool = True,
modality: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if modality == "image":
grid_t = mm_count
elif modality == "video":
grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = vision_tokens // (merge_size**2)
return resized_height, resized_width, llm_num_vision_tokens
def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
......@@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int:
def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
) -> tuple[ImageSize, int]:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
modality=modality,
image_processor = self._get_image_processor(hf_processor)
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
grid_t = max(num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_dummy_image_size(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
_, max_image_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_tokens
def _get_max_video_tokens(self, num_frames: int) -> int:
_, max_video_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=num_frames,
)
return max_llm_image_tokens
return max_video_tokens
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {
"image": self._get_max_mm_tokens("image"),
"video": self._get_max_mm_tokens("video"),
"image": max_image_tokens,
"video": max_video_tokens,
}
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_hf_processor(
self,
*,
......@@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor)
image_processor = self._get_image_processor(hf_processor)
if min_pixels:
image_processor.min_pixels = min_pixels
......@@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
image_processor = self._get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
......@@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
num_images = mm_counts.get("image", 0)
video_token: str = hf_processor.video_token
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=resized_width,
height=resized_height,
num_images=num_images)
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
......
......@@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):
......
......@@ -6,7 +6,6 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
......@@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
......@@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
......@@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
if not audios:
return super()._call_hf_processor(
......@@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate=feature_extractor.sampling_rate,
)
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
audio_features, audio_token_len = [], []
......@@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
)
]
def _get_dummy_mm_inputs(
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
......
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Final, Generic, Optional, Protocol, TypeVar
from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig)
......@@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError
@abstractmethod
def get_num_patches(self) -> int:
def get_image_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_size(self) -> int:
def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError
......@@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError
......@@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video")
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems):
......
This diff is collapsed.
......@@ -223,7 +223,8 @@ class MultiModalRegistry:
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item()
seq_len = model_config.max_model_len
return processor.get_mm_max_tokens_per_item(seq_len)
return {
key: plugin.get_max_multimodal_tokens(model_config)
......
......@@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def encode_tokens(
tokenizer: AnyTokenizer,
text: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment