Unverified Commit 803d5c35 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[V1] Override `mm_counts` for dummy data creation (#15703)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7fd8c0f8
...@@ -385,18 +385,7 @@ VLM_TEST_SETTINGS = { ...@@ -385,18 +385,7 @@ VLM_TEST_SETTINGS = {
), ),
"minicpmo_26": VLMTestInfo( "minicpmo_26": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"], models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
),
"minicpmo_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096, max_model_len=4096,
...@@ -404,22 +393,10 @@ VLM_TEST_SETTINGS = { ...@@ -404,22 +393,10 @@ VLM_TEST_SETTINGS = {
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
), ),
"minicpmv_26": VLMTestInfo( "minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"], models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
),
"minicpmv_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096, max_model_len=4096,
...@@ -427,7 +404,6 @@ VLM_TEST_SETTINGS = { ...@@ -427,7 +404,6 @@ VLM_TEST_SETTINGS = {
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
), ),
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
......
...@@ -71,7 +71,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -71,7 +71,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
max_video_tokens = self.get_num_video_tokens( max_video_tokens = self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len), num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
) )
return {"video": max_video_tokens} return {"video": max_video_tokens}
...@@ -130,9 +131,12 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -130,9 +131,12 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
return num_frames return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(
mm_config = self.ctx.get_mm_config() self,
max_videos = mm_config.get_limit_per_prompt("video") seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len) max_total_frames = self._get_max_video_frames(seq_len)
...@@ -155,7 +159,7 @@ class LlavaNextVideoDummyInputsBuilder( ...@@ -155,7 +159,7 @@ class LlavaNextVideoDummyInputsBuilder(
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { mm_data = {
"video": "video":
......
...@@ -108,7 +108,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -108,7 +108,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
) -> Mapping[str, int]: ) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), "image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len), "video": self.get_max_video_tokens(seq_len, mm_counts),
} }
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
...@@ -202,10 +202,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -202,10 +202,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
return num_frames return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(
mm_config = self.ctx.get_mm_config() self,
max_images = mm_config.get_limit_per_prompt("image") seq_len: int,
max_videos = mm_config.get_limit_per_prompt("video") mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
...@@ -215,13 +218,18 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -215,13 +218,18 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len), num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
) )
...@@ -243,7 +251,8 @@ class LlavaOnevisionDummyInputsBuilder( ...@@ -243,7 +251,8 @@ class LlavaOnevisionDummyInputsBuilder(
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len) self.info.get_num_frames_with_most_features(seq_len,
mm_counts)
mm_data = { mm_data = {
"image": "image":
......
...@@ -43,7 +43,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, ...@@ -43,7 +43,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config) _minicpmv_field_config)
...@@ -203,8 +204,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -203,8 +204,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
return 30 return 30
def get_max_audio_tokens(self) -> int: def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk( num_chunks = self.get_max_audio_chunks_with_most_features()
) * self.get_max_audio_chunks_with_most_features() return self.get_max_audio_tokens_per_chunk() * num_chunks
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate() sampling_rate = self.get_default_audio_sampling_rate()
...@@ -212,21 +213,24 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -212,21 +213,24 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2 num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(
mm_config = self.ctx.get_mm_config() self,
max_images = mm_config.get_limit_per_prompt("image") seq_len: int,
max_videos = mm_config.get_limit_per_prompt("video") mm_counts: Mapping[str, int],
max_audios = mm_config.get_limit_per_prompt("audio") ) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_audios = mm_counts.get("audio", 0)
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_audio_tokens = self.get_max_audio_tokens() * max_audios max_audio_tokens = self.get_max_audio_tokens() * max_audios
max_total_frames = self.get_max_video_frames(seq_len - max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens - max_image_tokens -
max_audio_tokens) max_audio_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
num_frames = max(max_total_frames // max(max_videos, 1), 1) return max(max_frames_per_video, 1)
return num_frames
class MiniCPMODummyInputsBuilder( class MiniCPMODummyInputsBuilder(
......
...@@ -69,6 +69,9 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, ...@@ -69,6 +69,9 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features from .vision import scatter_patch_features, select_patch_features
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
...@@ -369,7 +372,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -369,7 +372,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
) -> Mapping[str, int]: ) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()} mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) mm_max_tokens["video"] = self.get_max_video_tokens(
seq_len, mm_counts)
return mm_max_tokens return mm_max_tokens
...@@ -432,9 +436,14 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -432,9 +436,14 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id=False, use_image_id=False,
) )
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(
return self.get_max_video_frame_tokens( self,
) * self.get_num_frames_with_most_features(seq_len) seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
num_frames = self.get_num_frames_with_most_features(seq_len, mm_counts)
num_video_tokens_total = self.get_max_video_frame_tokens() * num_frames
return num_video_tokens_total
def get_video_max_slice_num(self) -> int: def get_video_max_slice_num(self) -> int:
return 1 return 1
...@@ -449,18 +458,21 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -449,18 +458,21 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
num_frames = max_tokens // num_frame_tokens num_frames = max_tokens // num_frame_tokens
return num_frames return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(
mm_config = self.ctx.get_mm_config() self,
max_images = mm_config.get_limit_per_prompt("image") seq_len: int,
max_videos = mm_config.get_limit_per_prompt("video") mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self.get_max_video_frames(seq_len - max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
num_frames = max(max_total_frames // max(max_videos, 1), 1) return max(max_frames_per_video, 1)
return num_frames
_I = TypeVar("_I", _I = TypeVar("_I",
...@@ -483,7 +495,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -483,7 +495,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
video_width, video_height = \ video_width, video_height = \
self.info.get_video_frame_size_with_most_features() self.info.get_video_frame_size_with_most_features()
num_video_frames = \ num_video_frames = \
self.info.get_num_frames_with_most_features(seq_len) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { mm_data = {
"image": "image":
......
...@@ -806,7 +806,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -806,7 +806,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
**kwargs: object, **kwargs: object,
): ) -> Qwen2VLImageProcessor:
return cached_image_processor_from_config( return cached_image_processor_from_config(
self.ctx.model_config, self.ctx.model_config,
**self._get_image_processor_kwargs(min_pixels=min_pixels, **self._get_image_processor_kwargs(min_pixels=min_pixels,
...@@ -825,7 +825,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -825,7 +825,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
) -> Mapping[str, int]: ) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), "image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len), "video": self.get_max_video_tokens(seq_len, mm_counts),
} }
def _get_vision_info( def _get_vision_info(
...@@ -941,10 +941,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -941,10 +941,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return num_frames return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(
mm_config = self.ctx.get_mm_config() self,
max_images = mm_config.get_limit_per_prompt("image") seq_len: int,
max_videos = mm_config.get_limit_per_prompt("video") mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
...@@ -954,13 +957,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -954,13 +957,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len), num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None, image_processor=None,
) )
...@@ -982,7 +990,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): ...@@ -982,7 +990,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { mm_data = {
"image": "image":
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, NamedTuple, TypeVar, cast from typing import Generic, NamedTuple, Optional, TypeVar, cast
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -160,17 +160,19 @@ class MultiModalProfiler(Generic[_I]): ...@@ -160,17 +160,19 @@ class MultiModalProfiler(Generic[_I]):
def get_and_validate_mm_inputs( def get_and_validate_mm_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> tuple[MultiModalInputs, Mapping[str, int]]: ) -> tuple[MultiModalInputs, Mapping[str, int]]:
mm_counts = self.get_mm_limits() if mm_counts is None:
mm_counts = self.get_mm_limits()
info = self.processing_info info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts) seq_len, mm_counts)
if mm_counts.keys() != mm_max_tokens_per_item.keys(): if mm_counts.keys() - mm_max_tokens_per_item.keys():
raise AssertionError( raise AssertionError(
"The keys returned by `get_supported_mm_limits` " "The keys returned by `get_supported_mm_limits` "
f"({set(mm_counts.keys())}) should be the same as those " f"({set(mm_counts.keys())}) should be a subset of those "
"returned by `get_mm_max_tokens_per_item` " "returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})") f"({set(mm_max_tokens_per_item.keys())})")
...@@ -193,8 +195,12 @@ class MultiModalProfiler(Generic[_I]): ...@@ -193,8 +195,12 @@ class MultiModalProfiler(Generic[_I]):
"tokens.") "tokens.")
return mm_inputs, total_placeholders_by_modality return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data(self, seq_len: int) -> DummyEncoderData: def get_encoder_dummy_data(
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len) self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of # For encoder-decoder models, use encoder prompt token ids instead of
...@@ -207,9 +213,15 @@ class MultiModalProfiler(Generic[_I]): ...@@ -207,9 +213,15 @@ class MultiModalProfiler(Generic[_I]):
return DummyEncoderData(encoder_prompt_token_ids) return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data(self, seq_len: int) -> DummyDecoderData: def get_decoder_dummy_data(
(mm_inputs, total_placeholders_by_modality self,
) = self.get_and_validate_mm_inputs(seq_len) seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData:
(
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
......
...@@ -458,6 +458,7 @@ class MultiModalRegistry: ...@@ -458,6 +458,7 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData: ) -> DummyDecoderData:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
...@@ -466,7 +467,7 @@ class MultiModalRegistry: ...@@ -466,7 +467,7 @@ class MultiModalRegistry:
""" """
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids token_ids = dummy_data.prompt_token_ids
...@@ -481,6 +482,7 @@ class MultiModalRegistry: ...@@ -481,6 +482,7 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData: ) -> DummyEncoderData:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
...@@ -489,7 +491,7 @@ class MultiModalRegistry: ...@@ -489,7 +491,7 @@ class MultiModalRegistry:
""" """
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids token_ids = dummy_data.prompt_token_ids
......
...@@ -1470,19 +1470,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1470,19 +1470,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_budget, max_num_mm_items, dummy_data_modality) encoder_budget, max_num_mm_items, dummy_data_modality)
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
dummy_request_data = self.mm_registry.get_decoder_dummy_data( dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config, model_config=self.model_config,
seq_len=self.max_num_tokens, seq_len=self.max_num_tokens,
) mm_counts={
dummy_mm_data = dummy_request_data.multi_modal_data dummy_data_modality: 1
},
# Dummy data definition may contain multiple multimodal items ).multi_modal_data
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
dummy_mm_item = dummy_mm_data.get_item(
modality=dummy_data_modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch( batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items) [dummy_mm_kwargs] * max_num_mm_items)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment