Unverified Commit b024a42e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Move multimodal placeholder from chat utils to model definition (#20355)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent cb97f2bf
...@@ -10,6 +10,22 @@ This document walks you through the steps to extend a basic model so that it acc ...@@ -10,6 +10,22 @@ This document walks you through the steps to extend a basic model so that it acc
It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic]. It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic].
Further update the model as follows: Further update the model as follows:
- Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model.
??? Code
```python
class YourModelForImage2Seq(nn.Module):
...
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
```
- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example: - Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example:
```diff ```diff
......
...@@ -33,7 +33,6 @@ class RequestOutput: ...@@ -33,7 +33,6 @@ class RequestOutput:
class MockModelConfig: class MockModelConfig:
use_async_output_proc = True use_async_output_proc = True
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
class MockEngine: class MockEngine:
......
...@@ -263,26 +263,6 @@ def test_media_io_kwargs_parser(arg, expected): ...@@ -263,26 +263,6 @@ def test_media_io_kwargs_parser(arg, expected):
assert args.media_io_kwargs == expected assert args.media_io_kwargs == expected
@pytest.mark.parametrize(("arg", "expected"), [
(None, dict()),
('{"video":"<|video_placeholder|>"}', {
"video": "<|video_placeholder|>"
}),
('{"video":"<|video_placeholder|>", "image": "<|image_placeholder|>"}', {
"video": "<|video_placeholder|>",
"image": "<|image_placeholder|>"
}),
])
def test_mm_placeholder_str_override_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--mm-placeholder-str-override", arg])
assert args.mm_placeholder_str_override == expected
def test_compilation_config(): def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
......
...@@ -41,7 +41,6 @@ class MockModelConfig: ...@@ -41,7 +41,6 @@ class MockModelConfig:
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -350,8 +350,6 @@ class ModelConfig: ...@@ -350,8 +350,6 @@ class ModelConfig:
"""Additional args passed to process media inputs, keyed by modalities. """Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """ `--media-io-kwargs '{"video": {"num_frames": 40} }'` """
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
"""Optionally override placeholder string for given modalities."""
use_async_output_proc: bool = True use_async_output_proc: bool = True
"""Whether to use async output processor.""" """Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
...@@ -661,7 +659,7 @@ class ModelConfig: ...@@ -661,7 +659,7 @@ class ModelConfig:
return self._architecture return self._architecture
@property @property
def model_info(self) -> dict[str, Any]: def model_info(self):
return self._model_info return self._model_info
def maybe_pull_model_tokenizer_for_s3(self, model: str, def maybe_pull_model_tokenizer_for_s3(self, model: str,
...@@ -701,7 +699,6 @@ class ModelConfig: ...@@ -701,7 +699,6 @@ class ModelConfig:
return MultiModalConfig( return MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt, limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
mm_placeholder_str_override=self.mm_placeholder_str_override,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self. disable_mm_preprocessor_cache=self.
disable_mm_preprocessor_cache) disable_mm_preprocessor_cache)
...@@ -3096,9 +3093,6 @@ class MultiModalConfig: ...@@ -3096,9 +3093,6 @@ class MultiModalConfig:
For example, to set num_frames for video, set For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """ `--media-io-kwargs '{"video": {"num_frames": 40} }'` """
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
"""Optionally override placeholder string for given modalities."""
mm_processor_kwargs: Optional[dict[str, object]] = None mm_processor_kwargs: Optional[dict[str, object]] = None
""" """
Overrides for the multi-modal processor obtained from Overrides for the multi-modal processor obtained from
......
...@@ -373,8 +373,6 @@ class EngineArgs: ...@@ -373,8 +373,6 @@ class EngineArgs:
media_io_kwargs: dict[str, dict[str, media_io_kwargs: dict[str, dict[str,
Any]] = get_field(MultiModalConfig, Any]] = get_field(MultiModalConfig,
"media_io_kwargs") "media_io_kwargs")
mm_placeholder_str_override: dict[str, str] = \
get_field(MultiModalConfig, "mm_placeholder_str_override")
mm_processor_kwargs: Optional[Dict[str, Any]] = \ mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = \ disable_mm_preprocessor_cache: bool = \
...@@ -759,9 +757,6 @@ class EngineArgs: ...@@ -759,9 +757,6 @@ class EngineArgs:
**multimodal_kwargs["limit_per_prompt"]) **multimodal_kwargs["limit_per_prompt"])
multimodal_group.add_argument("--media-io-kwargs", multimodal_group.add_argument("--media-io-kwargs",
**multimodal_kwargs["media_io_kwargs"]) **multimodal_kwargs["media_io_kwargs"])
multimodal_group.add_argument(
"--mm-placeholder-str-override",
**multimodal_kwargs["mm_placeholder_str_override"])
multimodal_group.add_argument( multimodal_group.add_argument(
"--mm-processor-kwargs", "--mm-processor-kwargs",
**multimodal_kwargs["mm_processor_kwargs"]) **multimodal_kwargs["mm_processor_kwargs"])
...@@ -987,7 +982,6 @@ class EngineArgs: ...@@ -987,7 +982,6 @@ class EngineArgs:
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
mm_placeholder_str_override=self.mm_placeholder_str_override,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Awaitable, Iterable from collections.abc import Awaitable, Iterable
from functools import cache, lru_cache, partial from functools import cached_property, lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast) cast)
...@@ -37,6 +37,8 @@ from typing_extensions import Required, TypeAlias, TypedDict ...@@ -37,6 +37,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector from vllm.multimodal.utils import MediaConnector
# yapf: disable # yapf: disable
...@@ -492,6 +494,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -492,6 +494,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._model_config return self._model_config
@cached_property
def model_cls(self):
return get_model_cls(self.model_config)
@property @property
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
...@@ -500,89 +506,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -500,89 +506,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def mm_registry(self): def mm_registry(self):
return MULTIMODAL_REGISTRY return MULTIMODAL_REGISTRY
@staticmethod
@cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index)
def _placeholder_str(self, modality: ModalityStr,
current_count: int) -> Optional[str]:
if modality in self._model_config.mm_placeholder_str_override:
return self._model_config.mm_placeholder_str_override[modality]
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
hf_config = self._model_config.hf_config
model_type = hf_config.model_type
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "glm4v":
return "<|begin_of_image|><|image|><|end_of_image|>"
if model_type in ("phi3_v", "phi4mm"):
return f"<|image_{current_count}|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
"pixtral", "mistral3"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
return f"Picture {current_count}: <img></img>"
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "ovis", "skywork_chat",
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
return "<image>"
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl", "keye", "Keye"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
if model_type == "molmo":
return ""
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3":
return "<start_of_image>"
if model_type == "kimi_vl":
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type in ("ultravox", "granite_speech"):
return "<|audio|>"
if model_type == "phi4mm":
return f"<|audio_{current_count}|>"
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo":
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type == "glm4v":
return "<|begin_of_video|><|video|><|end_of_video|>"
if model_type in ("qwen2_vl", "qwen2_5_vl", "keye", "Keye"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.video_token_index)
raise TypeError(f"Unknown {modality} model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
...@@ -590,6 +513,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -590,6 +513,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
""" """
mm_registry = self.mm_registry mm_registry = self.mm_registry
model_config = self.model_config model_config = self.model_config
model_cls = cast(SupportsMultiModal, self.model_cls)
input_modality = modality.replace("_embeds", "") input_modality = modality.replace("_embeds", "")
...@@ -614,7 +538,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -614,7 +538,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._items_by_modality[modality].append(item) self._items_by_modality[modality].append(item)
return self._placeholder_str(modality, current_count) return model_cls.get_placeholder_str(modality, current_count)
@abstractmethod @abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
......
...@@ -5,6 +5,7 @@ import io ...@@ -5,6 +5,7 @@ import io
import math import math
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from functools import cached_property
from math import ceil from math import ceil
from typing import Callable, Literal, Optional, TypeVar, Union, cast from typing import Callable, Literal, Optional, TypeVar, Union, cast
...@@ -24,7 +25,8 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing, ...@@ -24,7 +25,8 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import get_model_architecture from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models import SupportsTranscription
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
...@@ -76,24 +78,29 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -76,24 +78,29 @@ class OpenAISpeechToText(OpenAIServing):
self.model_sr = processor.feature_extractor.sampling_rate self.model_sr = processor.feature_extractor.sampling_rate
self.hop_length = processor.feature_extractor.hop_length self.hop_length = processor.feature_extractor.hop_length
self.task_type = task_type self.task_type = task_type
self.model_cls, _ = get_model_architecture(model_config)
if self.default_sampling_params: if self.default_sampling_params:
logger.info( logger.info(
"Overwriting default completion sampling param with: %s", "Overwriting default completion sampling param with: %s",
self.default_sampling_params) self.default_sampling_params)
@cached_property
def model_cls(self):
return get_model_cls(self.model_config)
async def _preprocess_speech_to_text( async def _preprocess_speech_to_text(
self, self,
request: SpeechToTextRequest, request: SpeechToTextRequest,
audio_data: bytes, audio_data: bytes,
) -> tuple[list[PromptType], float]: ) -> tuple[list[PromptType], float]:
model_cls = cast(SupportsTranscription, self.model_cls)
# Validate request # Validate request
# TODO language should be optional and can be guessed. # TODO language should be optional and can be guessed.
# For now we default to en. See # For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang = request.language or "en" lang = request.language or "en"
self.model_cls.validate_language(lang) # type: ignore[attr-defined] model_cls.validate_language(lang)
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
raise ValueError("Maximum file size exceeded.") raise ValueError("Maximum file size exceeded.")
...@@ -117,9 +124,8 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -117,9 +124,8 @@ class OpenAISpeechToText(OpenAIServing):
}, },
}, },
"decoder_prompt": "decoder_prompt":
self.model_cls. model_cls.get_decoder_prompt(lang, self.task_type,
get_decoder_prompt( # type: ignore[attr-defined] request.prompt)
lang, self.task_type, request.prompt)
} }
prompts.append(cast(PromptType, prompt)) prompts.append(cast(PromptType, prompt))
return prompts, duration return prompts, duration
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.sharded_state_loader import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.sharded_state_loader import (
ShardedStateLoader) ShardedStateLoader)
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture) get_architecture_class_name, get_model_architecture, get_model_cls)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
...@@ -65,6 +65,7 @@ __all__ = [ ...@@ -65,6 +65,7 @@ __all__ = [
"get_model_loader", "get_model_loader",
"get_architecture_class_name", "get_architecture_class_name",
"get_model_architecture", "get_model_architecture",
"get_model_cls",
"BaseModelLoader", "BaseModelLoader",
"BitsAndBytesModelLoader", "BitsAndBytesModelLoader",
"GGUFModelLoader", "GGUFModelLoader",
......
...@@ -13,7 +13,7 @@ import time ...@@ -13,7 +13,7 @@ import time
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Any, BinaryIO, Optional, Union from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union
import regex as re import regex as re
import torch import torch
...@@ -24,12 +24,14 @@ from transformers import PretrainedConfig ...@@ -24,12 +24,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (ModelConfig, ParallelConfig, VllmConfig, from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
set_current_vllm_config) set_current_vllm_config)
from vllm.engine.arg_utils import EngineArgs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser, PlaceholderModule from vllm.utils import FlexibleArgumentParser, PlaceholderModule
if TYPE_CHECKING:
from vllm.engine.arg_utils import EngineArgs
try: try:
from tensorizer import (DecryptionParams, EncryptionParams, from tensorizer import (DecryptionParams, EncryptionParams,
TensorDeserializer, TensorSerializer) TensorDeserializer, TensorSerializer)
...@@ -503,7 +505,7 @@ def serialize_vllm_model( ...@@ -503,7 +505,7 @@ def serialize_vllm_model(
return model return model
def tensorize_vllm_model(engine_args: EngineArgs, def tensorize_vllm_model(engine_args: "EngineArgs",
tensorizer_config: TensorizerConfig, tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True): generate_keyfile: bool = True):
"""Utility to load a model and then serialize it with Tensorizer """Utility to load a model and then serialize it with Tensorizer
......
...@@ -253,6 +253,10 @@ def get_model_architecture( ...@@ -253,6 +253,10 @@ def get_model_architecture(
return model_cls, arch return model_cls, arch
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
return get_model_architecture(model_config)[0]
def get_architecture_class_name(model_config: ModelConfig) -> str: def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1] return get_model_architecture(model_config)[1]
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, SupportsV0Only, has_inner_state, SupportsPP, SupportsTranscription, SupportsV0Only,
supports_lora, supports_multimodal, supports_pp, has_inner_state, supports_lora, supports_multimodal,
supports_v0_only) supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
is_pooling_model, is_text_generation_model) is_pooling_model, is_text_generation_model)
from .registry import ModelRegistry from .registry import ModelRegistry
...@@ -23,6 +23,8 @@ __all__ = [ ...@@ -23,6 +23,8 @@ __all__ = [
"supports_multimodal", "supports_multimodal",
"SupportsPP", "SupportsPP",
"supports_pp", "supports_pp",
"SupportsTranscription",
"supports_transcription",
"SupportsV0Only", "SupportsV0Only",
"supports_v0_only", "supports_v0_only",
] ]
...@@ -499,6 +499,13 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -499,6 +499,13 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
}, },
) )
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<|fim_prefix|><|img|><|fim_suffix|>"
raise ValueError("Only image modality is supported")
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
......
...@@ -304,6 +304,13 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -304,6 +304,13 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
"lm_head.": "language_model.lm_head.", "lm_head.": "language_model.lm_head.",
}) })
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: AyaVisionConfig = vllm_config.model_config.hf_config config: AyaVisionConfig = vllm_config.model_config.hf_config
......
...@@ -507,6 +507,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ...@@ -507,6 +507,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant): SupportsQuant):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return None
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -933,6 +933,13 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -933,6 +933,13 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj": ["gate_proj", "up_proj"] "gate_up_proj": ["gate_proj", "up_proj"]
} }
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -315,6 +315,13 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -315,6 +315,13 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
"language.": "language_model.", "language.": "language_model.",
}) })
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: DeepseekVLV2Config = vllm_config.model_config.hf_config config: DeepseekVLV2Config = vllm_config.model_config.hf_config
......
...@@ -877,6 +877,13 @@ class Florence2MultiModalProcessor( ...@@ -877,6 +877,13 @@ class Florence2MultiModalProcessor(
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only): SupportsV0Only):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return None
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -254,6 +254,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -254,6 +254,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
"lm_head.": "language_model.lm_head.", "lm_head.": "language_model.lm_head.",
}) })
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return None
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -483,6 +483,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -483,6 +483,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
"lm_head.": "language_model.lm_head.", "lm_head.": "language_model.lm_head.",
}) })
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<start_of_image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment