Unverified Commit 56d4aefa authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Avoid unnecessary dummy multimodal data during processing (#16416)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent dd143ef5
......@@ -15,12 +15,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
......@@ -87,29 +86,29 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
return "<image>\n" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
prompt_text="<image>\n" * num_images,
mm_data=mm_data,
)
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
......
......@@ -19,7 +19,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......@@ -90,29 +90,27 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
class PaliGemmaDummyInputsBuilder(
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class PaliGemmaMultiModalProcessor(
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
......
......@@ -32,7 +32,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
# yapf conflicts with isort for this block
......@@ -42,7 +43,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -343,31 +344,31 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return "".join(image_tokens[:num_images])
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=mm_data,
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
......
......@@ -32,13 +32,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
......@@ -203,28 +204,26 @@ class PixtralProcessingInfo(BaseProcessingInfo):
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
):
......
......@@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
......@@ -49,20 +49,21 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
return ProcessorInputs(
prompt_text="",
# This model input is fixed and is in the form of a torch Tensor.
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
mm_data={
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0)
})
) -> MultiModalDataDict:
# This model input is fixed and is in the form of a torch Tensor.
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0),
}
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
......
......@@ -37,13 +37,14 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......@@ -113,27 +114,30 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token
return audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
return {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|AUDIO|>" * num_audios,
mm_data=mm_data,
)
class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
......
......@@ -56,15 +56,15 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
VideoItem)
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, VideoItem)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
......@@ -965,11 +965,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
......@@ -977,12 +973,22 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
......@@ -996,11 +1002,6 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
......
......@@ -32,12 +32,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
......@@ -542,34 +543,34 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor()
img_start = hf_processor.image_start_tag
img_end = hf_processor.image_end_tag
return "".join(f"Picture {i}: {img_start}{img_end}\n"
for i in range(1, num_images + 1))
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
hf_config = self.info.get_hf_config()
vision_config = hf_config.visual
processor = self.info.get_hf_processor()
img_start = processor.image_start_tag
img_end = processor.image_end_tag
target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
for i in range(1, num_images + 1)),
mm_data=mm_data,
)
class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
......
......@@ -26,14 +26,14 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -505,27 +505,27 @@ _I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return "<image>" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
......
......@@ -23,13 +23,13 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
......@@ -110,11 +110,16 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return "<|audio|>" * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
......@@ -122,16 +127,11 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
_MAX_ENCODER_BATCH_SIZE)
num_audios = mm_counts.get("audio", 0)
mm_data = {
return {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|audio|>" * num_audios,
mm_data=mm_data,
)
class UltravoxMultiModalProcessor(
BaseMultiModalProcessor[UltravoxProcessingInfo]):
......
......@@ -26,13 +26,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
......@@ -544,27 +544,27 @@ class WhisperProcessingInfo(BaseProcessingInfo):
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return "<|startoftranscript|>" * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
return {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|startoftranscript|>" * num_audios,
mm_data=mm_data,
)
class WhisperMultiModalProcessor(
EncDecMultiModalProcessor[WhisperProcessingInfo]):
......
......@@ -1051,12 +1051,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
if get_repls := getattr(self, "_get_prompt_replacements", None):
logger.warning_once("`_get_prompt_replacements` has been renamed "
"to `_get_prompt_updates`. The old name will "
"be removed in an upcoming release.")
self._get_prompt_updates = get_repls # type: ignore[method-assign]
super().__init__()
self.info = info
......@@ -1274,13 +1268,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
mm_counts = mm_items.get_all_counts()
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.info.ctx.model_config.max_model_len,
mm_counts,
)
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
prompt_text=dummy_inputs.prompt_text,
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, cast
......@@ -60,7 +60,35 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
self.info = info
@abstractmethod
# TODO: @abstractmethod after transition
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
"""
Build the text input corresponding to :code:`mm_counts`.
"""
if (type(self).get_dummy_processor_inputs ==
BaseDummyInputsBuilder.get_dummy_processor_inputs):
raise NotImplementedError
logger.warning_once("`get_dummy_processor_inputs` has been split up "
"into `get_dummy_text` and `get_dummy_mm_data`. "
"These two methods will be marked as abstract "
"in an upcoming release.")
seq_len = self.info.ctx.model_config.max_model_len
return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text
# TODO: @abstractmethod after transition
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
"""
Build the multimodal input which, after processing, results in
the maximum possible number of placeholder tokens.
"""
raise NotImplementedError
def get_dummy_processor_inputs(
self,
seq_len: int,
......@@ -70,7 +98,10 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
Build the input which, after processing, results in
the maximum possible number of placeholder tokens.
"""
raise NotImplementedError
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data)
def _get_dummy_audios(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment