Unverified Commit 9ad7f89f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Models]: Make Multimodal config implicit in ViT implementation (#31972)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 6450b536
......@@ -205,7 +205,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -16,7 +16,7 @@ from transformers import (
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
......@@ -382,7 +382,6 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
require_post_norm: bool | None = None,
prefix: str = "",
......@@ -397,7 +396,6 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -461,7 +459,6 @@ class Mistral3ForConditionalGeneration(
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -52,7 +52,6 @@ import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
......@@ -62,6 +61,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.models.vision import is_vit_use_data_parallel
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
......@@ -308,11 +308,10 @@ class MLP2(nn.Module):
activation,
bias: bool = True,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
assert len(dims) == 3
self.use_data_parallel = use_data_parallel
self.use_data_parallel = is_vit_use_data_parallel()
self.fc0 = ColumnParallelLinear(
dims[0],
dims[1],
......@@ -343,17 +342,12 @@ class MoonVitEncoderLayer(nn.Module):
hidden_dim: int,
mlp_dim: int,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
*,
activation=F.gelu,
attn_bias: bool = False,
):
super().__init__()
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.use_data_parallel = is_vit_use_data_parallel()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
......@@ -369,7 +363,6 @@ class MoonVitEncoderLayer(nn.Module):
[hidden_dim, mlp_dim, hidden_dim],
activation,
prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel,
)
self.wqkv = QKVParallelLinear(
hidden_size=hidden_dim,
......@@ -391,7 +384,6 @@ class MoonVitEncoderLayer(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
......@@ -469,7 +461,6 @@ class MoonVitEncoder(nn.Module):
num_layers: int,
block_cfg: dict,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
......@@ -479,7 +470,6 @@ class MoonVitEncoder(nn.Module):
self.blocks = nn.ModuleList(
[
MoonVitEncoderLayer(
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
**block_cfg,
)
......@@ -550,7 +540,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
def __init__(
self,
config: MoonViTConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
*inputs,
**kwargs,
......@@ -579,7 +568,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
"attn_bias": True,
},
prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
)
def forward(
......
......@@ -244,7 +244,6 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
......
......@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -103,7 +103,6 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig,
visual_vocab_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -111,7 +110,6 @@ class VisualTokenizer(torch.nn.Module):
self.vit = self._init_backbone(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vit",
)
# reserved tokens for INDICATOR_IDS
......@@ -130,7 +128,6 @@ class VisualTokenizer(torch.nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: QuantizationConfig | None = None,
prefix: str = "",
):
model_type = config.model_type
......@@ -138,7 +135,6 @@ class VisualTokenizer(torch.nn.Module):
return Siglip2NavitModel(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=prefix,
)
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
......@@ -464,7 +460,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config: PretrainedConfig = config
......@@ -478,7 +473,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
multimodal_config=multimodal_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
......
......@@ -30,7 +30,7 @@ from transformers.modeling_outputs import (
)
from transformers.utils import torch_int
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
......@@ -532,7 +532,6 @@ class SiglipAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -565,7 +564,6 @@ class SiglipAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
......@@ -662,7 +660,6 @@ class SiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -673,14 +670,12 @@ class SiglipEncoderLayer(nn.Module):
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -718,7 +713,6 @@ class SiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -727,13 +721,9 @@ class SiglipEncoder(nn.Module):
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
......@@ -748,7 +738,6 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(config.num_hidden_layers)
......@@ -830,7 +819,6 @@ class SiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -841,7 +829,6 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
......@@ -880,7 +867,6 @@ class SiglipVisionModel(nn.Module):
self,
config,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -888,7 +874,6 @@ class SiglipVisionModel(nn.Module):
self.vision_model = SiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
)
self.quant_config = quant_config
......@@ -1010,16 +995,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, "image"):
self.visual = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = Projector(config, config.vision_config)
......
......@@ -29,7 +29,7 @@ from transformers import (
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
......@@ -96,7 +96,6 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
def _init_img_processor(
hf_config: PretrainedConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "",
) -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
......@@ -111,7 +110,6 @@ def _init_img_processor(
img_processor = CLIPVisionModel(
clip_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
......@@ -170,7 +168,6 @@ class Phi3HDImageEmbedding(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -181,7 +178,6 @@ class Phi3HDImageEmbedding(nn.Module):
self.img_processor = _init_img_processor(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.img_processor",
)
......@@ -596,7 +592,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
self.vision_embed_tokens = Phi3HDImageEmbedding(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
)
......
......@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer
......@@ -74,6 +74,7 @@ from .utils import init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
VisionFeatureSelectStrategy,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
)
......@@ -1065,17 +1066,12 @@ class PixtralHFMLP(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -1108,7 +1104,6 @@ class PixtralHFAttention(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
......@@ -1120,11 +1115,7 @@ class PixtralHFAttention(nn.Module):
self.head_dim = config.hidden_size // config.num_attention_heads
assert self.total_num_heads * self.head_dim == config.hidden_size
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
......@@ -1189,7 +1180,6 @@ class PixtralHFTransformerBlock(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
......@@ -1199,13 +1189,11 @@ class PixtralHFTransformerBlock(nn.Module):
self.attention = PixtralHFAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = PixtralHFMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.feed_forward",
)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
......@@ -1232,7 +1220,6 @@ class PixtralHFTransformer(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
......@@ -1249,7 +1236,6 @@ class PixtralHFTransformer(nn.Module):
PixtralHFTransformerBlock(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
......@@ -1281,7 +1267,6 @@ class PixtralHFVisionModel(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
......@@ -1302,7 +1287,6 @@ class PixtralHFVisionModel(nn.Module):
self.transformer = PixtralHFTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
......
......@@ -846,7 +846,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
with self._mark_language_model(vllm_config):
......
......@@ -43,7 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
......@@ -109,6 +109,7 @@ from .utils import (
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
)
......@@ -266,15 +267,10 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
......@@ -308,16 +304,11 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
......@@ -354,7 +345,6 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
......@@ -435,7 +425,6 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -448,7 +437,6 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen2_5_VisionMLP(
......@@ -457,7 +445,6 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -530,15 +517,10 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
......@@ -579,7 +561,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -620,15 +601,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
rope_parameters={"partial_rotary_factor": 0.5},
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend not in {
......@@ -650,7 +625,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
......@@ -664,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
)
......@@ -1152,7 +1125,6 @@ class Qwen2_5_VLForConditionalGeneration(
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
with self._mark_language_model(vllm_config):
......
......@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
......@@ -106,6 +106,7 @@ from .utils import (
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
)
......@@ -247,15 +248,10 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int,
act_layer: type[nn.Module] = QuickGELU,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
......@@ -286,16 +282,11 @@ class Qwen2VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
......@@ -328,7 +319,6 @@ class Qwen2VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
......@@ -409,7 +399,6 @@ class Qwen2VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -424,7 +413,6 @@ class Qwen2VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen2VisionMLP(
......@@ -432,7 +420,6 @@ class Qwen2VisionBlock(nn.Module):
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -493,15 +480,10 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
......@@ -545,7 +527,6 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -560,11 +541,7 @@ class Qwen2VisionTransformer(nn.Module):
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.use_data_parallel = is_vit_use_data_parallel()
self.out_hidden_size = vision_config.hidden_size
self.spatial_merge_size = spatial_merge_size
......@@ -596,7 +573,6 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
multimodal_config=multimodal_config,
)
for layer_idx in range(depth)
]
......@@ -607,15 +583,10 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
multimodal_config=multimodal_config,
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
@property
......@@ -1238,7 +1209,6 @@ class Qwen2VLForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
......
......@@ -48,7 +48,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
# isort: on
from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
......@@ -160,7 +160,6 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
def __init__(
self,
config: Qwen3OmniMoeAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -198,7 +197,6 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
num_heads=self.num_local_heads,
head_size=self.head_dim,
scale=self.scaling,
multimodal_config=multimodal_config,
)
def forward(
......@@ -233,13 +231,12 @@ class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
def __init__(
self,
config: Qwen3OmniMoeAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Qwen3OmniMoeAudioAttention(
config, multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn"
config, prefix=f"{prefix}.self_attn"
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
......@@ -301,7 +298,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
def __init__(
self,
config: Qwen3OmniMoeAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -345,7 +341,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
[
Qwen3OmniMoeAudioEncoderLayer(
config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{i}",
)
for i in range(config.encoder_layers)
......@@ -359,15 +354,9 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
self.proj2 = nn.Linear(config.d_model, config.output_dim)
# Get attention backend
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=config.d_model // config.encoder_attention_heads,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
......@@ -601,7 +590,6 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
......@@ -615,7 +603,6 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
......@@ -710,7 +697,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -758,7 +744,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
......@@ -788,16 +773,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
@property
......@@ -1617,7 +1595,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
......@@ -1638,7 +1615,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
# register buffer for deepstack
......
......@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
......@@ -123,6 +123,7 @@ from .utils import (
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
)
......@@ -169,15 +170,10 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
......@@ -211,7 +207,6 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
......@@ -225,7 +220,6 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
......@@ -234,7 +228,6 @@ class Qwen3_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -267,15 +260,10 @@ class Qwen3_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
......@@ -321,7 +309,6 @@ class Qwen3_VisionTransformer(nn.Module):
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -365,7 +352,6 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
)
......@@ -378,20 +364,15 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend not in {
......@@ -411,7 +392,6 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
......@@ -1291,7 +1271,6 @@ class Qwen3VLForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
......
......@@ -446,7 +446,6 @@ class Qwen3VLMoeForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
......
......@@ -16,7 +16,7 @@ from transformers import (
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.encoder_only_attention import (
......@@ -64,6 +64,7 @@ from .vision import (
VisionFeatureSelectStrategy,
VisionFeatureSelectStrategyStr,
get_num_selected_vision_tokens,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
)
......@@ -356,7 +357,6 @@ class SiglipAttention(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
......@@ -376,11 +376,7 @@ class SiglipAttention(nn.Module):
self.scale = self.head_dim**-0.5
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
......@@ -409,7 +405,6 @@ class SiglipAttention(nn.Module):
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
......@@ -437,17 +432,12 @@ class SiglipMLP(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
......@@ -487,7 +477,6 @@ class SiglipEncoderLayer(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
......@@ -499,7 +488,6 @@ class SiglipEncoderLayer(nn.Module):
self.self_attn = SiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
......@@ -507,7 +495,6 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
......@@ -535,7 +522,6 @@ class SiglipEncoder(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
......@@ -555,7 +541,6 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
......@@ -660,7 +645,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -674,7 +658,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self.mlp = SiglipMLP(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -700,7 +683,6 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
......@@ -717,7 +699,6 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
......@@ -756,7 +737,6 @@ class SiglipVisionTransformer(nn.Module):
SiglipMultiheadAttentionPoolingHead(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.head",
)
if self.use_head
......@@ -870,7 +850,6 @@ class SiglipVisionModel(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
......@@ -883,7 +862,6 @@ class SiglipVisionModel(nn.Module):
self.vision_model = SiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
......@@ -1062,9 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
config: SiglipConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
if hasattr(config, "num_labels"):
config.num_labels = 0
......@@ -1087,7 +1063,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
use_head=None, # Allows potential pooling head
)
......
......@@ -11,7 +11,6 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
......@@ -30,6 +29,8 @@ from vllm.model_executor.layers.rotary_embedding.common import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from .vision import is_vit_use_data_parallel
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
......@@ -178,9 +179,7 @@ class Siglip2Attention(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
......@@ -196,11 +195,7 @@ class Siglip2Attention(nn.Module):
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
......@@ -228,7 +223,6 @@ class Siglip2Attention(nn.Module):
head_size=self.head_dim,
scale=self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
......@@ -287,16 +281,11 @@ class Siglip2MLP(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
......@@ -325,7 +314,6 @@ class Siglip2EncoderLayer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -334,14 +322,12 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -387,7 +373,6 @@ class Siglip2Encoder(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -397,7 +382,6 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(config.num_hidden_layers)
......@@ -571,7 +555,6 @@ class Siglip2VisionTransformer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -582,7 +565,6 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
......@@ -610,7 +592,6 @@ class Siglip2NavitModel(torch.nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -618,7 +599,6 @@ class Siglip2NavitModel(torch.nn.Module):
self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
)
......
......@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.processing_utils import ProcessingKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -351,7 +351,6 @@ def _build_tarsier_hf_processor(
def init_vision_tower_for_tarsier(
hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
require_post_norm: bool | None = None,
prefix: str = "",
......@@ -378,7 +377,6 @@ def init_vision_tower_for_tarsier(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -387,7 +385,6 @@ def init_vision_tower_for_tarsier(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -420,7 +417,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config: TarsierHfConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config # Storing the Tarsier-specific HF config
......@@ -428,7 +424,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.vision_tower = init_vision_tower_for_tarsier(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
import torch
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -79,7 +79,7 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf
raise NotImplementedError(msg)
def get_vit_attn_backend(
def _get_vit_attn_backend(
head_size: int,
dtype: torch.dtype,
*,
......@@ -95,6 +95,52 @@ def get_vit_attn_backend(
)
def get_vit_attn_backend(
head_size: int,
dtype: torch.dtype,
) -> AttentionBackendEnum:
"""
Get the attention backend for Vision Transformer.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
multimodal_config: MultiModalConfig | None = (
vllm_config.model_config.multimodal_config
)
except AssertionError:
multimodal_config = None
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
attn_backend = _get_vit_attn_backend(
head_size,
dtype,
attn_backend_override=attn_backend_override,
)
return attn_backend
def is_vit_use_data_parallel():
"""
Get the tensor parallel type for Vision Transformer.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
multimodal_config: MultiModalConfig | None = (
vllm_config.model_config.multimodal_config
)
except AssertionError:
multimodal_config = None
mm_encoder_tp_mode = (
multimodal_config.mm_encoder_tp_mode if multimodal_config is not None else None
)
return mm_encoder_tp_mode == "data"
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment