Unverified Commit 93abf23a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Fully dynamic prompt replacement in merged input processor (#11199)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9c3dadd1
...@@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str): ...@@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
# max_model_len (128k) for this model may cause OOM. # max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs. # You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
# num_crops is an override kwarg to the multimodal image processor; # num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame. # to use 16 for single frame scenarios, and 4 for multi-frame.
...@@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str): ...@@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM( llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct", model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
......
...@@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"] ...@@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection # Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture() @pytest.fixture()
def processor_for_phi3v(): def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VProcessor return Phi3VMultiModalProcessor
@pytest.fixture() @pytest.fixture()
......
from typing import cast from typing import cast
import pytest import pytest
from transformers import BatchFeature
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
find_text_matches, find_token_matches, _PlaceholderInfo, find_text_matches,
iter_placeholders, iter_token_matches, find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches, replace_text_matches,
replace_token_matches) replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -16,7 +16,7 @@ from vllm.utils import full_groupby ...@@ -16,7 +16,7 @@ from vllm.utils import full_groupby
@pytest.mark.parametrize( @pytest.mark.parametrize(
("token_ids", "match_ids", "expected"), ("token_ids", "match_ids", "expected"),
[ [
([], [], [{ "start_idx": 0, "end_idx": 0 }]), ([], [], []),
([], [32000], []), ([], [32000], []),
( (
[32000, 32000, 32000], [32000, 32000, 32000],
...@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_2": [32000], "pattern_2": [32000],
}, },
{ {
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }], "pattern_1": [],
"pattern_2": [], "pattern_2": [],
} }
), ),
...@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): ...@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer) PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
result = find_token_matches(prompt, prompt_repls) result = find_token_matches(prompt, prompt_repls)
...@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer) PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
result = find_text_matches(prompt, prompt_repls) result = find_text_matches(prompt, prompt_repls)
...@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
"pattern_3": "!", "pattern_3": "!",
}, },
{ {
# Test whether target is confused with repl_unit # Test whether target is confused with replacement
"pattern_1": ("<image><image>", 1), "pattern_1": "<image><image>",
# Test empty repl_unit # Test empty replacement
"pattern_2": ("", 1), "pattern_2": "",
# Test multiple repl_count # Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": ("?", 2), "pattern_3": "?!?",
}, },
), ),
] ]
...@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
("mm_count", "expected"), ("mm_count", "expected"),
[ [
(0, "Image:<image>Image:<image><image>!"), (0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"), (1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>??"), (2, "<image><image><image><image><image>?!?"),
] ]
) )
# yapf: enable # yapf: enable
...@@ -306,7 +306,7 @@ def test_find_replace_text( ...@@ -306,7 +306,7 @@ def test_find_replace_text(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
matches = find_text_matches(prompt, prompt_repls) matches = find_text_matches(prompt, prompt_repls)
...@@ -314,9 +314,8 @@ def test_find_replace_text( ...@@ -314,9 +314,8 @@ def test_find_replace_text(
result = replace_text_matches( result = replace_text_matches(
prompt, prompt,
matches, matches,
{key: list(range(mm_count)) MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}, for key in repl_by_key}),
BatchFeature(),
) )
# Only displayed on error # Only displayed on error
...@@ -343,12 +342,12 @@ def test_find_replace_text( ...@@ -343,12 +342,12 @@ def test_find_replace_text(
"pattern_3": [918], "pattern_3": [918],
}, },
{ {
# Test whether target is confused with repl_unit # Test whether target is confused with replacement
"pattern_1": ([32000, 32000], 1), "pattern_1": [32000, 32000],
# Test empty repl_unit # Test empty replacement
"pattern_2": ([], 1), "pattern_2": [],
# Test multiple repl_count # Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": ([1550], 2), "pattern_3": [1550, 918, 1550],
}, },
), ),
] ]
...@@ -357,8 +356,8 @@ def test_find_replace_text( ...@@ -357,8 +356,8 @@ def test_find_replace_text(
("mm_count", "expected"), ("mm_count", "expected"),
[ [
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
] ]
) )
# yapf: enable # yapf: enable
...@@ -373,7 +372,7 @@ def test_find_replace_tokens( ...@@ -373,7 +372,7 @@ def test_find_replace_tokens(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
matches = find_token_matches(prompt, prompt_repls) matches = find_token_matches(prompt, prompt_repls)
...@@ -381,9 +380,8 @@ def test_find_replace_tokens( ...@@ -381,9 +380,8 @@ def test_find_replace_tokens(
result = replace_token_matches( result = replace_token_matches(
prompt, prompt,
matches, matches,
{key: list(range(mm_count)) MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}, for key in repl_by_key}),
BatchFeature(),
) )
# Only displayed on error # Only displayed on error
...@@ -399,9 +397,9 @@ def test_find_replace_tokens( ...@@ -399,9 +397,9 @@ def test_find_replace_tokens(
"repl_by_key", "repl_by_key",
[ [
{ {
"pattern_1": ([32000, 32000], 1), "pattern_1": [32000, 32000],
"pattern_2": ([], 1), "pattern_2": [],
"pattern_3": ([1550], 2), "pattern_3": [1550, 918, 1550],
}, },
], ],
) )
...@@ -414,48 +412,47 @@ def test_find_replace_tokens( ...@@ -414,48 +412,47 @@ def test_find_replace_tokens(
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=6, start_idx=6,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=1,
), ),
], ],
), ),
( (
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[ [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=1, start_idx=1,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=1,
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=5, start_idx=5,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=1,
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
start_idx=7, start_idx=7,
unit=[1550], replacement=[1550, 918, 1550],
unit_count=2,
), ),
], ],
), ),
( (
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[ [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=1, start_idx=1,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=2, ),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
start_idx=6, start_idx=6,
unit=[1550], replacement=[1550, 918, 1550],
unit_count=2,
), ),
], ],
), ),
...@@ -470,11 +467,17 @@ def test_iter_placeholders( ...@@ -470,11 +467,17 @@ def test_iter_placeholders(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer) PromptReplacement(key, [], repl).bind(mock_tokenizer)
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
] ]
result = list(iter_placeholders(prompt_repls, prompt)) result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
))
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
......
...@@ -3,14 +3,14 @@ from typing import Optional ...@@ -3,14 +3,14 @@ from typing import Optional
import torch import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaProcessor, LlavaMultiModalProcessor,
get_max_llava_image_tokens) get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration): class MyLlava(LlavaForConditionalGeneration):
def compute_logits( def compute_logits(
......
...@@ -2,7 +2,7 @@ import functools ...@@ -2,7 +2,7 @@ import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type, cast) Optional, Protocol, Type)
from torch import nn from torch import nn
from transformers import PretrainedConfig, ProcessorMixin from transformers import PretrainedConfig, ProcessorMixin
...@@ -47,7 +47,6 @@ class InputContext: ...@@ -47,7 +47,6 @@ class InputContext:
Raises: Raises:
TypeError: If the model is not of the specified type. TypeError: If the model is not of the specified type.
""" """
hf_config = self.model_config.hf_config hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type): if not isinstance(hf_config, hf_config_type):
raise TypeError("Invalid type of HuggingFace config. " raise TypeError("Invalid type of HuggingFace config. "
...@@ -60,21 +59,70 @@ class InputContext: ...@@ -60,21 +59,70 @@ class InputContext:
""" """
Get the HuggingFace image processor configuration of the model. Get the HuggingFace image processor configuration of the model.
""" """
return self.model_config.hf_image_processor_config return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
@dataclass(frozen=True) @dataclass(frozen=True)
class InputProcessingContext(InputContext): class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs) -> ProcessorMixin: def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor( return cached_get_processor(
self.model_config.tokenizer, self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
**kwargs) **merged_kwargs,
)
def resolve_hf_processor_call_kwargs(
self,
hf_processor: ProcessorMixin,
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
return resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
)
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
...@@ -171,7 +219,8 @@ class InputRegistry: ...@@ -171,7 +219,8 @@ class InputRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type: if self._dummy_factories_by_model_type.contains(model_cls,
strict=True):
logger.warning( logger.warning(
"Model class %s already has dummy data " "Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -195,7 +244,8 @@ class InputRegistry: ...@@ -195,7 +244,8 @@ class InputRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type: if self._dummy_encoder_factories_by_model_type.contains(
model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already has dummy encoder data " "Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -305,7 +355,8 @@ class InputRegistry: ...@@ -305,7 +355,8 @@ class InputRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors_by_model_type: if self._input_processors_by_model_type.contains(model_cls,
strict=True):
logger.warning( logger.warning(
"Model class %s already has input processor " "Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -357,7 +408,7 @@ class InputRegistry: ...@@ -357,7 +408,7 @@ class InputRegistry:
# If it's empty, it'll fall back to the default kwarg values # If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs( mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")), inputs.get("mm_processor_kwargs", {}), # type: ignore
processor, processor,
) )
......
...@@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, ...@@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig, PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig) ProcessorMixin, SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext, MultiModalDataItems, ProcessorInputs,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, ...@@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens) get_max_clip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
get_max_pixtral_hf_image_tokens) get_max_pixtral_hf_image_tokens,
get_pixtral_hf_image_feature_size)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
get_max_siglip_image_tokens) get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
...@@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext): ...@@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}") raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext, class LlavaMultiModalProcessor(BaseMultiModalProcessor):
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
def create_metadata_for_llava(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
hf_config = ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
def get_repl_count(
mm_items: list[Image],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
return get_max_llava_image_tokens(ctx)
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[image_token_id],
repl_unit=[image_token_id],
repl_count=get_repl_count),
]),
}
class LlavaProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
)
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False): if getattr(hf_processor, "__is_patched__", False):
...@@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor): ...@@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor):
hf_processor.__is_patched__ = True # type: ignore hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> ProcessorMixin: def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor() hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor): if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor) self._patch_pixtral_processor(hf_processor)
return hf_processor return hf_processor
def _get_dummy_mm_kwargs( def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
if isinstance(processor, PixtralProcessor):
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
) = get_pixtral_hf_image_feature_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
tokens = ([image_token] * num_width_tokens +
[image_break_token]) * num_height_tokens
tokens[-1] = image_end_token
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_pixtral,
),
]
max_image_tokens = get_max_llava_image_tokens(self.ctx)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
]
def _get_dummy_mm_inputs(
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalKwargs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig) hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
...@@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor): ...@@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor):
raise NotImplementedError(msg) raise NotImplementedError(msg)
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore image_token = hf_processor.image_token
hf_inputs = image_processor.preprocess(data['image'],
return_tensors="pt")
return MultiModalKwargs(**hf_inputs) return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
class LlavaLikeConfig(Protocol): class LlavaLikeConfig(Protocol):
...@@ -303,7 +303,7 @@ def init_vision_tower_for_llava( ...@@ -303,7 +303,7 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes # BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
...@@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return loader.load_weights(weights) return loader.load_weights(weights)
class MantisProcessor(LlavaProcessor): class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin: def _get_hf_processor(self) -> ProcessorMixin:
try: try:
...@@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor): ...@@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor):
# To use this model, please use # To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration): class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass pass
...@@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict, MultiModalDataDict,
MultiModalProcessingMetadata, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
def get_max_phi3v_image_tokens(ctx: InputContext, def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
*, processor = ctx.get_hf_processor()
num_crops: Optional[int] = None): image_processor = processor.image_processor # type: ignore
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
num_tokens = image_processor.calc_num_image_tokens_from_image_size( return image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH, width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return num_tokens
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}
class Phi3VProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
def _get_hf_processor( def _get_hf_processor(
self, self,
...@@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor): ...@@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor):
processed_outputs['input_ids'] = token_ids processed_outputs['input_ids'] = token_ids
return processed_outputs return processed_outputs
def _get_dummy_mm_kwargs( def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
]
def _get_dummy_mm_inputs(
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalKwargs: ) -> ProcessorInputs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts) num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ...@@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img image_token_id = mm_encoder.special_ids.img
mm_config = ctx.model_config.multimodal_config mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1) num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size # dummy size
......
...@@ -99,7 +99,7 @@ class MultiModalPlugin(ABC): ...@@ -99,7 +99,7 @@ class MultiModalPlugin(ABC):
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers: if self._input_mappers.contains(model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already has an input mapper " "Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -194,7 +194,7 @@ class MultiModalPlugin(ABC): ...@@ -194,7 +194,7 @@ class MultiModalPlugin(ABC):
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens: if self._max_mm_tokens.contains(model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already calculates maximum number of " "Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.", "tokens in %s. It is overwritten by the new one.",
......
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
TypeVar, Union, cast)
import numpy as np
import torch import torch
from PIL.Image import Image
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, ProcessorMixin
from typing_extensions import TypeAlias, TypedDict from typing_extensions import assert_never
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of, from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
resolve_mm_processor_kwargs)
from .inputs import (AudioItem, ImageItem, MultiModalDataDict, from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem) VideoItem)
logger = init_logger(__name__)
def bind_prompt_sequence(
seq: Union[str, list[int]],
tokenizer: AnyTokenizer,
) -> "_BoundPromptSequence":
"""
Bind a text or token sequence to a tokenizer so that it can be
lazily converted into the other format on demand.
"""
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
_T = TypeVar("_T")
_S = TypeVar("_S", str, list[int]) _S = TypeVar("_S", str, list[int])
_PromptSeq = Union[str, list[int]]
@dataclass @dataclass
class PromptReplacement(Generic[_S, _T]): class PromptReplacement:
target: _S modality: str
"""The text or token sequence to find and replace.""" """The modality for which the replacement is made"""
repl_unit: _S
"""
The unit making up the replacement text or token sequence.
See :code:`repl_count` for more details. target: _PromptSeq
""" """The text or token sequence to find and replace."""
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] replacement: Union[Callable[[int], _PromptSeq],
_PromptSeq] = field(repr=False)
""" """
Given the original multi-modal items for this modality, HF-processed data, Given the index of the processed item within :attr:`modality`, output the
and index of the processed item, output the number of repetitions of replacement text or token sequence.
:code:`repl_unit` to build up the replacement text or token sequence.
For convenience, you can pass in an integer if the number of repetitions is For convenience, you can pass in the replacement instead of a function
a constant. if it does not depend on the input.
""" """
def __repr__(self) -> str: def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
return (f"{type(self).__name__}(target={self.target!r}, "
f"repl_unit={self.repl_unit!r})")
def bind(
self,
modality: str,
tokenizer: AnyTokenizer,
) -> "_BoundPromptReplacement[_T]":
return _BoundPromptReplacement( return _BoundPromptReplacement(
modality=modality, tokenizer=tokenizer,
target=bind_prompt_sequence(self.target, tokenizer), modality=self.modality,
repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer), _target=self.target,
repl_count=self.repl_count, _replacement=self.replacement,
) )
@dataclass
class ModalityProcessingMetadata(Generic[_T]):
prompt_repls: Sequence[Union[PromptReplacement[str, _T],
PromptReplacement[list[int], _T]]]
"""
Defines each text or token sequence to replace in the HF-processed prompt.
This is skipped if the HF-processed prompt is found to already contain
the replacement prompts.
"""
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityProcessingMetadata[ImageItem]
video: ModalityProcessingMetadata[VideoItem]
audio: ModalityProcessingMetadata[AudioItem]
MultiModalProcessingMetadata: TypeAlias = \
Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
def _encode( def _encode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,
...@@ -185,7 +128,8 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: ...@@ -185,7 +128,8 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
@dataclass @dataclass
class _BoundPromptSequence: class _BoundPromptSequence:
tokenizer: AnyTokenizer tokenizer: AnyTokenizer = field(repr=False)
_text: Optional[str] _text: Optional[str]
_token_ids: Optional[list[int]] _token_ids: Optional[list[int]]
...@@ -210,38 +154,92 @@ class _BoundPromptSequence: ...@@ -210,38 +154,92 @@ class _BoundPromptSequence:
return self._token_ids return self._token_ids
def __repr__(self) -> str:
return (f"{type(self).__name__}(_text={self._text!r}, "
f"_token_ids={self._token_ids!r})")
@dataclass @dataclass
class _BoundPromptReplacement(Generic[_T]): class _BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False)
modality: str modality: str
target: _BoundPromptSequence
repl_unit: _BoundPromptSequence
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
def get_count( _target: _PromptSeq
self, _replacement: Union[Callable[[int], _PromptSeq],
mm_items: list[_T], _PromptSeq] = field(repr=False)
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
repl_count = self.repl_count
if isinstance(repl_count, int):
return repl_count
return repl_count(mm_items, hf_inputs, item_idx) def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptSequence]()
@property
def target(self) -> _BoundPromptSequence:
target = self._target
def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: return _BoundPromptSequence(
tokenizer=self.tokenizer,
_text=target if isinstance(target, str) else None,
_token_ids=target if isinstance(target, list) else None,
)
def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
replacement = self._replacement
if callable(replacement):
cache_key = item_idx
if cache_key in self._replacement_cache:
return self._replacement_cache[cache_key]
replacement = replacement(item_idx)
else:
cache_key = None
bound_replacement = _BoundPromptSequence(
tokenizer=self.tokenizer,
_text=replacement if isinstance(replacement, str) else None,
_token_ids=replacement if isinstance(replacement, list) else None,
)
if cache_key is not None:
self._replacement_cache[cache_key] = bound_replacement
return bound_replacement
class ImageSize(NamedTuple):
width: int
height: int
class MultiModalDataItems(UserDict[str, list[Any]]):
""" """
Convert a :class:`MultiModalDataDict` containing single data items As :class:`MultiModalDataDict`, but normalized such that each entry
to a :class:`MultiModalMultiDataDict` containing multiple data items corresponds to a list.
per entry.
""" """
multi_data = dict[str, list[Any]]()
@property
def image(self) -> list[ImageItem]:
return self["image"]
@property
def video(self) -> list[VideoItem]:
return self["video"]
@property
def audio(self) -> list[AudioItem]:
return self["audio"]
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.image[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items(): for k, v in data.items():
# yapf: disable # yapf: disable
...@@ -266,22 +264,33 @@ def iter_token_matches( ...@@ -266,22 +264,33 @@ def iter_token_matches(
token_ids: list[int], token_ids: list[int],
match_ids: list[int], match_ids: list[int],
) -> Iterable[_TokenMatch]: ) -> Iterable[_TokenMatch]:
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`.""" """
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
Note that empty matches are ignored.
"""
prompt_len = len(token_ids)
match_len = len(match_ids) match_len = len(match_ids)
last_end_idx = 0 if match_len == 0:
for start_idx in range(len(token_ids) - match_len + 1): return
if start_idx < last_end_idx:
continue # Exclude overlapping matches
start_idx = 0
while start_idx < prompt_len - match_len + 1:
end_idx = start_idx + match_len end_idx = start_idx + match_len
if token_ids[start_idx:end_idx] == match_ids: if token_ids[start_idx:end_idx] == match_ids:
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx) yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
last_end_idx = end_idx
# Exclude overlapping matches
start_idx = end_idx
else:
start_idx += 1
class _PromptReplacementMatch(ABC, Generic[_T, _S]): @dataclass(repr=False)
prompt_repl: _BoundPromptReplacement[_T] class _PromptReplacementMatch(ABC):
prompt_repl: _BoundPromptReplacement
@property @property
def modality(self) -> str: def modality(self) -> str:
...@@ -297,19 +306,13 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]): ...@@ -297,19 +306,13 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]):
def end_idx(self) -> int: def end_idx(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod
def repl_unit(self) -> _S:
raise NotImplementedError
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, " return (f"{type(self).__name__}(modality={self.modality!r}, "
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]): class _PromptReplacementTokenMatch(_PromptReplacementMatch):
prompt_repl: _BoundPromptReplacement[_T]
match: _TokenMatch match: _TokenMatch
@property @property
...@@ -320,14 +323,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]): ...@@ -320,14 +323,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
def end_idx(self) -> int: def end_idx(self) -> int:
return self.match.end_idx return self.match.end_idx
@property
def repl_unit(self) -> list[int]:
return self.prompt_repl.repl_unit.token_ids
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]): class _PromptReplacementTextMatch(_PromptReplacementMatch):
prompt_repl: _BoundPromptReplacement[_T]
match: re.Match[str] match: re.Match[str]
@property @property
...@@ -338,20 +336,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]): ...@@ -338,20 +336,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
def end_idx(self) -> int: def end_idx(self) -> int:
return self.match.end() return self.match.end()
@property
def repl_unit(self) -> str:
return self.prompt_repl.repl_unit.text
class _PlaceholderInfo(NamedTuple): class _PlaceholderInfo(NamedTuple):
modality: str modality: str
start_idx: int start_idx: int
unit: list[int] replacement: list[int]
unit_count: int
@property @property
def length(self) -> int: def length(self) -> int:
return len(self.unit) * self.unit_count return len(self.replacement)
def to_range(self) -> PlaceholderRange: def to_range(self) -> PlaceholderRange:
return PlaceholderRange( return PlaceholderRange(
...@@ -362,8 +355,8 @@ class _PlaceholderInfo(NamedTuple): ...@@ -362,8 +355,8 @@ class _PlaceholderInfo(NamedTuple):
def find_token_matches( def find_token_matches(
prompt: list[int], prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[_T]], prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch[_T]]: ) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [ return [
_PromptReplacementTokenMatch(prompt_repl, match) _PromptReplacementTokenMatch(prompt_repl, match)
...@@ -374,8 +367,8 @@ def find_token_matches( ...@@ -374,8 +367,8 @@ def find_token_matches(
def find_text_matches( def find_text_matches(
prompt: str, prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement[_T]], prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch[_T]]: ) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [ return [
_PromptReplacementTextMatch(prompt_repl, match) _PromptReplacementTextMatch(prompt_repl, match)
...@@ -385,15 +378,15 @@ def find_text_matches( ...@@ -385,15 +378,15 @@ def find_text_matches(
def _resolve_matches( def _resolve_matches(
prompt: _S, prompt: _PromptSeq,
matches: Sequence[_PromptReplacementMatch[_T, _S]], matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch[_T, _S]]: ) -> list[_PromptReplacementMatch]:
""" """
Resolve :code:`matches` to ensure that there are no overlapping matches, Resolve :code:`matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones. and sort them such that earlier matches take priority over later ones.
""" """
seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \ seen_matches: list[Optional[_PromptReplacementMatch]] = [None
= [None] * len(prompt) ] * len(prompt)
for match in matches: for match in matches:
for idx in range(match.start_idx, match.end_idx): for idx in range(match.start_idx, match.end_idx):
...@@ -409,30 +402,34 @@ def _resolve_matches( ...@@ -409,30 +402,34 @@ def _resolve_matches(
def _replace_matches( def _replace_matches(
prompt: _S, prompt: _S,
matches: Sequence[_PromptReplacementMatch[_T, _S]], matches: Sequence[_PromptReplacementMatch],
mm_items_by_modality: Mapping[str, list[_T]], mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
) -> list[_S]: ) -> list[_S]:
out_seqs = list[_S]() out_seqs = list[_S]()
prev_end_idx = 0 prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality} next_idx_by_modality = {modality: 0 for modality in mm_items}
for match in _resolve_matches(prompt, matches): for match in _resolve_matches(prompt, matches):
modality = match.modality modality = match.modality
mm_items = mm_items_by_modality[modality] modal_items = mm_items[modality]
item_idx = next_idx_by_modality[modality] item_idx = next_idx_by_modality[modality]
if item_idx >= len(mm_items): if item_idx >= len(modal_items):
continue continue
start_idx = match.start_idx start_idx = match.start_idx
end_idx = match.end_idx end_idx = match.end_idx
repl_unit = match.repl_unit
repl_info = match.prompt_repl repl_info = match.prompt_repl
repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx) replacement = repl_info.get_replacement(item_idx)
if isinstance(prompt, str):
repl_seq = replacement.text
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
else:
repl_seq = replacement.token_ids
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
out_seqs.append(prompt[prev_end_idx:start_idx] +
repl_unit * repl_count)
prev_end_idx = end_idx prev_end_idx = end_idx
next_idx_by_modality[modality] += 1 next_idx_by_modality[modality] += 1
...@@ -443,92 +440,104 @@ def _replace_matches( ...@@ -443,92 +440,104 @@ def _replace_matches(
def replace_token_matches( def replace_token_matches(
prompt: list[int], prompt: list[int],
matches: Sequence[_PromptReplacementMatch[_T, list[int]]], matches: Sequence[_PromptReplacementTokenMatch],
mm_items_by_modality: Mapping[str, list[_T]], mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
) -> list[int]: ) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`.""" """Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches: if not matches:
return prompt return prompt
token_id_seqs = _replace_matches( token_id_seqs = _replace_matches(prompt, matches, mm_items)
prompt,
matches,
mm_items_by_modality,
hf_inputs,
)
return flatten_2d_lists(token_id_seqs) return flatten_2d_lists(token_id_seqs)
def replace_text_matches( def replace_text_matches(
prompt: str, prompt: str,
matches: Sequence[_PromptReplacementMatch[_T, str]], matches: Sequence[_PromptReplacementTextMatch],
mm_items_by_modality: Mapping[str, list[_T]], mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
) -> str: ) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`.""" """Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches: if not matches:
return prompt return prompt
texts = _replace_matches( texts = _replace_matches(prompt, matches, mm_items)
prompt,
matches,
mm_items_by_modality,
hf_inputs,
)
return "".join(texts) return "".join(texts)
def _merge_placeholder_matches( def _iter_modality_placeholders(
matches: Iterable[_PromptReplacementTokenMatch], prompt: list[int],
) -> Iterable[_PromptReplacementTokenMatch]: modality: str,
current_match = None modality_repls: Sequence[_BoundPromptReplacement],
modal_items: list[Any],
) -> Iterable[_PlaceholderInfo]:
if len(modal_items) == 0:
return
prompt_len = len(prompt)
item_index = 0
start_idx = 0
while start_idx < prompt_len:
found = False
for match in sorted(matches, key=lambda x: x.start_idx): for repl_info in modality_repls:
if current_match is None: replacement = repl_info.get_replacement(item_index)
current_match = match repl_tokens = replacement.token_ids
elif (current_match.prompt_repl == match.prompt_repl repl_len = len(repl_tokens)
and current_match.end_idx == match.start_idx): end_idx = start_idx + repl_len
current_match = _PromptReplacementTokenMatch(
current_match.prompt_repl, if repl_len == 0 or end_idx > prompt_len:
match=_TokenMatch(current_match.start_idx, match.end_idx), continue
if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo(
modality=modality,
start_idx=start_idx,
replacement=repl_tokens,
) )
else:
yield current_match
current_match = match
if current_match is not None: item_index += 1
yield current_match if item_index >= len(modal_items):
return
# Exclude overlapping matches
start_idx = end_idx
found = True
break
if not found:
start_idx += 1
def iter_placeholders( def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement[Any]], prompt_repls: Sequence[_BoundPromptReplacement],
prompt: list[int], prompt: list[int],
*, mm_items: MultiModalDataItems,
min_unit_count: int = 1,
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[_PlaceholderInfo]:
"""Yield each set of placeholder tokens found in :code:`token_ids`.""" """
if min_unit_count <= 0: Yield each set of placeholder tokens found in :code:`prompt`.
raise ValueError("`min_unit_count` must be a positive integer")
matches = (_PromptReplacementTokenMatch(prompt_repl, match) Note that empty matches are ignored.
for prompt_repl in prompt_repls """
if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 repls_by_modality = dict(full_groupby_modality(prompt_repls))
for match in iter_token_matches(prompt, repl_unit))
for modality, modal_items in mm_items.items():
for match in _merge_placeholder_matches(matches): if modality in repls_by_modality:
unit = match.repl_unit yield from _iter_modality_placeholders(
placeholder = _PlaceholderInfo( prompt,
modality=match.modality, modality,
start_idx=match.start_idx, repls_by_modality[modality],
unit=unit, modal_items,
unit_count=(match.end_idx - match.start_idx) // len(unit),
) )
if placeholder.unit_count >= min_unit_count:
yield placeholder class ProcessorInputs(NamedTuple):
"""Keyword arguments to :meth:`BaseMultiModalProcessor`"""
prompt_text: str
mm_data: MultiModalDataDict
mm_processor_kwargs: Mapping[str, object]
class BaseMultiModalProcessor(ABC): class BaseMultiModalProcessor(ABC):
...@@ -536,52 +545,55 @@ class BaseMultiModalProcessor(ABC): ...@@ -536,52 +545,55 @@ class BaseMultiModalProcessor(ABC):
Abstract base class to process multi-modal inputs to be used in vLLM. Abstract base class to process multi-modal inputs to be used in vLLM.
""" """
def __init__( def __init__(self, ctx: InputProcessingContext) -> None:
self,
ctx: InputProcessingContext,
metadata: MultiModalProcessingMetadata,
) -> None:
super().__init__() super().__init__()
self.ctx = ctx self.ctx = ctx
self.metadata = metadata
self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs
or {})
def _get_hf_processor( def __call__(
self, self,
**mm_processor_kwargs: Mapping[str, object], prompt: str,
) -> ProcessorMixin: mm_data: MultiModalDataDict,
# by default, we won't pass any kwargs to the processor initialization mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
def _get_hf_processor(self) -> ProcessorMixin:
"""
Subclasses can add keyword arguments to this method to accept
additional kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor() return self.ctx.get_hf_processor()
def _get_tokenizer(self) -> AnyTokenizer: def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def __call__( @abstractmethod
def _get_prompt_replacements(
self, self,
prompt: str, mm_items: MultiModalDataItems,
mm_data: MultiModalDataDict, hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object], mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> list[PromptReplacement]:
return self.apply(prompt, mm_data, mm_processor_kwargs) """
Given the original multi-modal items for this modality
and HF-processed data, output the replacements to perform.
Note:
Even when the HF processor already performs replacement for us,
we still use this replacement information to determine
the placeholder token positions for each multi-modal item.
"""
raise NotImplementedError
def _find_placeholders( def _find_placeholders(
self, self,
all_prompt_repls: Sequence[_BoundPromptReplacement[Any]], all_prompt_repls: Sequence[_BoundPromptReplacement],
new_token_ids: list[int], new_token_ids: list[int],
*, mm_items: MultiModalDataItems,
# To avoid false positives from multi-input when detecting
# whether placeholder tokens have been inserted, in case
# the target sequence is a subset of the replacement tokens
min_unit_count: int = 16,
) -> list[_PlaceholderInfo]: ) -> list[_PlaceholderInfo]:
return list( return list(
iter_placeholders( iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
all_prompt_repls,
new_token_ids,
min_unit_count=min_unit_count,
))
def _apply_hf_processor( def _apply_hf_processor(
self, self,
...@@ -589,13 +601,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -589,13 +601,7 @@ class BaseMultiModalProcessor(ABC):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object], mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization hf_processor = self._get_hf_processor(**mm_processor_kwargs)
# instead of processor call
processor_init_kwargs = {
**self.init_mm_processor_kwargs,
**mm_processor_kwargs,
}
hf_processor = self._get_hf_processor(**processor_init_kwargs)
processor_data = dict[str, Any]() processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, Any]()
...@@ -615,11 +621,10 @@ class BaseMultiModalProcessor(ABC): ...@@ -615,11 +621,10 @@ class BaseMultiModalProcessor(ABC):
else: else:
processor_data[k] = v processor_data[k] = v
# filter mm_processor_kwargs used in processor call assert callable(hf_processor)
mm_processor_kwargs = resolve_mm_processor_kwargs( mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
self.init_mm_processor_kwargs,
cast(Dict[str, Any], mm_processor_kwargs),
hf_processor, hf_processor,
mm_processor_kwargs,
) )
try: try:
...@@ -642,26 +647,21 @@ class BaseMultiModalProcessor(ABC): ...@@ -642,26 +647,21 @@ class BaseMultiModalProcessor(ABC):
def _bind_prompt_replacements( def _bind_prompt_replacements(
self, self,
mm_data: MultiModalDataDict, prompt_repls: list[PromptReplacement],
) -> list[_BoundPromptReplacement[Any]]: ) -> list[_BoundPromptReplacement]:
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
return [ return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
prompt_repl.bind(modality, tokenizer)
for modality, metadata in self.metadata.items()
if modality in mm_data for prompt_repl in metadata.prompt_repls
]
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
token_ids: list[int], token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[Any]], prompt_repls: Sequence[_BoundPromptReplacement],
) -> tuple[list[int], str, list[_PlaceholderInfo]]: ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
mm_items = to_multi_format(mm_data)
token_matches = find_token_matches(token_ids, prompt_repls) token_matches = find_token_matches(token_ids, prompt_repls)
# If the search text does not represent a special token, # If the search text does not represent a special token,
...@@ -682,7 +682,6 @@ class BaseMultiModalProcessor(ABC): ...@@ -682,7 +682,6 @@ class BaseMultiModalProcessor(ABC):
token_ids, token_ids,
token_matches, token_matches,
mm_items, mm_items,
hf_inputs,
) )
text = _decode(tokenizer, token_ids) text = _decode(tokenizer, token_ids)
...@@ -695,13 +694,13 @@ class BaseMultiModalProcessor(ABC): ...@@ -695,13 +694,13 @@ class BaseMultiModalProcessor(ABC):
text, text,
text_matches, text_matches,
mm_items, mm_items,
hf_inputs,
) )
token_ids = _encode(tokenizer, text) token_ids = _encode(tokenizer, text)
matched_repls = [match.prompt_repl for match in text_matches] matched_repls = [match.prompt_repl for match in text_matches]
placeholders = self._find_placeholders(matched_repls, token_ids) placeholders = self._find_placeholders(matched_repls, token_ids,
mm_items)
return token_ids, text, placeholders return token_ids, text, placeholders
...@@ -731,12 +730,16 @@ class BaseMultiModalProcessor(ABC): ...@@ -731,12 +730,16 @@ class BaseMultiModalProcessor(ABC):
prompt_ids, = hf_inputs.pop("input_ids").tolist() prompt_ids, = hf_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(hf_inputs) mm_kwargs = MultiModalKwargs(hf_inputs)
all_prompt_repls = self._bind_prompt_replacements(mm_data) mm_items = to_multi_format(mm_data)
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
mm_processor_kwargs)
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
# If HF processor already inserts placeholder tokens, # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them # there is no need for us to insert them
all_placeholders = self._find_placeholders(all_prompt_repls, all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids) prompt_ids, mm_items)
if all_placeholders: if all_placeholders:
prompt_text = _decode(tokenizer, prompt_ids) prompt_text = _decode(tokenizer, prompt_ids)
else: else:
...@@ -745,7 +748,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -745,7 +748,7 @@ class BaseMultiModalProcessor(ABC):
prompt_text, prompt_text,
all_placeholders, all_placeholders,
) = self._apply_prompt_replacements( ) = self._apply_prompt_replacements(
mm_data, mm_items,
hf_inputs, hf_inputs,
prompt_ids, prompt_ids,
all_prompt_repls, all_prompt_repls,
...@@ -765,13 +768,13 @@ class BaseMultiModalProcessor(ABC): ...@@ -765,13 +768,13 @@ class BaseMultiModalProcessor(ABC):
) )
@abstractmethod @abstractmethod
def _get_dummy_mm_kwargs( def _get_dummy_mm_inputs(
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalKwargs: ) -> ProcessorInputs:
""" """
Build the input that corresponds to `mm_max_tokens` in Build the multi-modal portion of the input which, after processing,
:meth:`get_dummy_data`. results in `mm_max_tokens` in :meth:`get_dummy_data`.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -784,38 +787,41 @@ class BaseMultiModalProcessor(ABC): ...@@ -784,38 +787,41 @@ class BaseMultiModalProcessor(ABC):
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
tokenizer = self._get_tokenizer() processor_inputs = self._get_dummy_mm_inputs(mm_counts)
mm_inputs = self.apply(*processor_inputs)
mm_placeholders = dict[str, _PlaceholderInfo]()
offset = 0 prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
for modality, max_tokens in mm_max_tokens.items():
if max_tokens == 0: total_placeholders_by_modality = dict[str, int]()
continue for modality, placeholders in placeholders_by_modality.items():
num_placeholders = sum(item["length"] for item in placeholders)
metadata = self.metadata[modality] max_tokens = mm_max_tokens[modality]
repl = metadata.prompt_repls[0].bind(modality, tokenizer)
repl_token_ids = repl.repl_unit.token_ids if num_placeholders != max_tokens:
logger.warning(
placeholders = _PlaceholderInfo( "The processed dummy data has a total of %d placeholder "
modality=modality, "tokens for the '%s' modality, which is not the expected "
start_idx=offset, "%d tokens.", num_placeholders, modality, max_tokens)
unit=repl_token_ids,
unit_count=max_tokens // len(repl_token_ids), total_placeholders_by_modality[modality] = num_placeholders
)
total_len = len(prompt_token_ids)
mm_placeholders[modality] = placeholders if total_len > seq_len:
offset += placeholders.length logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
prompt_token_ids = flatten_2d_lists(
[p.unit * p.unit_count for p in mm_placeholders.values()])
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData( return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids), seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=self._get_dummy_mm_kwargs(mm_counts), multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders={ multi_modal_placeholders=placeholders_by_modality,
modality: [p.to_range()]
for modality, p in mm_placeholders.items()
},
) )
...@@ -299,9 +299,9 @@ class MultiModalRegistry: ...@@ -299,9 +299,9 @@ class MultiModalRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._processor_factories: if self._processor_factories.contains(model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already has an input mapper " "Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
model_cls, self) model_cls, self)
......
...@@ -1370,8 +1370,8 @@ def supports_kw( ...@@ -1370,8 +1370,8 @@ def supports_kw(
def resolve_mm_processor_kwargs( def resolve_mm_processor_kwargs(
init_kwargs: Optional[Dict[str, Any]], init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Dict[str, Any]], inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object], callable: Callable[..., object],
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs( ...@@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Dict[str, Any]], overrides: Optional[Mapping[str, object]],
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
...@@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]): ...@@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]):
raise KeyError(key) raise KeyError(key)
def __contains__(self, key: object) -> bool: def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type): if not isinstance(key, type):
return False return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro()) return any(cls in self.data for cls in key.mro())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment