Unverified Commit 3f674a49 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)

parent 70b746ef
......@@ -17,4 +17,4 @@ Input Processing Pipeline
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.
......@@ -15,6 +15,9 @@ by following :ref:`this guide <adding_multimodal_plugin>`.
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
..
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
Guides
++++++
......
......@@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens
------------------------------------------------
For each modality type that the model accepts as input, calculate the maximum possible number of tokens
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
.. code-block:: diff
......
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
("image=16", {
"image": 16
}),
("image=16,video=2", {
"image": 16,
"video": 2
}),
])
def test_limit_mm_per_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--limit-mm-per-prompt", arg])
assert args.limit_mm_per_prompt == expected
......@@ -59,7 +59,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
......@@ -49,7 +49,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
......@@ -117,7 +117,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
......@@ -69,7 +69,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
......@@ -177,7 +177,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
......@@ -61,7 +61,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......@@ -176,7 +176,7 @@ def run_multi_image_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......@@ -197,6 +197,7 @@ def run_multi_image_test(
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
limit_mm_per_prompt={"image": len(images)},
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
......
......@@ -72,7 +72,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
......@@ -73,7 +73,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
......
from contextlib import nullcontext
import numpy as np
import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
@pytest.fixture
def mm_registry():
return MultiModalRegistry()
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, dtype, size_factor):
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
......@@ -24,6 +31,9 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
dtype=dtype,
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
......@@ -32,7 +42,7 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
image,
return_tensors="pt",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
......@@ -48,7 +58,8 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, dtype, size_factor):
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
size_factor):
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
......@@ -63,6 +74,9 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
dtype=dtype,
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
......@@ -71,7 +85,7 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
image,
return_tensors="pt",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
......@@ -83,3 +97,61 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": limit})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
image = image_assets[0].pil_image
if num_images == 0:
mm_inputs = {}
elif num_images == 1:
mm_inputs = {"image": image}
else:
mm_inputs = {"image": [image] * num_images}
with nullcontext() if is_valid else pytest.raises(ValueError):
mm_registry.map_input(model_config, mm_inputs)
# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": num_images})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
image = image_assets[0].pil_image
mm_inputs = {"image": [image] * num_images}
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
assert len(mapped_inputs["pixel_values"]) == num_images
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
Type, Union)
import torch
from transformers import PretrainedConfig
......@@ -1429,10 +1430,15 @@ class PromptAdapterConfig:
@dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
multimodal models."""
"""Controls the behavior of multimodal models."""
limit_per_prompt: Mapping[str, int]
"""
The maximum number of multi-modal input instances allowed per prompt
for each :class:`~vllm.multimodal.MultiModalPlugin`.
"""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE = {
......
......@@ -2,7 +2,8 @@ import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
......@@ -15,8 +16,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
logger = init_logger(__name__)
......@@ -29,11 +29,32 @@ def nullable_str(val: str):
return val
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
if len(val) == 0:
return None
out_dict: Dict[str, int] = {}
for item in val.split(","):
try:
key, value = item.split("=")
except TypeError as exc:
msg = "Each item should be in the form KEY=VALUE"
raise ValueError(msg) from exc
try:
out_dict[key] = int(value)
except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}"
raise ValueError(msg) from exc
return out_dict
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[List[str]]] = None
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
......@@ -81,6 +102,7 @@ class EngineArgs:
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
......@@ -435,6 +457,21 @@ class EngineArgs:
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Multimodal related configs
parser.add_argument(
'--limit-mm-per-prompt',
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalRegistry.init_mm_limits_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
......@@ -709,7 +746,8 @@ class EngineArgs:
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
multimodal_config = MultiModalConfig()
multimodal_config = MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt or {})
device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
......
......@@ -24,8 +24,9 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
PromptInputs, SingletonPromptInputs)
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -180,6 +181,7 @@ class LLMEngine:
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
......@@ -265,8 +267,9 @@ class LLMEngine:
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_executor = executor_class(
model_config=model_config,
......
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
Tuple, Type)
from torch import nn
from transformers import PretrainedConfig
......@@ -12,7 +14,7 @@ from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry
from vllm.sequence import SequenceData
logger = init_logger(__name__)
......@@ -65,15 +67,38 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData",
Optional["MultiModalDataDict"]]]
"""
Create dummy data to be inputted into the model.
Note:
class DummyDataFactory(Protocol):
def __call__(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
"""
...
class _MultiModalCounts(UserDict):
"""
Wraps `mm_counts` for a more informative error message
when attempting to access a plugin that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no multi-modal plugin with the key: {key}. "
f"Available keys: {set(self.keys())}")
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
"""Preprocess the inputs to the model."""
......@@ -95,6 +120,7 @@ class InputRegistry:
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
The default dummy data factory represents the longest possible text
......@@ -133,8 +159,12 @@ class InputRegistry:
return wrapper
def dummy_data_for_profiling(self, model_config: "ModelConfig",
seq_len: int):
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
seq_len: int,
mm_registry: "MultiModalRegistry",
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data for profiling the memory usage of a model.
......@@ -142,6 +172,10 @@ class InputRegistry:
See also:
:ref:`enabling_multimodal_inputs`
Note:
This should be called after
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
......@@ -149,8 +183,29 @@ class InputRegistry:
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
return dummy_factory(InputContext(model_config), seq_len)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
seq_data, mm_data = dummy_factory(
InputContext(model_config),
seq_len,
_MultiModalCounts(mm_counts),
)
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None:
for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k]
assert num_items >= num_expected, (
f"Expected at least {num_expected} dummy '{k}' instances "
f"for profiling, but found {num_items} instances instead.")
return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs:
......
......@@ -133,7 +133,7 @@ def _get_model_initialization_kwargs(
if supports_multimodal(model_class):
if multimodal_config is None:
raise ValueError("Provide vision related configurations "
raise ValueError("Provide multi-modal related configurations "
"through LLM entrypoint or engine arguments.")
extra_kwargs["multimodal_config"] = multimodal_config
......
......@@ -31,13 +31,13 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_image_feature_size(hf_config)
......@@ -60,6 +60,7 @@ def dummy_seq_data_for_blip(
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
......@@ -71,7 +72,7 @@ def dummy_image_for_blip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_blip(
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
......@@ -413,17 +414,39 @@ def get_max_blip2_image_tokens(ctx: InputContext):
raise NotImplementedError(msg)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_blip2_image_feature_size(hf_config)
token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
seq_data = SequenceData(token_ids)
seq_data = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config)
mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment