Unverified Commit 2a0596bc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Reorganize profiling/processing-related code (#11812)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f1214117
......@@ -4,24 +4,17 @@ from functools import partial
import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_next():
from vllm.model_executor.models.llava_next import (
LlavaNextMultiModalProcessor)
return LlavaNextMultiModalProcessor
def _validate_image_prompt_replacements_one(
processor,
processor: BaseMultiModalProcessor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
......@@ -78,20 +71,17 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
......@@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
"Comment this out to run it manually.")
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_all(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()
......
......@@ -4,24 +4,17 @@ from functools import partial
import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_onevision():
from vllm.model_executor.models.llava_onevision import (
LlavaOnevisionMultiModalProcessor)
return LlavaOnevisionMultiModalProcessor
def _validate_image_prompt_replacements_one(
processor,
processor: BaseMultiModalProcessor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
......@@ -77,20 +70,17 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
......@@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_all(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()
......
"""Tests for phi3v's multimodal preprocessing kwargs."""
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import _ImageAssets
from ....utils import build_model_context
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable
@pytest.mark.parametrize(
......@@ -29,7 +21,6 @@ def processor_for_phi3v():
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_phi3v,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
......@@ -37,21 +28,26 @@ def test_processor_override(
num_imgs: int,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size
......
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import _ImageAssets
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize(
......@@ -24,7 +17,6 @@ def processor_for_qwen2_vl():
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
......@@ -39,18 +31,20 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
......
......@@ -10,12 +10,17 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_mm_placeholders,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
PromptReplacement,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
......@@ -431,7 +436,7 @@ def test_find_replace_tokens(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=6,
......@@ -445,13 +450,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=5,
......@@ -459,7 +464,7 @@ def test_find_replace_tokens(
),
],
"pattern_3": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=7,
......@@ -472,13 +477,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=3,
......@@ -486,7 +491,7 @@ def test_find_replace_tokens(
),
],
"pattern_3": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=6,
......@@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
processor = MULTIMODAL_REGISTRY.create_processor(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info
profiler = MultiModalProfiler(processor)
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
profiler.get_supported_mm_limits = mock_supported_mm_limits
processor.info.get_supported_mm_limits = mock_supported_mm_limits
if is_valid:
exc_ctx = nullcontext()
......@@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
profiler.get_mm_limits()
profiler.get_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
......@@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
processor = MULTIMODAL_REGISTRY.create_processor(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0:
......@@ -681,9 +678,9 @@ def _test_processing_cache_correctness(
hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
......@@ -691,8 +688,9 @@ def _test_processing_cache_correctness(
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
baseline_processor = processor_factory(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache)
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
rng = np.random.RandomState(0)
......@@ -724,7 +722,7 @@ def _test_processing_cache_correctness(
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
......
......@@ -2,13 +2,17 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaMultiModalProcessor)
from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
LlavaProcessingInfo)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(
......
......@@ -7,7 +7,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_info_once, print_warning_once
......
......@@ -323,6 +323,7 @@ class InputRegistry:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
if mm_registry.has_processor(model_config):
......@@ -331,7 +332,8 @@ class InputRegistry:
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer)
dummy_data = processor.get_dummy_data(seq_len)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
......
......@@ -23,10 +23,10 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
......@@ -445,33 +445,33 @@ def build_mm_projector(config: PretrainedConfig):
)
class AriaProcessingMixin(ProcessingMixin):
class AriaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config()
def _get_vision_config(self) -> AriaVisionConfig:
return self._get_hf_config().vision_config
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
def get_vision_config(self) -> AriaVisionConfig:
return self.get_hf_config().vision_config
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
vision_config = self._get_vision_config()
vision_config = self.info.get_vision_config()
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
......@@ -483,7 +483,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
num_images=num_images)
}
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore
return ProcessorInputs(
......@@ -492,10 +492,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
)
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return AriaProfilingInfo(self.ctx)
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -513,10 +510,10 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
num_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.info.get_num_image_tokens()
return [
PromptReplacement(
......@@ -527,7 +524,9 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
]
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
info=AriaProcessingInfo,
dummy_inputs=AriaDummyInputsBuilder)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
Aria model for conditional generation tasks.
......
......@@ -17,10 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
......@@ -397,30 +397,30 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
class Blip2ProcessingMixin(ProcessingMixin):
class Blip2ProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return hf_config.num_query_tokens
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
......@@ -439,10 +439,7 @@ class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
)
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -460,7 +457,7 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.info.get_num_image_tokens()
return [
PromptReplacement(
......@@ -491,7 +488,9 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -30,10 +30,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
......@@ -49,33 +49,34 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonProcessingMixin(ProcessingMixin):
class ChameleonProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor()
return processor.image_seq_length
class ChameleonDummyInputsBuilder(
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self._get_hf_config()
config = self.info.get_hf_config()
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
......@@ -93,11 +94,8 @@ class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
)
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -112,7 +110,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
return [
PromptReplacement(
......@@ -120,7 +118,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_token * self.info.get_num_image_tokens(),
processor.image_end_token,
]),
)
......@@ -916,7 +914,10 @@ class ChameleonModel(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
ChameleonMultiModalProcessor,
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
......
......@@ -33,11 +33,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
......@@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict):
"""
class FuyuProcessingMixin(ProcessingMixin):
class FuyuProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(FuyuConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(FuyuProcessor)
def _get_image_processor(self) -> FuyuImageProcessor:
return self._get_hf_processor().image_processor
def get_image_processor(self) -> FuyuImageProcessor:
return self.get_hf_processor().image_processor
def _get_image_feature_grid_size(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
image_processor = self._get_image_processor()
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
......@@ -97,34 +111,21 @@ class FuyuProcessingMixin(ProcessingMixin):
nrows = math.ceil(image_height / 30)
return ncols, nrows
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
image_processor = self._get_image_processor()
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"])
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
......@@ -140,10 +141,7 @@ class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
)
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return FuyuProfilingInfo(self.ctx)
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def _call_hf_processor(
self,
......@@ -156,7 +154,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
......@@ -196,10 +194,10 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
......@@ -207,7 +205,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self._get_image_feature_grid_size(
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
......@@ -244,7 +242,9 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
from abc import ABC, abstractmethod
from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
......@@ -25,11 +25,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingCache,
ProcessingMixin, PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
......@@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol):
image_token: Final[str]
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
class BaseLlavaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self) -> LlavaLikeConfig:
def get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
@abstractmethod
def _get_hf_processor(self) -> LlavaLikeProcessor:
def get_hf_processor(self) -> LlavaLikeProcessor:
raise NotImplementedError
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(
self,
......@@ -147,28 +136,42 @@ class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens(
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
seq_len: int,
......@@ -176,9 +179,10 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
......@@ -193,23 +197,13 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
)
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
pass
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
BaseMultiModalProcessor):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_profiling_info(self) -> BaseProfilingInfo:
raise NotImplementedError
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
......@@ -226,7 +220,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
......@@ -237,7 +231,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
......@@ -253,10 +247,8 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
]
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaProfilingInfo(self.ctx)
class LlavaMultiModalProcessor(
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -269,21 +261,14 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
)
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor)
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
pass
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return PixtralHFProfilingInfo(self.ctx)
class PixtralHFMultiModalProcessor(
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
def _call_hf_processor(
self,
......@@ -328,10 +313,10 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
......@@ -363,27 +348,41 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
]
def _build_llava_or_pixtral_hf_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFProcessingInfo(ctx)
return LlavaProcessingInfo(ctx)
def _build_llava_or_pixtral_hf_processor(
ctx: InputProcessingContext,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(
ctx,
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
if isinstance(info, LlavaProcessingInfo):
return LlavaMultiModalProcessor(
ctx,
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
raise NotImplementedError(type(info))
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
......@@ -460,7 +459,9 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
......@@ -727,11 +728,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
# Assume that it doesn't depend on the image size
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=-1,
image_height=-1,
)
......@@ -796,6 +797,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass
from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
......@@ -16,13 +17,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.profiling import BaseProfilingInfo
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
BaseLlavaProfilingInfo, LlavaLikeConfig,
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
......@@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
image_grid_pinpoints: Final[list[list[int]]]
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_config(self) -> LlavaNextLikeConfig:
def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def _get_num_image_tokens(
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
......@@ -140,15 +140,12 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
return (unpadded_features, newline_features)
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
feat_size = self.get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
......@@ -161,11 +158,23 @@ class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
return largest_feature_pinpoint
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
BaseLlavaMultiModalProcessor):
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextProfilingInfo(self.ctx)
class LlavaNextMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -179,7 +188,9 @@ class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
......
......@@ -17,12 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -47,33 +46,52 @@ class LlavaNextVideoPixelInputs(TypedDict):
"""
class LlavaNextVideoProcessingMixin(ProcessingMixin):
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride
vision_encoder_info = self._get_vision_encoder_info()
vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
......@@ -87,37 +105,14 @@ class LlavaNextVideoProcessingMixin(ProcessingMixin):
return num_frame_tokens * num_frames
class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_video_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
return {"video": max_video_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
......@@ -130,7 +125,7 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
......@@ -138,6 +133,10 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return max(max_total_frames // max(max_videos, 1), 1)
class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
......@@ -145,16 +144,20 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
......@@ -165,11 +168,8 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
)
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextVideoProfilingInfo(self.ctx)
class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -184,7 +184,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int):
......@@ -195,7 +195,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
......@@ -269,7 +269,11 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextVideoMultiModalProcessor,
info=LlavaNextVideoProcessingInfo,
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
......
......@@ -17,19 +17,20 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
LlavaNextProcessingMixin)
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingInfo)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
......@@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
video_token_index: Final[int]
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
def get_hf_config(self) -> LlavaOnevisionLikeConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features(
......@@ -141,16 +151,16 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
vision_encoder_info = self._get_vision_encoder_info()
vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
......@@ -164,43 +174,14 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
return num_frame_tokens * num_frames + 1 # Newline token
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
......@@ -213,12 +194,12 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
......@@ -226,15 +207,19 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return max(max_frames_per_video, 1)
def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens(
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=self.get_num_frames_with_most_features(seq_len),
)
class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
......@@ -243,10 +228,14 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"image":
......@@ -257,7 +246,7 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
......@@ -268,11 +257,8 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
)
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
LlavaNextMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaOnevisionProfilingInfo(self.ctx)
class LlavaOnevisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
def _get_mm_fields_config(
self,
......@@ -303,7 +289,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
mm_kwargs=mm_kwargs,
)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos
......@@ -345,7 +331,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
out_mm_kwargs=out_mm_kwargs,
)
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int):
......@@ -356,7 +342,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
......@@ -393,7 +379,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
LlavaOnevisionMultiModalProcessor,
info=LlavaOnevisionProcessingInfo,
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
......
......@@ -34,13 +34,12 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo,
BoundPromptReplacement,
PlaceholderInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -302,9 +301,9 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
class Phi3VProcessingMixin(ProcessingMixin):
class Phi3VProcessingInfo(BaseProcessingInfo):
def _get_hf_processor(
def get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
......@@ -314,39 +313,42 @@ class Phi3VProcessingMixin(ProcessingMixin):
return self.ctx.get_hf_processor()
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
processor = self._get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self._get_num_image_tokens(
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
......@@ -354,7 +356,8 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
......@@ -363,7 +366,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
num_images=num_images)
}
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
......@@ -372,10 +375,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
)
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Phi3VProfilingInfo(self.ctx)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def _call_hf_processor(
self,
......@@ -416,10 +416,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
......@@ -431,9 +431,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
......@@ -451,9 +452,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
......@@ -466,7 +467,7 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [
_PlaceholderInfo(
PlaceholderInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
......@@ -499,7 +500,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
......
......@@ -38,11 +38,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
......@@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths
class Qwen2AudioProcessingMixin(ProcessingMixin):
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
def _get_hf_processor(
def get_hf_processor(
self,
*,
# Ignored in initialization
......@@ -93,36 +93,37 @@ class Qwen2AudioProcessingMixin(ProcessingMixin):
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor(
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
......@@ -139,14 +140,11 @@ class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
)
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2AudioProfilingInfo(self.ctx)
class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
......@@ -161,7 +159,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
if audios:
mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor(**mm_kwargs)
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
......@@ -194,7 +192,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
placeholder = hf_config.audio_token_index
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
......@@ -234,10 +232,13 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return not hasattr(self._get_hf_processor(), "audio_token")
return not hasattr(self.info.get_hf_processor(), "audio_token")
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
info=Qwen2AudioProcessingInfo,
dummy_inputs=Qwen2AudioDummyInputsBuilder)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
......
......@@ -57,11 +57,10 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
......@@ -709,12 +708,12 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data)
class Qwen2VLProcessingMixin(ProcessingMixin):
class Qwen2VLProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig)
def _get_hf_processor(
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
......@@ -736,18 +735,27 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
return hf_processor
def _get_image_processor(
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
):
hf_processor = self._get_hf_processor(min_pixels=min_pixels,
hf_processor = self.get_hf_processor(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
def _get_vision_info(
self,
*,
......@@ -755,15 +763,17 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessor],
) -> tuple[ImageSize, int]:
hf_config = self._get_hf_config()
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
image_processor = self._get_image_processor()
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
......@@ -787,70 +797,65 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
return preprocessed_size, num_vision_tokens
def _get_num_image_tokens(
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
)
return num_image_tokens
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
)
return num_video_tokens
class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_image_size_with_most_features(self) -> ImageSize:
def get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
image_processor=None,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens(
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
)
if next_max_tokens > max_tokens:
......@@ -860,12 +865,12 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
......@@ -877,15 +882,19 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
return num_frames
def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens(
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None,
)
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
......@@ -894,10 +903,14 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"image":
......@@ -908,7 +921,7 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
......@@ -919,11 +932,8 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
)
class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2VLProfilingInfo(self.ctx)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
......@@ -934,8 +944,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self._get_image_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
......@@ -991,7 +1002,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
info=Qwen2VLProcessingInfo,
dummy_inputs=Qwen2VLDummyInputsBuilder)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
......
......@@ -24,11 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
......@@ -59,9 +58,9 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs]
class UltravoxProcessingMixin(ProcessingMixin):
class UltravoxProcessingInfo(BaseProcessingInfo):
def _get_hf_processor(
def get_hf_processor(
self,
*,
# Ignored in initialization
......@@ -76,37 +75,38 @@ class UltravoxProcessingMixin(ProcessingMixin):
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
return hf_processor
def _get_feature_extractor(
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens}
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
......@@ -123,14 +123,11 @@ class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
)
class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return UltravoxProfilingInfo(self.ctx)
class UltravoxMultiModalProcessor(
BaseMultiModalProcessor[UltravoxProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
......@@ -141,7 +138,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(
prompt,
......@@ -160,7 +157,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
mm_kwargs=mm_kwargs,
)
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
......@@ -208,7 +205,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int):
......@@ -342,7 +339,10 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment