Unverified Commit eed11ebe authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 300acb83
...@@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_height, height=image_height,
) )
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens( max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
...@@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]: ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids, token_ids=token_ids,
prompt_repls=prompt_repls, mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts, mm_item_counts=mm_item_counts,
) )
...@@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
if text.startswith("<s> <|image|>"): if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1) text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]] token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [ placeholders = {
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement) modality: [
for p in placeholders _PlaceholderInfo(
] modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
replacement=p.replacement,
) for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders return token_ids, text, placeholders
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
......
...@@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ...@@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config) return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_pixtral_hf_patch_grid_length( return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
class PixtralHFMLP(nn.Module): class PixtralHFMLP(nn.Module):
......
...@@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): ...@@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1 max_output_lengths = (max_source_positions - 2) // 2 + 1
...@@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): ...@@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
] ]
def _always_apply_prompt_replacements(self) -> bool: def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves # HF never applies prompt replacements, so we have to do it ourselves.
# _find_placeholders may incorrectly think that HF has already performed # NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# processing for multi-audio input when the input audios are short # has already performed processing for multi-audio input when the input
# (the corresponding placeholders may take up fewer tokens than # audios are short (the corresponding placeholders may take up fewer
# the number of audio items) # tokens than the number of audio items)
return True return True
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
......
...@@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem) NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
...@@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params return loaded_params
# === Vision input helpers === #
def _get_vision_info(
vision_config: Qwen2VLVisionConfig,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
*,
do_resize: bool = True,
modality: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if modality == "image":
grid_t = mm_count
elif modality == "video":
grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = vision_tokens // (merge_size**2)
return resized_height, resized_width, llm_num_vision_tokens
def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]): dict[str, torch.Tensor]]):
...@@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int: def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
) -> tuple[ImageSize, int]:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig) hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info( if do_resize:
vision_config, resized_height, resized_width = smart_resize(
height=9999999, height=image_height,
width=9999999, width=image_width,
min_pixels=image_processor.min_pixels, factor=patch_size * merge_size,
max_pixels=image_processor.max_pixels, min_pixels=image_processor.min_pixels,
modality=modality, max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
grid_t = max(num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_dummy_image_size(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
_, max_image_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_tokens
def _get_max_video_tokens(self, num_frames: int) -> int:
_, max_video_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=num_frames,
) )
return max_llm_image_tokens return max_video_tokens
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return { return {
"image": self._get_max_mm_tokens("image"), "image": max_image_tokens,
"video": self._get_max_mm_tokens("video"), "video": max_video_tokens,
} }
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser() return Qwen2MultiModalDataParser()
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_hf_processor( def _get_hf_processor(
self, self,
*, *,
...@@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor: ) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
if min_pixels: if min_pixels:
image_processor.min_pixels = min_pixels image_processor.min_pixels = min_pixels
...@@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered # image_token and video_token registered
...@@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video"),
) )
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_processor = self._get_hf_processor() num_images = mm_counts.get("image", 0)
image_processor = _get_image_processor(hf_processor) num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize( video_token: str = hf_processor.video_token
height=9999999, target_width, target_height = self._get_dummy_image_size()
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
num_images = mm_counts.get("image", 0)
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=resized_width, self._get_dummy_images(width=target_width,
height=resized_height, height=target_height,
num_images=num_images) num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
} }
return ProcessorInputs( return ProcessorInputs(
prompt_text=image_token * num_images, prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data, mm_data=mm_data,
) )
......
...@@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): ...@@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config) return get_max_siglip_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length( return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module): class SiglipVisionEmbeddings(nn.Module):
......
...@@ -6,7 +6,6 @@ from functools import cached_property ...@@ -6,7 +6,6 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...@@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length * max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND) _AUDIO_TOKENS_PER_SECOND)
...@@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_data = dict(mm_data) mm_data = dict(mm_data)
audios = mm_data.pop("audios", []) audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
if not audios: if not audios:
return super()._call_hf_processor( return super()._call_hf_processor(
...@@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
) )
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs, # Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one # therefore we need to input text and audio one by one
audio_features, audio_token_len = [], [] audio_features, audio_token_len = [], []
...@@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, TypeVar from typing import Final, Generic, Optional, Protocol, TypeVar
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig) _C = TypeVar("_C", bound=PretrainedConfig)
...@@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]): ...@@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_num_patches(self) -> int: def get_image_size(self) -> int:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_image_size(self) -> int: def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError raise NotImplementedError
...@@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: ...@@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError
...@@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): ...@@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None: def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video") super().__init__(data, "video")
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems): class VideoEmbeddingItems(EmbeddingItems):
......
This diff is collapsed.
...@@ -223,7 +223,8 @@ class MultiModalRegistry: ...@@ -223,7 +223,8 @@ class MultiModalRegistry:
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item() seq_len = model_config.max_model_len
return processor.get_mm_max_tokens_per_item(seq_len)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)
......
...@@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ...@@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer] MistralTokenizer]
def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def encode_tokens( def encode_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment