Unverified Commit 93abf23a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Fully dynamic prompt replacement in merged input processor (#11199)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9c3dadd1
...@@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str): ...@@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
# max_model_len (128k) for this model may cause OOM. # max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs. # You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
# num_crops is an override kwarg to the multimodal image processor; # num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame. # to use 16 for single frame scenarios, and 4 for multi-frame.
...@@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str): ...@@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM( llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct", model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
......
...@@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"] ...@@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection # Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture() @pytest.fixture()
def processor_for_phi3v(): def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VProcessor return Phi3VMultiModalProcessor
@pytest.fixture() @pytest.fixture()
......
from typing import cast from typing import cast
import pytest import pytest
from transformers import BatchFeature
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
find_text_matches, find_token_matches, _PlaceholderInfo, find_text_matches,
iter_placeholders, iter_token_matches, find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches, replace_text_matches,
replace_token_matches) replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -16,7 +16,7 @@ from vllm.utils import full_groupby ...@@ -16,7 +16,7 @@ from vllm.utils import full_groupby
@pytest.mark.parametrize( @pytest.mark.parametrize(
("token_ids", "match_ids", "expected"), ("token_ids", "match_ids", "expected"),
[ [
([], [], [{ "start_idx": 0, "end_idx": 0 }]), ([], [], []),
([], [32000], []), ([], [32000], []),
( (
[32000, 32000, 32000], [32000, 32000, 32000],
...@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_2": [32000], "pattern_2": [32000],
}, },
{ {
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }], "pattern_1": [],
"pattern_2": [], "pattern_2": [],
} }
), ),
...@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): ...@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer) PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
result = find_token_matches(prompt, prompt_repls) result = find_token_matches(prompt, prompt_repls)
...@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer) PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
result = find_text_matches(prompt, prompt_repls) result = find_text_matches(prompt, prompt_repls)
...@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
"pattern_3": "!", "pattern_3": "!",
}, },
{ {
# Test whether target is confused with repl_unit # Test whether target is confused with replacement
"pattern_1": ("<image><image>", 1), "pattern_1": "<image><image>",
# Test empty repl_unit # Test empty replacement
"pattern_2": ("", 1), "pattern_2": "",
# Test multiple repl_count # Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": ("?", 2), "pattern_3": "?!?",
}, },
), ),
] ]
...@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
("mm_count", "expected"), ("mm_count", "expected"),
[ [
(0, "Image:<image>Image:<image><image>!"), (0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"), (1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>??"), (2, "<image><image><image><image><image>?!?"),
] ]
) )
# yapf: enable # yapf: enable
...@@ -306,7 +306,7 @@ def test_find_replace_text( ...@@ -306,7 +306,7 @@ def test_find_replace_text(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
matches = find_text_matches(prompt, prompt_repls) matches = find_text_matches(prompt, prompt_repls)
...@@ -314,9 +314,8 @@ def test_find_replace_text( ...@@ -314,9 +314,8 @@ def test_find_replace_text(
result = replace_text_matches( result = replace_text_matches(
prompt, prompt,
matches, matches,
{key: list(range(mm_count)) MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}, for key in repl_by_key}),
BatchFeature(),
) )
# Only displayed on error # Only displayed on error
...@@ -343,12 +342,12 @@ def test_find_replace_text( ...@@ -343,12 +342,12 @@ def test_find_replace_text(
"pattern_3": [918], "pattern_3": [918],
}, },
{ {
# Test whether target is confused with repl_unit # Test whether target is confused with replacement
"pattern_1": ([32000, 32000], 1), "pattern_1": [32000, 32000],
# Test empty repl_unit # Test empty replacement
"pattern_2": ([], 1), "pattern_2": [],
# Test multiple repl_count # Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": ([1550], 2), "pattern_3": [1550, 918, 1550],
}, },
), ),
] ]
...@@ -357,8 +356,8 @@ def test_find_replace_text( ...@@ -357,8 +356,8 @@ def test_find_replace_text(
("mm_count", "expected"), ("mm_count", "expected"),
[ [
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
] ]
) )
# yapf: enable # yapf: enable
...@@ -373,7 +372,7 @@ def test_find_replace_tokens( ...@@ -373,7 +372,7 @@ def test_find_replace_tokens(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
matches = find_token_matches(prompt, prompt_repls) matches = find_token_matches(prompt, prompt_repls)
...@@ -381,9 +380,8 @@ def test_find_replace_tokens( ...@@ -381,9 +380,8 @@ def test_find_replace_tokens(
result = replace_token_matches( result = replace_token_matches(
prompt, prompt,
matches, matches,
{key: list(range(mm_count)) MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}, for key in repl_by_key}),
BatchFeature(),
) )
# Only displayed on error # Only displayed on error
...@@ -399,9 +397,9 @@ def test_find_replace_tokens( ...@@ -399,9 +397,9 @@ def test_find_replace_tokens(
"repl_by_key", "repl_by_key",
[ [
{ {
"pattern_1": ([32000, 32000], 1), "pattern_1": [32000, 32000],
"pattern_2": ([], 1), "pattern_2": [],
"pattern_3": ([1550], 2), "pattern_3": [1550, 918, 1550],
}, },
], ],
) )
...@@ -414,48 +412,47 @@ def test_find_replace_tokens( ...@@ -414,48 +412,47 @@ def test_find_replace_tokens(
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=6, start_idx=6,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=1,
), ),
], ],
), ),
( (
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[ [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=1, start_idx=1,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=1,
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=5, start_idx=5,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=1,
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
start_idx=7, start_idx=7,
unit=[1550], replacement=[1550, 918, 1550],
unit_count=2,
), ),
], ],
), ),
( (
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[ [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
start_idx=1, start_idx=1,
unit=[32000, 32000], replacement=[32000, 32000],
unit_count=2, ),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
start_idx=6, start_idx=6,
unit=[1550], replacement=[1550, 918, 1550],
unit_count=2,
), ),
], ],
), ),
...@@ -470,11 +467,17 @@ def test_iter_placeholders( ...@@ -470,11 +467,17 @@ def test_iter_placeholders(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer) PromptReplacement(key, [], repl).bind(mock_tokenizer)
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
] ]
result = list(iter_placeholders(prompt_repls, prompt)) result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
))
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
......
...@@ -3,14 +3,14 @@ from typing import Optional ...@@ -3,14 +3,14 @@ from typing import Optional
import torch import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaProcessor, LlavaMultiModalProcessor,
get_max_llava_image_tokens) get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration): class MyLlava(LlavaForConditionalGeneration):
def compute_logits( def compute_logits(
......
...@@ -2,7 +2,7 @@ import functools ...@@ -2,7 +2,7 @@ import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type, cast) Optional, Protocol, Type)
from torch import nn from torch import nn
from transformers import PretrainedConfig, ProcessorMixin from transformers import PretrainedConfig, ProcessorMixin
...@@ -47,7 +47,6 @@ class InputContext: ...@@ -47,7 +47,6 @@ class InputContext:
Raises: Raises:
TypeError: If the model is not of the specified type. TypeError: If the model is not of the specified type.
""" """
hf_config = self.model_config.hf_config hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type): if not isinstance(hf_config, hf_config_type):
raise TypeError("Invalid type of HuggingFace config. " raise TypeError("Invalid type of HuggingFace config. "
...@@ -60,21 +59,70 @@ class InputContext: ...@@ -60,21 +59,70 @@ class InputContext:
""" """
Get the HuggingFace image processor configuration of the model. Get the HuggingFace image processor configuration of the model.
""" """
return self.model_config.hf_image_processor_config return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
@dataclass(frozen=True) @dataclass(frozen=True)
class InputProcessingContext(InputContext): class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs) -> ProcessorMixin: def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor( return cached_get_processor(
self.model_config.tokenizer, self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
**kwargs) **merged_kwargs,
)
def resolve_hf_processor_call_kwargs(
self,
hf_processor: ProcessorMixin,
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
return resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
)
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
...@@ -171,7 +219,8 @@ class InputRegistry: ...@@ -171,7 +219,8 @@ class InputRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type: if self._dummy_factories_by_model_type.contains(model_cls,
strict=True):
logger.warning( logger.warning(
"Model class %s already has dummy data " "Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -195,7 +244,8 @@ class InputRegistry: ...@@ -195,7 +244,8 @@ class InputRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type: if self._dummy_encoder_factories_by_model_type.contains(
model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already has dummy encoder data " "Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -305,7 +355,8 @@ class InputRegistry: ...@@ -305,7 +355,8 @@ class InputRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors_by_model_type: if self._input_processors_by_model_type.contains(model_cls,
strict=True):
logger.warning( logger.warning(
"Model class %s already has input processor " "Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -357,7 +408,7 @@ class InputRegistry: ...@@ -357,7 +408,7 @@ class InputRegistry:
# If it's empty, it'll fall back to the default kwarg values # If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs( mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")), inputs.get("mm_processor_kwargs", {}), # type: ignore
processor, processor,
) )
......
...@@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, ...@@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig, PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig) ProcessorMixin, SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext, MultiModalDataItems, ProcessorInputs,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, ...@@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens) get_max_clip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
get_max_pixtral_hf_image_tokens) get_max_pixtral_hf_image_tokens,
get_pixtral_hf_image_feature_size)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
get_max_siglip_image_tokens) get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
...@@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext): ...@@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}") raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext, class LlavaMultiModalProcessor(BaseMultiModalProcessor):
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
def create_metadata_for_llava(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
hf_config = ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
def get_repl_count(
mm_items: list[Image],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
return get_max_llava_image_tokens(ctx)
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[image_token_id],
repl_unit=[image_token_id],
repl_count=get_repl_count),
]),
}
class LlavaProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
)
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False): if getattr(hf_processor, "__is_patched__", False):
...@@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor): ...@@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor):
hf_processor.__is_patched__ = True # type: ignore hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> ProcessorMixin: def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor() hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor): if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor) self._patch_pixtral_processor(hf_processor)
return hf_processor return hf_processor
def _get_dummy_mm_kwargs( def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
if isinstance(processor, PixtralProcessor):
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
) = get_pixtral_hf_image_feature_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
tokens = ([image_token] * num_width_tokens +
[image_break_token]) * num_height_tokens
tokens[-1] = image_end_token
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_pixtral,
),
]
max_image_tokens = get_max_llava_image_tokens(self.ctx)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
]
def _get_dummy_mm_inputs(
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalKwargs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig) hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
...@@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor): ...@@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor):
raise NotImplementedError(msg) raise NotImplementedError(msg)
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore image_token = hf_processor.image_token
hf_inputs = image_processor.preprocess(data['image'],
return_tensors="pt")
return MultiModalKwargs(**hf_inputs) return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
class LlavaLikeConfig(Protocol): class LlavaLikeConfig(Protocol):
...@@ -303,7 +303,7 @@ def init_vision_tower_for_llava( ...@@ -303,7 +303,7 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes # BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
...@@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return loader.load_weights(weights) return loader.load_weights(weights)
class MantisProcessor(LlavaProcessor): class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin: def _get_hf_processor(self) -> ProcessorMixin:
try: try:
...@@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor): ...@@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor):
# To use this model, please use # To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration): class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass pass
...@@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict, MultiModalDataDict,
MultiModalProcessingMetadata, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
def get_max_phi3v_image_tokens(ctx: InputContext, def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
*, processor = ctx.get_hf_processor()
num_crops: Optional[int] = None): image_processor = processor.image_processor # type: ignore
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
model_config = ctx.model_config return image_processor.calc_num_image_tokens_from_image_size(
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH, width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return num_tokens
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}
class Phi3VProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None: class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
def _get_hf_processor( def _get_hf_processor(
self, self,
...@@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor): ...@@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor):
processed_outputs['input_ids'] = token_ids processed_outputs['input_ids'] = token_ids
return processed_outputs return processed_outputs
def _get_dummy_mm_kwargs( def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
]
def _get_dummy_mm_inputs(
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalKwargs: ) -> ProcessorInputs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts) num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ...@@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img image_token_id = mm_encoder.special_ids.img
mm_config = ctx.model_config.multimodal_config mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1) num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size # dummy size
......
...@@ -99,7 +99,7 @@ class MultiModalPlugin(ABC): ...@@ -99,7 +99,7 @@ class MultiModalPlugin(ABC):
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers: if self._input_mappers.contains(model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already has an input mapper " "Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -194,7 +194,7 @@ class MultiModalPlugin(ABC): ...@@ -194,7 +194,7 @@ class MultiModalPlugin(ABC):
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens: if self._max_mm_tokens.contains(model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already calculates maximum number of " "Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.", "tokens in %s. It is overwritten by the new one.",
......
This diff is collapsed.
...@@ -299,9 +299,9 @@ class MultiModalRegistry: ...@@ -299,9 +299,9 @@ class MultiModalRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if model_cls in self._processor_factories: if self._processor_factories.contains(model_cls, strict=True):
logger.warning( logger.warning(
"Model class %s already has an input mapper " "Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
model_cls, self) model_cls, self)
......
...@@ -1370,8 +1370,8 @@ def supports_kw( ...@@ -1370,8 +1370,8 @@ def supports_kw(
def resolve_mm_processor_kwargs( def resolve_mm_processor_kwargs(
init_kwargs: Optional[Dict[str, Any]], init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Dict[str, Any]], inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object], callable: Callable[..., object],
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs( ...@@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Dict[str, Any]], overrides: Optional[Mapping[str, object]],
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
...@@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]): ...@@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]):
raise KeyError(key) raise KeyError(key)
def __contains__(self, key: object) -> bool: def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type): if not isinstance(key, type):
return False return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro()) return any(cls in self.data for cls in key.mro())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment