Unverified Commit 93abf23a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Fully dynamic prompt replacement in merged input processor (#11199)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9c3dadd1
......@@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame.
......@@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
......
......@@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor
return Phi3VProcessor
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor
@pytest.fixture()
......
from typing import cast
import pytest
from transformers import BatchFeature
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -16,7 +16,7 @@ from vllm.utils import full_groupby
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
[
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
([], [], []),
([], [32000], []),
(
[32000, 32000, 32000],
......@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_2": [32000],
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_1": [],
"pattern_2": [],
}
),
......@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
......@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
......@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
"pattern_3": "!",
},
{
# Test whether target is confused with repl_unit
"pattern_1": ("<image><image>", 1),
# Test empty repl_unit
"pattern_2": ("", 1),
# Test multiple repl_count
"pattern_3": ("?", 2),
# Test whether target is confused with replacement
"pattern_1": "<image><image>",
# Test empty replacement
"pattern_2": "",
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
),
]
......@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"),
(2, "<image><image><image><image><image>??"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
......@@ -306,7 +306,7 @@ def test_find_replace_text(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_text_matches(prompt, prompt_repls)
......@@ -314,9 +314,8 @@ def test_find_replace_text(
result = replace_text_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)
# Only displayed on error
......@@ -343,12 +342,12 @@ def test_find_replace_text(
"pattern_3": [918],
},
{
# Test whether target is confused with repl_unit
"pattern_1": ([32000, 32000], 1),
# Test empty repl_unit
"pattern_2": ([], 1),
# Test multiple repl_count
"pattern_3": ([1550], 2),
# Test whether target is confused with replacement
"pattern_1": [32000, 32000],
# Test empty replacement
"pattern_2": [],
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
),
]
......@@ -357,8 +356,8 @@ def test_find_replace_text(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
......@@ -373,7 +372,7 @@ def test_find_replace_tokens(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_token_matches(prompt, prompt_repls)
......@@ -381,9 +380,8 @@ def test_find_replace_tokens(
result = replace_token_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)
# Only displayed on error
......@@ -399,9 +397,9 @@ def test_find_replace_tokens(
"repl_by_key",
[
{
"pattern_1": ([32000, 32000], 1),
"pattern_2": ([], 1),
"pattern_3": ([1550], 2),
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
},
],
)
......@@ -414,48 +412,47 @@ def test_find_replace_tokens(
_PlaceholderInfo(
modality="pattern_1",
start_idx=6,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
],
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=5,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=7,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=2,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=6,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
......@@ -470,11 +467,17 @@ def test_iter_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer)
PromptReplacement(key, [], repl).bind(mock_tokenizer)
for key, repl in repl_by_key.items()
]
result = list(iter_placeholders(prompt_repls, prompt))
result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
))
# Only displayed on error
print("result:", result)
......
......@@ -3,14 +3,14 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaProcessor,
LlavaMultiModalProcessor,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(
......
......@@ -2,7 +2,7 @@ import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type, cast)
Optional, Protocol, Type)
from torch import nn
from transformers import PretrainedConfig, ProcessorMixin
......@@ -47,7 +47,6 @@ class InputContext:
Raises:
TypeError: If the model is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
raise TypeError("Invalid type of HuggingFace config. "
......@@ -60,21 +59,70 @@ class InputContext:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs) -> ProcessorMixin:
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.tokenizer,
self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
**kwargs)
**merged_kwargs,
)
def resolve_hf_processor_call_kwargs(
self,
hf_processor: ProcessorMixin,
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
return resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
)
N = TypeVar("N", bound=Type[nn.Module])
......@@ -171,7 +219,8 @@ class InputRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
if self._dummy_factories_by_model_type.contains(model_cls,
strict=True):
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
......@@ -195,7 +244,8 @@ class InputRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type:
if self._dummy_encoder_factories_by_model_type.contains(
model_cls, strict=True):
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
......@@ -305,7 +355,8 @@ class InputRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors_by_model_type:
if self._input_processors_by_model_type.contains(model_cls,
strict=True):
logger.warning(
"Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.",
......@@ -357,7 +408,7 @@ class InputRegistry:
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
inputs.get("mm_processor_kwargs", {}), # type: ignore
processor,
)
......
......@@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
import torch
import torch.nn as nn
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
......@@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
......@@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
get_max_pixtral_hf_image_tokens)
get_max_pixtral_hf_image_tokens,
get_pixtral_hf_image_feature_size)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
......@@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
def create_metadata_for_llava(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
hf_config = ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
def get_repl_count(
mm_items: list[Image],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
return get_max_llava_image_tokens(ctx)
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[image_token_id],
repl_unit=[image_token_id],
repl_count=get_repl_count),
]),
}
class LlavaProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
)
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
......@@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor):
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> ProcessorMixin:
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
return hf_processor
def _get_dummy_mm_kwargs(
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
if isinstance(processor, PixtralProcessor):
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
) = get_pixtral_hf_image_feature_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
tokens = ([image_token] * num_width_tokens +
[image_break_token]) * num_height_tokens
tokens[-1] = image_end_token
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_pixtral,
),
]
max_image_tokens = get_max_llava_image_tokens(self.ctx)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
......@@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor):
raise NotImplementedError(msg)
hf_processor = self._get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'],
return_tensors="pt")
image_token = hf_processor.image_token
return MultiModalKwargs(**hf_inputs)
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
class LlavaLikeConfig(Protocol):
......@@ -303,7 +303,7 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
......@@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return loader.load_weights(weights)
class MantisProcessor(LlavaProcessor):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
......@@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass
......@@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict,
MultiModalProcessingMetadata,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
processor = ctx.get_hf_processor()
image_processor = processor.image_processor # type: ignore
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
return image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return num_tokens
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}
class Phi3VProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(
self,
......@@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor):
processed_outputs['input_ids'] = token_ids
return processed_outputs
def _get_dummy_mm_kwargs(
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
) -> ProcessorInputs:
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.model_config.multimodal_config
mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size
......
......@@ -99,7 +99,7 @@ class MultiModalPlugin(ABC):
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers:
if self._input_mappers.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
......@@ -194,7 +194,7 @@ class MultiModalPlugin(ABC):
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens:
if self._max_mm_tokens.contains(model_cls, strict=True):
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
......
This diff is collapsed.
......@@ -299,9 +299,9 @@ class MultiModalRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._processor_factories:
if self._processor_factories.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has an input mapper "
"Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
......
......@@ -1370,8 +1370,8 @@ def supports_kw(
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Dict[str, Any]],
inference_kwargs: Optional[Dict[str, Any]],
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
......@@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Dict[str, Any]],
overrides: Optional[Mapping[str, object]],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""
......@@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]):
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment