Unverified Commit eed11ebe authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 300acb83
...@@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_height, height=image_height,
) )
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens( max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
...@@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]: ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids, token_ids=token_ids,
prompt_repls=prompt_repls, mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts, mm_item_counts=mm_item_counts,
) )
...@@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
if text.startswith("<s> <|image|>"): if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1) text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]] token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [ placeholders = {
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement) modality: [
for p in placeholders _PlaceholderInfo(
] modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
replacement=p.replacement,
) for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders return token_ids, text, placeholders
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
......
...@@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ...@@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config) return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_pixtral_hf_patch_grid_length( return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
class PixtralHFMLP(nn.Module): class PixtralHFMLP(nn.Module):
......
...@@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): ...@@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1 max_output_lengths = (max_source_positions - 2) // 2 + 1
...@@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): ...@@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
] ]
def _always_apply_prompt_replacements(self) -> bool: def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves # HF never applies prompt replacements, so we have to do it ourselves.
# _find_placeholders may incorrectly think that HF has already performed # NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# processing for multi-audio input when the input audios are short # has already performed processing for multi-audio input when the input
# (the corresponding placeholders may take up fewer tokens than # audios are short (the corresponding placeholders may take up fewer
# the number of audio items) # tokens than the number of audio items)
return True return True
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
......
...@@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem) NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
...@@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params return loaded_params
# === Vision input helpers === #
def _get_vision_info(
vision_config: Qwen2VLVisionConfig,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
*,
do_resize: bool = True,
modality: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if modality == "image":
grid_t = mm_count
elif modality == "video":
grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = vision_tokens // (merge_size**2)
return resized_height, resized_width, llm_num_vision_tokens
def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]): dict[str, torch.Tensor]]):
...@@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int: def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
) -> tuple[ImageSize, int]:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig) hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info( if do_resize:
vision_config, resized_height, resized_width = smart_resize(
height=9999999, height=image_height,
width=9999999, width=image_width,
min_pixels=image_processor.min_pixels, factor=patch_size * merge_size,
max_pixels=image_processor.max_pixels, min_pixels=image_processor.min_pixels,
modality=modality, max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
grid_t = max(num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_dummy_image_size(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
_, max_image_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_tokens
def _get_max_video_tokens(self, num_frames: int) -> int:
_, max_video_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=num_frames,
) )
return max_llm_image_tokens return max_video_tokens
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return { return {
"image": self._get_max_mm_tokens("image"), "image": max_image_tokens,
"video": self._get_max_mm_tokens("video"), "video": max_video_tokens,
} }
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser() return Qwen2MultiModalDataParser()
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_hf_processor( def _get_hf_processor(
self, self,
*, *,
...@@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor: ) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
if min_pixels: if min_pixels:
image_processor.min_pixels = min_pixels image_processor.min_pixels = min_pixels
...@@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered # image_token and video_token registered
...@@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video"),
) )
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_processor = self._get_hf_processor() num_images = mm_counts.get("image", 0)
image_processor = _get_image_processor(hf_processor) num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize( video_token: str = hf_processor.video_token
height=9999999, target_width, target_height = self._get_dummy_image_size()
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
num_images = mm_counts.get("image", 0)
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=resized_width, self._get_dummy_images(width=target_width,
height=resized_height, height=target_height,
num_images=num_images) num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
} }
return ProcessorInputs( return ProcessorInputs(
prompt_text=image_token * num_images, prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data, mm_data=mm_data,
) )
......
...@@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): ...@@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config) return get_max_siglip_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length( return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module): class SiglipVisionEmbeddings(nn.Module):
......
...@@ -6,7 +6,6 @@ from functools import cached_property ...@@ -6,7 +6,6 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...@@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length * max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND) _AUDIO_TOKENS_PER_SECOND)
...@@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_data = dict(mm_data) mm_data = dict(mm_data)
audios = mm_data.pop("audios", []) audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
if not audios: if not audios:
return super()._call_hf_processor( return super()._call_hf_processor(
...@@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
) )
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs, # Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one # therefore we need to input text and audio one by one
audio_features, audio_token_len = [], [] audio_features, audio_token_len = [], []
...@@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, TypeVar from typing import Final, Generic, Optional, Protocol, TypeVar
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig) _C = TypeVar("_C", bound=PretrainedConfig)
...@@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]): ...@@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_num_patches(self) -> int: def get_image_size(self) -> int:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_image_size(self) -> int: def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError raise NotImplementedError
...@@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: ...@@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError
...@@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): ...@@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None: def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video") super().__init__(data, "video")
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems): class VideoEmbeddingItems(EmbeddingItems):
......
...@@ -16,7 +16,8 @@ from transformers import BatchFeature, ProcessorMixin ...@@ -16,7 +16,8 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig, from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -69,19 +70,6 @@ def _cached_encode( ...@@ -69,19 +70,6 @@ def _cached_encode(
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
def _decode(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_decode( def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
...@@ -89,9 +77,9 @@ def _cached_decode( ...@@ -89,9 +77,9 @@ def _cached_decode(
*, *,
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
) -> str: ) -> str:
return _decode(tokenizer, return decode_tokens(tokenizer,
list(token_ids), list(token_ids),
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens)
class _HasModalityAttr(Protocol): class _HasModalityAttr(Protocol):
...@@ -269,8 +257,10 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch): ...@@ -269,8 +257,10 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
return self.match.end() return self.match.end()
class _PlaceholderInfo(NamedTuple): @dataclass
class _PlaceholderInfo:
modality: str modality: str
item_idx: int
start_idx: int start_idx: int
replacement: list[int] replacement: list[int]
...@@ -311,12 +301,14 @@ def find_text_matches( ...@@ -311,12 +301,14 @@ def find_text_matches(
def _resolve_matches( def _resolve_matches(
prompt: _PromptSeq, prompt: _PromptSeq,
matches: Sequence[_PromptReplacementMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
) -> list[_PromptReplacementMatch]: ) -> list[_PromptReplacementMatch]:
""" """
Resolve :code:`matches` to ensure that there are no overlapping matches, Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones. and sort them such that earlier matches take priority over later ones.
""" """
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptReplacementMatch]] = [None seen_matches: list[Optional[_PromptReplacementMatch]] = [None
] * len(prompt) ] * len(prompt)
...@@ -334,14 +326,15 @@ def _resolve_matches( ...@@ -334,14 +326,15 @@ def _resolve_matches(
def _replace_matches( def _replace_matches(
prompt: _S, prompt: _S,
matches: Sequence[_PromptReplacementMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[_S]: ) -> list[_S]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
out_seqs = list[_S]() out_seqs = list[_S]()
prev_end_idx = 0 prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0) next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, matches): for match in _resolve_matches(prompt, mm_matches):
modality = match.modality modality = match.modality
item_idx = next_idx_by_modality[modality] item_idx = next_idx_by_modality[modality]
...@@ -371,28 +364,28 @@ def _replace_matches( ...@@ -371,28 +364,28 @@ def _replace_matches(
def replace_token_matches( def replace_token_matches(
prompt: list[int], prompt: list[int],
matches: Sequence[_PromptReplacementTokenMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[int]: ) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`.""" """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
if not matches: if not mm_matches:
return prompt return prompt
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts) token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
return flatten_2d_lists(token_id_seqs) return flatten_2d_lists(token_id_seqs)
def replace_text_matches( def replace_text_matches(
prompt: str, prompt: str,
matches: Sequence[_PromptReplacementTextMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> str: ) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`.""" """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
if not matches: if not mm_matches:
return prompt return prompt
texts = _replace_matches(prompt, matches, mm_item_counts) texts = _replace_matches(prompt, mm_matches, mm_item_counts)
return "".join(texts) return "".join(texts)
...@@ -407,14 +400,14 @@ def _iter_modality_placeholders( ...@@ -407,14 +400,14 @@ def _iter_modality_placeholders(
return return
prompt_len = len(prompt) prompt_len = len(prompt)
item_index = 0 item_idx = 0
start_idx = 0 start_idx = 0
while start_idx < prompt_len: while start_idx < prompt_len:
found = False found = False
for repl_info in modality_repls: for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_index) replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids repl_tokens = replacement.token_ids
repl_len = len(repl_tokens) repl_len = len(repl_tokens)
end_idx = start_idx + repl_len end_idx = start_idx + repl_len
...@@ -425,12 +418,13 @@ def _iter_modality_placeholders( ...@@ -425,12 +418,13 @@ def _iter_modality_placeholders(
if prompt[start_idx:end_idx] == repl_tokens: if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo( yield _PlaceholderInfo(
modality=modality, modality=modality,
item_idx=item_idx,
start_idx=start_idx, start_idx=start_idx,
replacement=repl_tokens, replacement=repl_tokens,
) )
item_index += 1 item_idx += 1
if item_index >= modal_item_count: if item_idx >= modal_item_count:
return return
# Exclude overlapping matches # Exclude overlapping matches
...@@ -442,28 +436,36 @@ def _iter_modality_placeholders( ...@@ -442,28 +436,36 @@ def _iter_modality_placeholders(
start_idx += 1 start_idx += 1
def iter_placeholders( def _iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[_PlaceholderInfo]:
""" """
Yield each set of placeholder tokens found in :code:`prompt`. For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
Note that empty matches are ignored. Note that empty matches are ignored.
""" """
repls_by_modality = dict(full_groupby_modality(prompt_repls))
for modality, modal_item_count in mm_item_counts.items(): for modality, modal_item_count in mm_item_counts.items():
if modality in repls_by_modality: if modality in mm_prompt_repls:
yield from _iter_modality_placeholders( yield from _iter_modality_placeholders(
prompt, prompt,
modality, modality,
repls_by_modality[modality], mm_prompt_repls[modality],
modal_item_count, modal_item_count,
) )
def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
@dataclass @dataclass
class ProcessorInputs: class ProcessorInputs:
"""Keyword arguments to :meth:`BaseMultiModalProcessor`.""" """Keyword arguments to :meth:`BaseMultiModalProcessor`."""
...@@ -620,7 +622,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -620,7 +622,7 @@ class BaseMultiModalProcessor(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
""" """
Get the maximum possible number of tokens per data item Get the maximum possible number of tokens per data item
for each modality. for each modality.
...@@ -703,14 +705,14 @@ class BaseMultiModalProcessor(ABC): ...@@ -703,14 +705,14 @@ class BaseMultiModalProcessor(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def _find_placeholders( def _find_mm_placeholders(
self, self,
all_prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[_PlaceholderInfo]: ) -> Mapping[str, list[_PlaceholderInfo]]:
return list( return find_mm_placeholders(mm_prompt_repls, new_token_ids,
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) mm_item_counts)
def _get_hf_mm_data( def _get_hf_mm_data(
self, self,
...@@ -797,7 +799,10 @@ class BaseMultiModalProcessor(ABC): ...@@ -797,7 +799,10 @@ class BaseMultiModalProcessor(ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding # Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text # multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts) dummy_inputs = self._get_dummy_processor_inputs(
self.ctx.model_config.max_model_len,
mm_missing_counts,
)
_, mm_missing_kwargs = self._apply_hf_processor( _, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=dummy_inputs.prompt_text, prompt_text=dummy_inputs.prompt_text,
...@@ -889,50 +894,44 @@ class BaseMultiModalProcessor(ABC): ...@@ -889,50 +894,44 @@ class BaseMultiModalProcessor(ABC):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_all_counts()
for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
try:
mm_kwargs.get_item(modality, item_idx)
except Exception as e:
# Make it easy to set a breakpoint in the debugger
raise e
return prompt_ids, mm_kwargs return prompt_ids, mm_kwargs
def _bind_prompt_replacements( def _bind_and_group_repls(
self, self,
prompt_repls: list[PromptReplacement], prompt_repls: list[PromptReplacement],
) -> list[_BoundPromptReplacement]: ) -> dict[str, list[_BoundPromptReplacement]]:
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls] it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it))
def _always_apply_prompt_replacements(self) -> bool: def _always_apply_prompt_replacements(self) -> bool:
""" """
A flag which can be overridden so that A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we :meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via :meth:`_find_placeholders`. detect that HF has performed processing via
:meth:`_find_placeholders_by_modality`.
This is useful in cases where :meth:`_find_placeholders` cannot be This is useful in cases where :meth:`_find_placeholders_by_modality`
reliably used to detect whether HF has performed processing or not. cannot be reliably used to detect whether HF has performed processing.
""" """
return False return False
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]: ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
token_matches = find_token_matches(token_ids, prompt_repls) mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
mm_match_counts = { mm_match_counts = {
modality: len(matches) modality: len(matches)
for modality, matches in full_groupby_modality(token_matches) for modality, matches in mm_token_matches.items()
} }
# If the search text does not represent a special token, # If the search text does not represent a special token,
...@@ -951,32 +950,92 @@ class BaseMultiModalProcessor(ABC): ...@@ -951,32 +950,92 @@ class BaseMultiModalProcessor(ABC):
): # yapf: disable ): # yapf: disable
token_ids = replace_token_matches( token_ids = replace_token_matches(
token_ids, token_ids,
token_matches, mm_token_matches,
mm_item_counts, mm_item_counts,
) )
text = _decode(tokenizer, token_ids) text = decode_tokens(tokenizer, token_ids)
matched_repls = [match.prompt_repl for match in token_matches] matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_token_matches.items()
}
else: else:
text = _decode(tokenizer, token_ids) text = decode_tokens(tokenizer, token_ids)
text_matches = find_text_matches(text, prompt_repls) mm_text_matches = {
modality: find_text_matches(text, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
text = replace_text_matches( text = replace_text_matches(
text, text,
text_matches, mm_text_matches,
mm_item_counts, mm_item_counts,
) )
token_ids = encode_tokens(tokenizer, token_ids = encode_tokens(tokenizer,
text, text,
add_special_tokens=False) add_special_tokens=False)
matched_repls = [match.prompt_repl for match in text_matches] matched_repls = {
modality: [match.prompt_repl for match in token_matches]
placeholders = self._find_placeholders(matched_repls, token_ids, for modality, token_matches in mm_text_matches.items()
mm_item_counts) }
placeholders = self._find_mm_placeholders(
matched_repls,
token_ids,
mm_item_counts,
)
return token_ids, text, placeholders return token_ids, text, placeholders
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargs,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
if modality in mm_kwargs.modalities:
items = mm_kwargs.get_items(modality)
else:
items = []
if len(items) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} {modality} items in "
f"keyword arguments corresponding to {item_count} "
f"{modality} data items, but only found {len(items)}! "
"There is likely a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_mm_fields_config`).")
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,
) -> Mapping[str, int]:
missing_repl_counts = dict[str, int]()
for modality, item_count in mm_item_counts.items():
placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count and not allow_missing:
raise RuntimeError(
f"Expected there to be {item_count} prompt replacements "
f"corresponding to {item_count} {modality} items, but only "
f"found {len(placeholders)} prompt replacements! Either "
"the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).")
missing_repl_counts[modality] = item_count - len(placeholders)
return missing_repl_counts
def apply( def apply(
self, self,
prompt_text: str, prompt_text: str,
...@@ -1009,56 +1068,69 @@ class BaseMultiModalProcessor(ABC): ...@@ -1009,56 +1068,69 @@ class BaseMultiModalProcessor(ABC):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls) mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
mm_item_counts)
hf_mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
prompt_ids,
mm_item_counts,
)
if self._always_apply_prompt_replacements():
mm_missing_repl_counts = mm_item_counts
mm_missing_repls = dict(mm_prompt_repls)
else:
mm_missing_repl_counts = self._validate_mm_placeholders(
hf_mm_placeholders,
mm_item_counts,
allow_missing=True,
)
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
elif missing_repl_count == mm_item_counts.get(modality, 0):
mm_missing_repls[modality] = mm_prompt_repls[modality]
else:
raise ValueError("Partial prompt replacement within "
f"{modality=} is not supported")
if all_placeholders and not self._always_apply_prompt_replacements(): # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids) prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else: else:
( (
prompt_ids, prompt_ids,
prompt_text, prompt_text,
all_placeholders, missing_mm_placeholders,
) = self._apply_prompt_replacements( ) = self._apply_prompt_replacements(
prompt_ids, prompt_ids,
prompt_repls, mm_missing_repls,
mm_item_counts, mm_missing_repl_counts,
) )
mm_placeholders = dict[str, list[PlaceholderRange]]() mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
err_suffix = ("This suggests a problem with your implementation of "
"the merged multi-modal processor for this model, " self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
"particularly in the `_get_prompt_replacements` method.")
mm_placeholder_ranges = {
for modality, placeholders in full_groupby_modality(all_placeholders): modality: [item.to_range() for item in placeholders]
if modality not in mm_items: for modality, placeholders in mm_placeholders.items()
raise AssertionError( }
f"Expected no placeholders for {modality=}, "
f"but found {placeholders=}. Input items: {mm_items}"
f"\n{err_suffix}")
if len(placeholders) != len(mm_items[modality]):
raise AssertionError(
f"Expected length of {placeholders=} for {modality=} "
f"to equal that of input items: {mm_items[modality]}"
f"\n{err_suffix}")
mm_placeholders[modality] = [
item.to_range() for item in placeholders
]
return MultiModalInputsV2( return MultiModalInputsV2(
type="multimodal", type="multimodal",
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders, mm_placeholders=mm_placeholder_ranges,
) )
def _get_dummy_audios( def _get_dummy_audios(
...@@ -1092,8 +1164,9 @@ class BaseMultiModalProcessor(ABC): ...@@ -1092,8 +1164,9 @@ class BaseMultiModalProcessor(ABC):
return [video] * num_videos return [video] * num_videos
@abstractmethod @abstractmethod
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
...@@ -1121,12 +1194,25 @@ class BaseMultiModalProcessor(ABC): ...@@ -1121,12 +1194,25 @@ class BaseMultiModalProcessor(ABC):
return mm_limits return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData: def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
mm_counts = self._get_and_validate_dummy_mm_counts() mm_counts = self._get_and_validate_dummy_mm_counts()
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item() mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys(): if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError( raise AssertionError(
"The keys returned by `get_supported_mm_limits`" "The keys returned by `get_supported_mm_limits`"
...@@ -1134,13 +1220,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -1134,13 +1220,7 @@ class BaseMultiModalProcessor(ABC):
"returned by `get_mm_max_tokens_per_item` " "returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})") f"({set(mm_max_tokens_per_item.keys())})")
processor_inputs = self._get_dummy_mm_inputs(mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs = self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
...@@ -1171,6 +1251,12 @@ class BaseMultiModalProcessor(ABC): ...@@ -1171,6 +1251,12 @@ class BaseMultiModalProcessor(ABC):
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality) total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData( return DummyData(
......
...@@ -223,7 +223,8 @@ class MultiModalRegistry: ...@@ -223,7 +223,8 @@ class MultiModalRegistry:
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item() seq_len = model_config.max_model_len
return processor.get_mm_max_tokens_per_item(seq_len)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)
......
...@@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ...@@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer] MistralTokenizer]
def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def encode_tokens( def encode_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment