Unverified Commit 90db5b31 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move top-level dummy data generation to registry (#32310)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b8199f60
......@@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int):
limit_mm_per_prompt=mm_counts,
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
processor,
max_model_len,
mm_counts=mm_counts,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_inputs = MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
ctx.model_config,
mm_counts=mm_counts,
)
hf_config = ctx.get_hf_config(Llama4Config)
mm_inputs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)
mm_data = mm_inputs["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))
)
tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio
mm_data = mm_inputs["mm_kwargs"].get_data()
chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = (
......@@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int):
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
)
assert total_tokens == sum(
placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"]
)
......@@ -926,10 +926,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx:
processor.dummy_inputs.get_decoder_dummy_data(
processor,
model_config.max_model_len,
MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
model_config,
mm_counts=limit_mm_per_prompt,
processor=processor,
)
......
......@@ -50,6 +50,7 @@ from .parse import (
MultiModalDataItems,
MultiModalDataParser,
)
from .profiling import BaseDummyInputsBuilder
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
......@@ -59,7 +60,6 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
else:
PretrainedConfig = object
BatchFeature = object
......
......@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Generic
import numpy as np
import numpy.typing as npt
......@@ -17,16 +17,14 @@ from vllm.config.multimodal import (
)
from vllm.logger import init_logger
from .inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
)
from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
)
from .inputs import MultiModalDataDict
if TYPE_CHECKING:
from .processing import _I
else:
from typing import TypeVar
_I = TypeVar("_I")
logger = init_logger(__name__)
......@@ -44,23 +42,6 @@ class ProcessorInputs:
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
......@@ -222,52 +203,3 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
return [video] * num_videos
def get_dummy_mm_inputs(
self,
processor: BaseMultiModalProcessor[_I],
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalInputs:
if mm_counts is None:
mm_counts = processor.allowed_mm_limits
processor_inputs = self.get_dummy_processor_inputs(
seq_len,
mm_counts=mm_counts,
mm_options=mm_options,
)
return processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
def get_decoder_dummy_data(
self,
processor: BaseMultiModalProcessor[_I],
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyDecoderData:
mm_inputs = self.get_dummy_mm_inputs(
processor,
seq_len,
mm_counts=mm_counts,
mm_options=mm_options,
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
......@@ -10,15 +10,13 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache
from .inputs import MultiModalInputs
from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
)
from .profiling import (
BaseDummyInputsBuilder,
DummyDecoderData,
)
from .profiling import BaseDummyInputsBuilder
if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig
......@@ -160,7 +158,6 @@ class MultiModalRegistry:
model_config, observability_config, cache=cache
)
seq_len = model_config.max_model_len
if profiler_limits is None:
profiler_limits = processor.allowed_mm_limits
......@@ -169,7 +166,7 @@ class MultiModalRegistry:
}
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=seq_len,
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
......@@ -179,11 +176,10 @@ class MultiModalRegistry:
if mm_counts.get(modality, 0) > 0
}
mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs(
processor,
seq_len,
mm_inputs = self.get_dummy_mm_inputs(
model_config,
mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config),
processor=processor,
)
return {
......@@ -298,39 +294,47 @@ class MultiModalRegistry:
return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data(
def get_dummy_mm_inputs(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyDecoderData:
processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(
model_config, observability_config, cache=cache
)
dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
processor,
seq_len,
seq_len = model_config.max_model_len
if processor is None:
processor = self.create_processor(
model_config, observability_config, cache=cache
)
if mm_counts is None:
mm_counts = processor.allowed_mm_limits
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len,
mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config),
)
mm_inputs = processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(token_ids)} tokens instead."
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return dummy_data
return mm_inputs
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
......
......@@ -4192,16 +4192,18 @@ class GPUModelRunner(
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_model_len,
# Don't use `max_items_per_batch` here to avoid redundant computation
dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs(
self.model_config,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
dummy_mm_data = dummy_decoder_data.multi_modal_data
dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0]
# We use the cache so that the item is saved to the cache,
# but not read from the cache
assert dummy_mm_item is not None, "Item should not already be cached"
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment