Unverified Commit d1ca7df8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for InternVL-based models (#12553)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 96b23621
...@@ -62,7 +62,11 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -62,7 +62,11 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1} return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens( max_video_tokens = self.get_num_video_tokens(
......
...@@ -103,7 +103,11 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -103,7 +103,11 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), "image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len), "video": self.get_max_video_tokens(seq_len),
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial from functools import partial
from itertools import accumulate
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
...@@ -138,11 +137,15 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -138,11 +137,15 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None} return {"image": None, "video": None, "audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), "image": self.get_max_image_tokens(),
"audio": self.get_max_audio_tokens(), "audio": self.get_max_audio_tokens(),
"video": self.get_max_video_tokens(seq_len) "video": self.get_max_video_tokens(seq_len),
} }
def get_default_audio_pool_step(self) -> int: def get_default_audio_pool_step(self) -> int:
...@@ -369,23 +372,18 @@ class MiniCPMOMultiModalProcessor( ...@@ -369,23 +372,18 @@ class MiniCPMOMultiModalProcessor(
hf_inputs, hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
def get_slices(num_slices: List[int]) -> List[int]:
slice_indices = [0] + list(accumulate(num_slices))
slices = [(slice_indices[i], slice_indices[i + 1])
for i in range(len(num_slices))]
return [slice(*slice_item) for slice_item in slices]
audio_slices = get_slices(
hf_inputs.get("audio_num_slices", torch.empty(0)))
return dict( return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
audio_features=MultiModalFieldConfig.flat("audio", audio_slices), audio_features=MultiModalFieldConfig.flat_from_sizes(
audio_feature_lens=MultiModalFieldConfig.flat( "audio", audio_num_slices),
"audio", audio_slices), audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"), audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices)) audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices))
class MultiModalProjector(nn.Module): class MultiModalProjector(nn.Module):
......
...@@ -26,7 +26,6 @@ import math ...@@ -26,7 +26,6 @@ import math
import re import re
from collections import Counter from collections import Counter
from functools import cached_property, partial from functools import cached_property, partial
from itertools import accumulate
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union) Optional, Set, Tuple, TypedDict, Union)
...@@ -365,7 +364,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -365,7 +364,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
else: else:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()} mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
...@@ -761,30 +764,25 @@ class MiniCPMVMultiModalProcessor( ...@@ -761,30 +764,25 @@ class MiniCPMVMultiModalProcessor(
hf_inputs, hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
def get_slices(num_slices: List[int]) -> List[int]: video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
slice_indices = [0] + list(accumulate(num_slices))
slices = [(slice_indices[i], slice_indices[i + 1]) return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
for i in range(len(num_slices))] "image", image_num_slices),
return [slice(*slice_item) for slice_item in slices] image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
image_slices = get_slices( "image", image_num_slices),
hf_inputs.get("image_num_slices", torch.empty(0))) image_num_slices=MultiModalFieldConfig.batched("image"),
video_slices = get_slices( image_embeds=MultiModalFieldConfig.flat_from_sizes(
hf_inputs.get("video_num_slices", torch.empty(0))) "image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
return dict( "video", video_num_slices),
pixel_values=MultiModalFieldConfig.flat("image", image_slices), video_image_sizes=MultiModalFieldConfig.batched("video"),
image_sizes=MultiModalFieldConfig.batched("image"), video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
tgt_sizes=MultiModalFieldConfig.flat("image", image_slices), "video", video_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"), video_embeds=MultiModalFieldConfig.flat_from_sizes(
image_embeds=MultiModalFieldConfig.flat("image", image_slices), "video", video_num_slices),
video_pixel_values=MultiModalFieldConfig.flat( video_num_slices=MultiModalFieldConfig.batched("video"))
"video", video_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat("video", video_slices),
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
video_num_slices=MultiModalFieldConfig.batched("video"))
def apply( def apply(
self, self,
......
...@@ -6,44 +6,190 @@ ...@@ -6,44 +6,190 @@
# Copyright (c) 2024 NVIDIA # Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Optional from typing import Mapping, Optional
import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
from .internvl import (InternVLChatModel, InternVLInputPipeline, from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
get_max_internvl_image_tokens) InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor)
IMG_START = '<|vision_start|>' IMG_PAD = "<|vision_pad|>"
IMG_END = '<|vision_end|>'
IMG_CONTEXT = '<|vision_pad|>'
class NVLMInputPipeline(InternVLInputPipeline): class NVLMProcessor(BaseInternVLProcessor):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_PAD]
def get_image_repl_features(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
if num_patches is None:
raise NotImplementedError("Embedding inputs are not supported")
tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)]
if self.use_thumbnail and num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
tile_pos_identifiers = ([f"<tile_{i}>"
for i in range(1, num_patches)] +
["<tile_global_thumbnail>"])
context_size = feature_size // num_patches context_size = feature_size // num_patches
features = "".join(identifier + IMG_PAD * context_size
for identifier in tile_pos_identifiers)
# We include the start and end as well because "<Image><tile" is
# tokenized as ["<Image", "><", "tile"], resulting in assertion error
# when trying to find "<tile" as a subsequence of "<Image><tile"
return "<Image>" + features + "</Image>"
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return self.get_image_repl_features(feature_size, num_patches)
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> NVLMProcessor:
return NVLMProcessor(
self.get_hf_config(),
self.get_tokenizer(),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
tokenizer = hf_processor.tokenizer
max_num_patches = hf_processor.max_dynamic_patch
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers = [
f"<tile_{i+1}>" for i in range(max_num_patches)
]
if hf_processor.use_thumbnail and max_num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
# so we include <tile_1> in the start_str
start_str = "<Image>" + tile_pos_identifiers.pop(0)
end_str = "</Image>"
start_token_len = len(tokenizer.encode(start_str))
end_token_len = len(tokenizer.encode(end_str))
tile_token_len = sum(
len(tokenizer.encode(identifier))
for identifier in tile_pos_identifiers)
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
return super().get_max_image_tokens() + non_image_tokens_num
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
prompt_text="<image>\n" * num_images,
mm_data=mm_data,
)
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
return '<Image>' + ''.join( def _get_prompt_replacements(
tile_pos_identifier + self.img_context_token * context_size self,
for tile_pos_identifier in tile_pos_identifiers) + '</Image>' mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
image_num_patches = []
def get_replacement_nvlm(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n",
features=hf_processor.get_image_repl_features(
feature_size, num_patches) + "\n",
)
input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) # See note in dummy data regarding why we have the extra newline
return [
PromptReplacement(
modality="image",
target="<image>\n",
replacement=get_replacement_nvlm,
)
]
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) @MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor,
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) info=NVLMProcessingInfo,
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) dummy_inputs=NVLMDummyInputsBuilder)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class NVLM_D_Model(InternVLChatModel): class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
......
...@@ -322,7 +322,11 @@ class Phi3VProcessingInfo(BaseProcessingInfo): ...@@ -322,7 +322,11 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens( max_image_tokens = self.get_num_image_tokens(
......
...@@ -779,7 +779,11 @@ class QWenVLProcessingInfo(BaseProcessingInfo): ...@@ -779,7 +779,11 @@ class QWenVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()} return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
...@@ -799,13 +803,13 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): ...@@ -799,13 +803,13 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
vision_config = hf_config.visual vision_config = hf_config.visual
max_image_size = vision_config["image_size"] target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=max_image_size, self._get_dummy_images(width=target_width,
height=max_image_size, height=target_height,
num_images=num_images) num_images=num_images)
} }
......
...@@ -110,7 +110,11 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): ...@@ -110,7 +110,11 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1 max_output_lengths = (max_source_positions - 2) // 2 + 1
......
...@@ -758,7 +758,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -758,7 +758,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), "image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len), "video": self.get_max_video_tokens(seq_len),
...@@ -989,26 +993,21 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -989,26 +993,21 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist() image_grid_sizes = image_grid_thw.prod(-1)
image_slices = [
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
for i in range(len(image_grid_thw))
]
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist() video_grid_sizes = video_grid_thw.prod(-1)
video_slices = [
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
for i in range(len(video_grid_thw))
]
return dict( return dict(
pixel_values=MultiModalFieldConfig.flat("image", image_slices), pixel_values=MultiModalFieldConfig.flat_from_sizes(
image_embeds=MultiModalFieldConfig.flat("image", image_slices), "image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"), image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat( pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_slices), "video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat("video", video_slices), video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video"),
) )
......
...@@ -92,7 +92,11 @@ class UltravoxProcessingInfo(BaseProcessingInfo): ...@@ -92,7 +92,11 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length * max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND) _AUDIO_TOKENS_PER_SECOND)
......
...@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final) Union, cast, final)
...@@ -258,6 +259,16 @@ class MultiModalFieldConfig: ...@@ -258,6 +259,16 @@ class MultiModalFieldConfig:
slices=slices, slices=slices,
) )
@staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(size_per_item))
]
return MultiModalFieldConfig.flat(modality, slices)
def __init__( def __init__(
self, self,
field_cls: type[BaseMultiModalField], field_cls: type[BaseMultiModalField],
......
...@@ -680,7 +680,11 @@ class BaseProcessingInfo: ...@@ -680,7 +680,11 @@ class BaseProcessingInfo:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
""" """
Get the maximum possible number of tokens per data item Get the maximum possible number of tokens per data item
for each modality. for each modality.
......
...@@ -151,7 +151,8 @@ class MultiModalProfiler(Generic[_I]): ...@@ -151,7 +151,8 @@ class MultiModalProfiler(Generic[_I]):
mm_counts = self.get_mm_limits() mm_counts = self.get_mm_limits()
info = self.processing_info info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts)
if mm_counts.keys() != mm_max_tokens_per_item.keys(): if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError( raise AssertionError(
......
...@@ -264,7 +264,9 @@ class MultiModalRegistry: ...@@ -264,7 +264,9 @@ class MultiModalRegistry:
) )
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
return processor.info.get_mm_max_tokens_per_item(seq_len) mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
seq_len, mm_limits)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment