"vscode:/vscode.git/clone" did not exist on "0d243f2a54fbd1c56da8a571f0899c30b6aba5d9"
Unverified Commit 90db5b31 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move top-level dummy data generation to registry (#32310)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b8199f60
...@@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int): ...@@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int):
limit_mm_per_prompt=mm_counts, limit_mm_per_prompt=mm_counts,
) )
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) mm_inputs = MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data( ctx.model_config,
processor,
max_model_len,
mm_counts=mm_counts,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_counts, mm_counts=mm_counts,
) )
hf_config = ctx.get_hf_config(Llama4Config) hf_config = ctx.get_hf_config(Llama4Config)
mm_inputs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)
mm_data = mm_inputs["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size patch_size = hf_config.vision_config.patch_size
downsample_ratio = int( downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)) round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))
) )
tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio
mm_data = mm_inputs["mm_kwargs"].get_data()
chunks_per_image = prod(mm_data["patches_per_image"]) chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = ( num_tiles = (
...@@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int): ...@@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int):
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"] item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
) )
assert total_tokens == sum( assert total_tokens == sum(
placeholder.length placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"]
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
) )
...@@ -926,10 +926,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -926,10 +926,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
processor.dummy_inputs.get_decoder_dummy_data( MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
processor, model_config,
model_config.max_model_len,
mm_counts=limit_mm_per_prompt, mm_counts=limit_mm_per_prompt,
processor=processor,
) )
......
...@@ -50,6 +50,7 @@ from .parse import ( ...@@ -50,6 +50,7 @@ from .parse import (
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
) )
from .profiling import BaseDummyInputsBuilder
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -59,7 +60,6 @@ if TYPE_CHECKING: ...@@ -59,7 +60,6 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig, ObservabilityConfig
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
else: else:
PretrainedConfig = object PretrainedConfig = object
BatchFeature = object BatchFeature = object
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, NamedTuple, TypeVar from typing import TYPE_CHECKING, Generic
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -17,16 +17,14 @@ from vllm.config.multimodal import ( ...@@ -17,16 +17,14 @@ from vllm.config.multimodal import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import ( from .inputs import MultiModalDataDict
MultiModalDataDict,
MultiModalInputs, if TYPE_CHECKING:
MultiModalKwargsItems, from .processing import _I
MultiModalPlaceholderDict, else:
) from typing import TypeVar
from .processing import (
BaseMultiModalProcessor, _I = TypeVar("_I")
BaseProcessingInfo,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,23 +42,6 @@ class ProcessorInputs: ...@@ -44,23 +42,6 @@ class ProcessorInputs:
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict) tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]): class BaseDummyInputsBuilder(ABC, Generic[_I]):
""" """
Abstract base class that constructs the dummy data to profile Abstract base class that constructs the dummy data to profile
...@@ -222,52 +203,3 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -222,52 +203,3 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
height = min(height, overrides.height) height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
return [video] * num_videos return [video] * num_videos
def get_dummy_mm_inputs(
self,
processor: BaseMultiModalProcessor[_I],
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalInputs:
if mm_counts is None:
mm_counts = processor.allowed_mm_limits
processor_inputs = self.get_dummy_processor_inputs(
seq_len,
mm_counts=mm_counts,
mm_options=mm_options,
)
return processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
def get_decoder_dummy_data(
self,
processor: BaseMultiModalProcessor[_I],
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyDecoderData:
mm_inputs = self.get_dummy_mm_inputs(
processor,
seq_len,
mm_counts=mm_counts,
mm_options=mm_options,
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
...@@ -10,15 +10,13 @@ from vllm.logger import init_logger ...@@ -10,15 +10,13 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .inputs import MultiModalInputs
from .processing import ( from .processing import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
InputProcessingContext, InputProcessingContext,
) )
from .profiling import ( from .profiling import BaseDummyInputsBuilder
BaseDummyInputsBuilder,
DummyDecoderData,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig, ObservabilityConfig
...@@ -160,7 +158,6 @@ class MultiModalRegistry: ...@@ -160,7 +158,6 @@ class MultiModalRegistry:
model_config, observability_config, cache=cache model_config, observability_config, cache=cache
) )
seq_len = model_config.max_model_len
if profiler_limits is None: if profiler_limits is None:
profiler_limits = processor.allowed_mm_limits profiler_limits = processor.allowed_mm_limits
...@@ -169,7 +166,7 @@ class MultiModalRegistry: ...@@ -169,7 +166,7 @@ class MultiModalRegistry:
} }
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item( max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=seq_len, seq_len=model_config.max_model_len,
mm_counts=mm_counts, mm_counts=mm_counts,
) )
if max_tokens_per_item is not None: if max_tokens_per_item is not None:
...@@ -179,11 +176,10 @@ class MultiModalRegistry: ...@@ -179,11 +176,10 @@ class MultiModalRegistry:
if mm_counts.get(modality, 0) > 0 if mm_counts.get(modality, 0) > 0
} }
mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs( mm_inputs = self.get_dummy_mm_inputs(
processor, model_config,
seq_len,
mm_counts=mm_counts, mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config), processor=processor,
) )
return { return {
...@@ -298,39 +294,47 @@ class MultiModalRegistry: ...@@ -298,39 +294,47 @@ class MultiModalRegistry:
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data( def get_dummy_mm_inputs(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None, observability_config: ObservabilityConfig | None = None,
) -> DummyDecoderData: processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`. The model is identified by `model_config`.
""" """
seq_len = model_config.max_model_len
if processor is None:
processor = self.create_processor( processor = self.create_processor(
model_config, observability_config, cache=cache model_config, observability_config, cache=cache
) )
dummy_data = processor.dummy_inputs.get_decoder_dummy_data( if mm_counts is None:
processor, mm_counts = processor.allowed_mm_limits
seq_len,
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len,
mm_counts=mm_counts, mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config), mm_options=self._extract_mm_options(model_config),
) )
mm_inputs = processor.apply(
# Having more tokens is over-conservative but otherwise fine prompt=processor_inputs.prompt,
token_ids = dummy_data.prompt_token_ids mm_data=processor_inputs.mm_data,
if len(token_ids) < seq_len: hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
raise AssertionError( tokenization_kwargs=processor_inputs.tokenization_kwargs,
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(token_ids)} tokens instead."
) )
return dummy_data prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return mm_inputs
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
""" """
......
...@@ -4192,16 +4192,18 @@ class GPUModelRunner( ...@@ -4192,16 +4192,18 @@ class GPUModelRunner(
"""Dummy data for profiling and precompiling multimodal models.""" """Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( # Don't use `max_items_per_batch` here to avoid redundant computation
model_config=self.model_config, dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs(
seq_len=self.max_model_len, self.model_config,
mm_counts={modality: 1}, mm_counts={modality: 1},
cache=self.mm_budget.cache, cache=self.mm_budget.cache,
) )
dummy_mm_data = dummy_decoder_data.multi_modal_data dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0]
# We use the cache so that the item is saved to the cache,
# but not read from the cache
assert dummy_mm_item is not None, "Item should not already be cached"
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next( return next(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment