Unverified Commit 9ad7f89f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Models]: Make Multimodal config implicit in ViT implementation (#31972)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 6450b536
...@@ -205,7 +205,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -205,7 +205,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -16,7 +16,7 @@ from transformers import ( ...@@ -16,7 +16,7 @@ from transformers import (
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
...@@ -382,7 +382,6 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: ...@@ -382,7 +382,6 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
def init_vision_tower_for_llava( def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig, hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
...@@ -397,7 +396,6 @@ def init_vision_tower_for_llava( ...@@ -397,7 +396,6 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel( return PixtralHFVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -461,7 +459,6 @@ class Mistral3ForConditionalGeneration( ...@@ -461,7 +459,6 @@ class Mistral3ForConditionalGeneration(
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -52,7 +52,6 @@ import torch.nn.functional as F ...@@ -52,7 +52,6 @@ import torch.nn.functional as F
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -62,6 +61,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -62,6 +61,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.models.vision import is_vit_use_data_parallel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig
...@@ -308,11 +308,10 @@ class MLP2(nn.Module): ...@@ -308,11 +308,10 @@ class MLP2(nn.Module):
activation, activation,
bias: bool = True, bias: bool = True,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
assert len(dims) == 3 assert len(dims) == 3
self.use_data_parallel = use_data_parallel self.use_data_parallel = is_vit_use_data_parallel()
self.fc0 = ColumnParallelLinear( self.fc0 = ColumnParallelLinear(
dims[0], dims[0],
dims[1], dims[1],
...@@ -343,17 +342,12 @@ class MoonVitEncoderLayer(nn.Module): ...@@ -343,17 +342,12 @@ class MoonVitEncoderLayer(nn.Module):
hidden_dim: int, hidden_dim: int,
mlp_dim: int, mlp_dim: int,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
*, *,
activation=F.gelu, activation=F.gelu,
attn_bias: bool = False, attn_bias: bool = False,
): ):
super().__init__() super().__init__()
self.use_data_parallel = ( self.use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
...@@ -369,7 +363,6 @@ class MoonVitEncoderLayer(nn.Module): ...@@ -369,7 +363,6 @@ class MoonVitEncoderLayer(nn.Module):
[hidden_dim, mlp_dim, hidden_dim], [hidden_dim, mlp_dim, hidden_dim],
activation, activation,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel,
) )
self.wqkv = QKVParallelLinear( self.wqkv = QKVParallelLinear(
hidden_size=hidden_dim, hidden_size=hidden_dim,
...@@ -391,7 +384,6 @@ class MoonVitEncoderLayer(nn.Module): ...@@ -391,7 +384,6 @@ class MoonVitEncoderLayer(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
...@@ -469,7 +461,6 @@ class MoonVitEncoder(nn.Module): ...@@ -469,7 +461,6 @@ class MoonVitEncoder(nn.Module):
num_layers: int, num_layers: int,
block_cfg: dict, block_cfg: dict,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -479,7 +470,6 @@ class MoonVitEncoder(nn.Module): ...@@ -479,7 +470,6 @@ class MoonVitEncoder(nn.Module):
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
MoonVitEncoderLayer( MoonVitEncoderLayer(
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
**block_cfg, **block_cfg,
) )
...@@ -550,7 +540,6 @@ class MoonVitPretrainedModel(PreTrainedModel): ...@@ -550,7 +540,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
def __init__( def __init__(
self, self,
config: MoonViTConfig, config: MoonViTConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
*inputs, *inputs,
**kwargs, **kwargs,
...@@ -579,7 +568,6 @@ class MoonVitPretrainedModel(PreTrainedModel): ...@@ -579,7 +568,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
"attn_bias": True, "attn_bias": True,
}, },
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
) )
def forward( def forward(
......
...@@ -244,7 +244,6 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -244,7 +244,6 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
vision_config=config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config, quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -103,7 +103,6 @@ class VisualTokenizer(torch.nn.Module): ...@@ -103,7 +103,6 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
visual_vocab_size: int, visual_vocab_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -111,7 +110,6 @@ class VisualTokenizer(torch.nn.Module): ...@@ -111,7 +110,6 @@ class VisualTokenizer(torch.nn.Module):
self.vit = self._init_backbone( self.vit = self._init_backbone(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vit", prefix=f"{prefix}.vit",
) )
# reserved tokens for INDICATOR_IDS # reserved tokens for INDICATOR_IDS
...@@ -130,7 +128,6 @@ class VisualTokenizer(torch.nn.Module): ...@@ -130,7 +128,6 @@ class VisualTokenizer(torch.nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
model_type = config.model_type model_type = config.model_type
...@@ -138,7 +135,6 @@ class VisualTokenizer(torch.nn.Module): ...@@ -138,7 +135,6 @@ class VisualTokenizer(torch.nn.Module):
return Siglip2NavitModel( return Siglip2NavitModel(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=prefix, prefix=prefix,
) )
raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
...@@ -464,7 +460,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -464,7 +460,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
...@@ -478,7 +473,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -478,7 +473,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
self.visual_tokenizer = VisualTokenizer( self.visual_tokenizer = VisualTokenizer(
config=config.vit_config, config=config.vit_config,
visual_vocab_size=config.visual_vocab_size, visual_vocab_size=config.visual_vocab_size,
multimodal_config=multimodal_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer", prefix=f"{prefix}.visual_tokenizer",
) )
......
...@@ -30,7 +30,7 @@ from transformers.modeling_outputs import ( ...@@ -30,7 +30,7 @@ from transformers.modeling_outputs import (
) )
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -532,7 +532,6 @@ class SiglipAttention(nn.Module): ...@@ -532,7 +532,6 @@ class SiglipAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -565,7 +564,6 @@ class SiglipAttention(nn.Module): ...@@ -565,7 +564,6 @@ class SiglipAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.apply_rotary_emb = ApplyRotaryEmb( self.apply_rotary_emb = ApplyRotaryEmb(
...@@ -662,7 +660,6 @@ class SiglipEncoderLayer(nn.Module): ...@@ -662,7 +660,6 @@ class SiglipEncoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -673,14 +670,12 @@ class SiglipEncoderLayer(nn.Module): ...@@ -673,14 +670,12 @@ class SiglipEncoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
projection_size=config.hidden_size, projection_size=config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -718,7 +713,6 @@ class SiglipEncoder(nn.Module): ...@@ -718,7 +713,6 @@ class SiglipEncoder(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -727,13 +721,9 @@ class SiglipEncoder(nn.Module): ...@@ -727,13 +721,9 @@ class SiglipEncoder(nn.Module):
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
...@@ -748,7 +738,6 @@ class SiglipEncoder(nn.Module): ...@@ -748,7 +738,6 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer( SiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
...@@ -830,7 +819,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -830,7 +819,6 @@ class SiglipVisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -841,7 +829,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -841,7 +829,6 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -880,7 +867,6 @@ class SiglipVisionModel(nn.Module): ...@@ -880,7 +867,6 @@ class SiglipVisionModel(nn.Module):
self, self,
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -888,7 +874,6 @@ class SiglipVisionModel(nn.Module): ...@@ -888,7 +874,6 @@ class SiglipVisionModel(nn.Module):
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1010,16 +995,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1010,16 +995,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, "image"): with self._mark_tower_model(vllm_config, "image"):
self.visual = SiglipVisionModel( self.visual = SiglipVisionModel(
config=config.vision_config, config=config.vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.mlp_AR = Projector(config, config.vision_config) self.mlp_AR = Projector(config, config.vision_config)
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
...@@ -96,7 +96,6 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( ...@@ -96,7 +96,6 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
def _init_img_processor( def _init_img_processor(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "", prefix: str = "",
) -> CLIPVisionModel: ) -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
...@@ -111,7 +110,6 @@ def _init_img_processor( ...@@ -111,7 +110,6 @@ def _init_img_processor(
img_processor = CLIPVisionModel( img_processor = CLIPVisionModel(
clip_config, clip_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
prefix=prefix, prefix=prefix,
) )
...@@ -170,7 +168,6 @@ class Phi3HDImageEmbedding(nn.Module): ...@@ -170,7 +168,6 @@ class Phi3HDImageEmbedding(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -181,7 +178,6 @@ class Phi3HDImageEmbedding(nn.Module): ...@@ -181,7 +178,6 @@ class Phi3HDImageEmbedding(nn.Module):
self.img_processor = _init_img_processor( self.img_processor = _init_img_processor(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.img_processor", prefix=f"{prefix}.img_processor",
) )
...@@ -596,7 +592,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -596,7 +592,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
) )
......
...@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import ( ...@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -74,6 +74,7 @@ from .utils import init_vllm_registered_model, maybe_prefix ...@@ -74,6 +74,7 @@ from .utils import init_vllm_registered_model, maybe_prefix
from .vision import ( from .vision import (
VisionEncoderInfo, VisionEncoderInfo,
VisionFeatureSelectStrategy, VisionFeatureSelectStrategy,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs, resolve_visual_encoder_outputs,
) )
...@@ -1065,17 +1066,12 @@ class PixtralHFMLP(nn.Module): ...@@ -1065,17 +1066,12 @@ class PixtralHFMLP(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
assert config.intermediate_size is not None assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -1108,7 +1104,6 @@ class PixtralHFAttention(nn.Module): ...@@ -1108,7 +1104,6 @@ class PixtralHFAttention(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -1120,11 +1115,7 @@ class PixtralHFAttention(nn.Module): ...@@ -1120,11 +1115,7 @@ class PixtralHFAttention(nn.Module):
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
assert self.total_num_heads * self.head_dim == config.hidden_size assert self.total_num_heads * self.head_dim == config.hidden_size
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
head_size=self.head_dim, head_size=self.head_dim,
...@@ -1189,7 +1180,6 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -1189,7 +1180,6 @@ class PixtralHFTransformerBlock(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -1199,13 +1189,11 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -1199,13 +1189,11 @@ class PixtralHFTransformerBlock(nn.Module):
self.attention = PixtralHFAttention( self.attention = PixtralHFAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attention", prefix=f"{prefix}.attention",
) )
self.feed_forward = PixtralHFMLP( self.feed_forward = PixtralHFMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.feed_forward", prefix=f"{prefix}.feed_forward",
) )
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
...@@ -1232,7 +1220,6 @@ class PixtralHFTransformer(nn.Module): ...@@ -1232,7 +1220,6 @@ class PixtralHFTransformer(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
prefix: str = "", prefix: str = "",
...@@ -1249,7 +1236,6 @@ class PixtralHFTransformer(nn.Module): ...@@ -1249,7 +1236,6 @@ class PixtralHFTransformer(nn.Module):
PixtralHFTransformerBlock( PixtralHFTransformerBlock(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
) )
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
...@@ -1281,7 +1267,6 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1281,7 +1267,6 @@ class PixtralHFVisionModel(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -1302,7 +1287,6 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1302,7 +1287,6 @@ class PixtralHFVisionModel(nn.Module):
self.transformer = PixtralHFTransformer( self.transformer = PixtralHFTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer", prefix=f"{prefix}.transformer",
) )
......
...@@ -846,7 +846,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -846,7 +846,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
) )
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
......
...@@ -43,7 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ...@@ -43,7 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
) )
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
...@@ -109,6 +109,7 @@ from .utils import ( ...@@ -109,6 +109,7 @@ from .utils import (
) )
from .vision import ( from .vision import (
get_vit_attn_backend, get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -266,15 +267,10 @@ class Qwen2_5_VisionMLP(nn.Module): ...@@ -266,15 +267,10 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False, bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features, input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
...@@ -308,16 +304,11 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -308,16 +304,11 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 1
if use_data_parallel if use_data_parallel
...@@ -354,7 +345,6 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -354,7 +345,6 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
...@@ -435,7 +425,6 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -435,7 +425,6 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -448,7 +437,6 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -448,7 +437,6 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.mlp = Qwen2_5_VisionMLP( self.mlp = Qwen2_5_VisionMLP(
...@@ -457,7 +445,6 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -457,7 +445,6 @@ class Qwen2_5_VisionBlock(nn.Module):
act_fn=act_fn, act_fn=act_fn,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -530,15 +517,10 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -530,15 +517,10 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(nn.LayerNorm, eps=1e-6)
...@@ -579,7 +561,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -579,7 +561,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig, vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -620,15 +601,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -620,15 +601,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
rope_parameters={"partial_rotary_factor": 0.5}, rope_parameters={"partial_rotary_factor": 0.5},
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
if self.attn_backend not in { if self.attn_backend not in {
...@@ -650,7 +625,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -650,7 +625,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
act_fn=get_act_and_mul_fn(vision_config.hidden_act), act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(depth) for layer_idx in range(depth)
...@@ -664,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -664,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size, spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
) )
...@@ -1152,7 +1125,6 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1152,7 +1125,6 @@ class Qwen2_5_VLForConditionalGeneration(
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
) )
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
......
...@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import ( ...@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -106,6 +106,7 @@ from .utils import ( ...@@ -106,6 +106,7 @@ from .utils import (
) )
from .vision import ( from .vision import (
get_vit_attn_backend, get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -247,15 +248,10 @@ class Qwen2VisionMLP(nn.Module): ...@@ -247,15 +248,10 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int, hidden_features: int,
act_layer: type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
in_features, in_features,
hidden_features, hidden_features,
...@@ -286,16 +282,11 @@ class Qwen2VisionAttention(nn.Module): ...@@ -286,16 +282,11 @@ class Qwen2VisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 1
if use_data_parallel if use_data_parallel
...@@ -328,7 +319,6 @@ class Qwen2VisionAttention(nn.Module): ...@@ -328,7 +319,6 @@ class Qwen2VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
...@@ -409,7 +399,6 @@ class Qwen2VisionBlock(nn.Module): ...@@ -409,7 +399,6 @@ class Qwen2VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -424,7 +413,6 @@ class Qwen2VisionBlock(nn.Module): ...@@ -424,7 +413,6 @@ class Qwen2VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.mlp = Qwen2VisionMLP( self.mlp = Qwen2VisionMLP(
...@@ -432,7 +420,6 @@ class Qwen2VisionBlock(nn.Module): ...@@ -432,7 +420,6 @@ class Qwen2VisionBlock(nn.Module):
mlp_hidden_dim, mlp_hidden_dim,
act_layer=act_layer, act_layer=act_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -493,15 +480,10 @@ class Qwen2VisionPatchMerger(nn.Module): ...@@ -493,15 +480,10 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(nn.LayerNorm, eps=1e-6)
...@@ -545,7 +527,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -545,7 +527,6 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig, vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -560,11 +541,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -560,11 +541,7 @@ class Qwen2VisionTransformer(nn.Module):
num_heads = vision_config.num_heads num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio mlp_ratio = vision_config.mlp_ratio
self.use_data_parallel = ( self.use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.out_hidden_size = vision_config.hidden_size self.out_hidden_size = vision_config.hidden_size
self.spatial_merge_size = spatial_merge_size self.spatial_merge_size = spatial_merge_size
...@@ -596,7 +573,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -596,7 +573,6 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
multimodal_config=multimodal_config,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -607,15 +583,10 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -607,15 +583,10 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
multimodal_config=multimodal_config,
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
) )
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
@property @property
...@@ -1238,7 +1209,6 @@ class Qwen2VLForConditionalGeneration( ...@@ -1238,7 +1209,6 @@ class Qwen2VLForConditionalGeneration(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
......
...@@ -48,7 +48,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION ...@@ -48,7 +48,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
# isort: on # isort: on
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
...@@ -160,7 +160,6 @@ class Qwen3OmniMoeAudioAttention(nn.Module): ...@@ -160,7 +160,6 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3OmniMoeAudioEncoderConfig, config: Qwen3OmniMoeAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -198,7 +197,6 @@ class Qwen3OmniMoeAudioAttention(nn.Module): ...@@ -198,7 +197,6 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
num_heads=self.num_local_heads, num_heads=self.num_local_heads,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scaling, scale=self.scaling,
multimodal_config=multimodal_config,
) )
def forward( def forward(
...@@ -233,13 +231,12 @@ class Qwen3OmniMoeAudioEncoderLayer(nn.Module): ...@@ -233,13 +231,12 @@ class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3OmniMoeAudioEncoderConfig, config: Qwen3OmniMoeAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = Qwen3OmniMoeAudioAttention( self.self_attn = Qwen3OmniMoeAudioAttention(
config, multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn" config, prefix=f"{prefix}.self_attn"
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function] self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
...@@ -301,7 +298,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): ...@@ -301,7 +298,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3OmniMoeAudioEncoderConfig, config: Qwen3OmniMoeAudioEncoderConfig,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -345,7 +341,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): ...@@ -345,7 +341,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
[ [
Qwen3OmniMoeAudioEncoderLayer( Qwen3OmniMoeAudioEncoderLayer(
config, config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{i}", prefix=f"{prefix}.layers.{i}",
) )
for i in range(config.encoder_layers) for i in range(config.encoder_layers)
...@@ -359,15 +354,9 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): ...@@ -359,15 +354,9 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
self.proj2 = nn.Linear(config.d_model, config.output_dim) self.proj2 = nn.Linear(config.d_model, config.output_dim)
# Get attention backend # Get attention backend
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=config.d_model // config.encoder_attention_heads, head_size=config.d_model // config.encoder_attention_heads,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None: def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
...@@ -601,7 +590,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -601,7 +590,6 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int, mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -615,7 +603,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -615,7 +603,6 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.mlp = Qwen3_VisionMLP( self.mlp = Qwen3_VisionMLP(
...@@ -710,7 +697,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -710,7 +697,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
vision_config, vision_config,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -758,7 +744,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -758,7 +744,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
...@@ -788,16 +773,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -788,16 +773,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
] ]
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
@property @property
...@@ -1617,7 +1595,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1617,7 +1595,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
with self._mark_tower_model(vllm_config, "audio"): with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen3OmniMoeAudioEncoder( self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config, thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"), prefix=maybe_prefix(prefix, "audio_tower"),
) )
...@@ -1638,7 +1615,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1638,7 +1615,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
) )
# register buffer for deepstack # register buffer for deepstack
......
...@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( ...@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -123,6 +123,7 @@ from .utils import ( ...@@ -123,6 +123,7 @@ from .utils import (
) )
from .vision import ( from .vision import (
get_vit_attn_backend, get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -169,15 +170,10 @@ class Qwen3_VisionMLP(nn.Module): ...@@ -169,15 +170,10 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False, bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.linear_fc1 = ColumnParallelLinear( self.linear_fc1 = ColumnParallelLinear(
in_features, in_features,
hidden_features, hidden_features,
...@@ -211,7 +207,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -211,7 +207,6 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int, mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -225,7 +220,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -225,7 +220,6 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.mlp = Qwen3_VisionMLP( self.mlp = Qwen3_VisionMLP(
...@@ -234,7 +228,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -234,7 +228,6 @@ class Qwen3_VisionBlock(nn.Module):
act_fn=act_fn, act_fn=act_fn,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -267,15 +260,10 @@ class Qwen3_VisionPatchMerger(nn.Module): ...@@ -267,15 +260,10 @@ class Qwen3_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False, use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm self.use_postshuffle_norm = use_postshuffle_norm
...@@ -321,7 +309,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -321,7 +309,6 @@ class Qwen3_VisionTransformer(nn.Module):
vision_config: Qwen3VLVisionConfig, vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -365,7 +352,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -365,7 +352,6 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size, spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
) )
...@@ -378,20 +364,15 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -378,20 +364,15 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True, use_postshuffle_norm=True,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
) )
for layer_idx in range(len(self.deepstack_visual_indexes)) for layer_idx in range(len(self.deepstack_visual_indexes))
] ]
) )
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
if self.attn_backend not in { if self.attn_backend not in {
...@@ -411,7 +392,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -411,7 +392,6 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
...@@ -1291,7 +1271,6 @@ class Qwen3VLForConditionalGeneration( ...@@ -1291,7 +1271,6 @@ class Qwen3VLForConditionalGeneration(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
......
...@@ -446,7 +446,6 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -446,7 +446,6 @@ class Qwen3VLMoeForConditionalGeneration(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
......
...@@ -16,7 +16,7 @@ from transformers import ( ...@@ -16,7 +16,7 @@ from transformers import (
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.encoder_only_attention import ( from vllm.model_executor.layers.attention.encoder_only_attention import (
...@@ -64,6 +64,7 @@ from .vision import ( ...@@ -64,6 +64,7 @@ from .vision import (
VisionFeatureSelectStrategy, VisionFeatureSelectStrategy,
VisionFeatureSelectStrategyStr, VisionFeatureSelectStrategyStr,
get_num_selected_vision_tokens, get_num_selected_vision_tokens,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs, resolve_visual_encoder_outputs,
) )
...@@ -356,7 +357,6 @@ class SiglipAttention(nn.Module): ...@@ -356,7 +357,6 @@ class SiglipAttention(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
...@@ -376,11 +376,7 @@ class SiglipAttention(nn.Module): ...@@ -376,11 +376,7 @@ class SiglipAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
...@@ -409,7 +405,6 @@ class SiglipAttention(nn.Module): ...@@ -409,7 +405,6 @@ class SiglipAttention(nn.Module):
self.head_dim, self.head_dim,
self.scale, self.scale,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
else: else:
self.attn = attn_cls( self.attn = attn_cls(
...@@ -437,17 +432,12 @@ class SiglipMLP(nn.Module): ...@@ -437,17 +432,12 @@ class SiglipMLP(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization # Special handling for BNB and torchao quantization
...@@ -487,7 +477,6 @@ class SiglipEncoderLayer(nn.Module): ...@@ -487,7 +477,6 @@ class SiglipEncoderLayer(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
...@@ -499,7 +488,6 @@ class SiglipEncoderLayer(nn.Module): ...@@ -499,7 +488,6 @@ class SiglipEncoderLayer(nn.Module):
self.self_attn = SiglipAttention( self.self_attn = SiglipAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -507,7 +495,6 @@ class SiglipEncoderLayer(nn.Module): ...@@ -507,7 +495,6 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
...@@ -535,7 +522,6 @@ class SiglipEncoder(nn.Module): ...@@ -535,7 +522,6 @@ class SiglipEncoder(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
...@@ -555,7 +541,6 @@ class SiglipEncoder(nn.Module): ...@@ -555,7 +541,6 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer( SiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -660,7 +645,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -660,7 +645,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -674,7 +658,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -674,7 +658,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -700,7 +683,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -700,7 +683,6 @@ class SiglipVisionTransformer(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -717,7 +699,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -717,7 +699,6 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention, attn_cls=MMEncoderAttention,
...@@ -756,7 +737,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -756,7 +737,6 @@ class SiglipVisionTransformer(nn.Module):
SiglipMultiheadAttentionPoolingHead( SiglipMultiheadAttentionPoolingHead(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.head", prefix=f"{prefix}.head",
) )
if self.use_head if self.use_head
...@@ -870,7 +850,6 @@ class SiglipVisionModel(nn.Module): ...@@ -870,7 +850,6 @@ class SiglipVisionModel(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -883,7 +862,6 @@ class SiglipVisionModel(nn.Module): ...@@ -883,7 +862,6 @@ class SiglipVisionModel(nn.Module):
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
...@@ -1062,9 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1062,9 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
config: SiglipConfig = vllm_config.model_config.hf_config config: SiglipConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config
if hasattr(config, "num_labels"): if hasattr(config, "num_labels"):
config.num_labels = 0 config.num_labels = 0
...@@ -1087,7 +1063,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1087,7 +1063,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
use_head=None, # Allows potential pooling head use_head=None, # Allows potential pooling head
) )
......
...@@ -11,7 +11,6 @@ from torch.nn import functional as F ...@@ -11,7 +11,6 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
...@@ -30,6 +29,8 @@ from vllm.model_executor.layers.rotary_embedding.common import ( ...@@ -30,6 +29,8 @@ from vllm.model_executor.layers.rotary_embedding.common import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .vision import is_vit_use_data_parallel
class VisionRotaryEmbedding(nn.Module): class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None: def __init__(self, dim: int, theta: float = 10000.0) -> None:
...@@ -178,9 +179,7 @@ class Siglip2Attention(nn.Module): ...@@ -178,9 +179,7 @@ class Siglip2Attention(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -196,11 +195,7 @@ class Siglip2Attention(nn.Module): ...@@ -196,11 +195,7 @@ class Siglip2Attention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
...@@ -228,7 +223,6 @@ class Siglip2Attention(nn.Module): ...@@ -228,7 +223,6 @@ class Siglip2Attention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scale, scale=self.scale,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb( self.apply_rotary_emb = ApplyRotaryEmb(
...@@ -287,16 +281,11 @@ class Siglip2MLP(nn.Module): ...@@ -287,16 +281,11 @@ class Siglip2MLP(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
...@@ -325,7 +314,6 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -325,7 +314,6 @@ class Siglip2EncoderLayer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -334,14 +322,12 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -334,14 +322,12 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention( self.self_attn = Siglip2Attention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP( self.mlp = Siglip2MLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -387,7 +373,6 @@ class Siglip2Encoder(nn.Module): ...@@ -387,7 +373,6 @@ class Siglip2Encoder(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -397,7 +382,6 @@ class Siglip2Encoder(nn.Module): ...@@ -397,7 +382,6 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer( Siglip2EncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}", prefix=f"{prefix}.layers.{idx}",
) )
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
...@@ -571,7 +555,6 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -571,7 +555,6 @@ class Siglip2VisionTransformer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -582,7 +565,6 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -582,7 +565,6 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder( self.encoder = Siglip2Encoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -610,7 +592,6 @@ class Siglip2NavitModel(torch.nn.Module): ...@@ -610,7 +592,6 @@ class Siglip2NavitModel(torch.nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -618,7 +599,6 @@ class Siglip2NavitModel(torch.nn.Module): ...@@ -618,7 +599,6 @@ class Siglip2NavitModel(torch.nn.Module):
self.vision_model = Siglip2VisionTransformer( self.vision_model = Siglip2VisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
) )
......
...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor ...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.processing_utils import ProcessingKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -351,7 +351,6 @@ def _build_tarsier_hf_processor( ...@@ -351,7 +351,6 @@ def _build_tarsier_hf_processor(
def init_vision_tower_for_tarsier( def init_vision_tower_for_tarsier(
hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
...@@ -378,7 +377,6 @@ def init_vision_tower_for_tarsier( ...@@ -378,7 +377,6 @@ def init_vision_tower_for_tarsier(
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init, num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -387,7 +385,6 @@ def init_vision_tower_for_tarsier( ...@@ -387,7 +385,6 @@ def init_vision_tower_for_tarsier(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init, num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -420,7 +417,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -420,7 +417,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config: TarsierHfConfig = vllm_config.model_config.hf_config config: TarsierHfConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config # Storing the Tarsier-specific HF config self.config = config # Storing the Tarsier-specific HF config
...@@ -428,7 +424,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -428,7 +424,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.vision_tower = init_vision_tower_for_tarsier( self.vision_tower = init_vision_tower_for_tarsier(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar ...@@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -79,7 +79,7 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf ...@@ -79,7 +79,7 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf
raise NotImplementedError(msg) raise NotImplementedError(msg)
def get_vit_attn_backend( def _get_vit_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
...@@ -95,6 +95,52 @@ def get_vit_attn_backend( ...@@ -95,6 +95,52 @@ def get_vit_attn_backend(
) )
def get_vit_attn_backend(
head_size: int,
dtype: torch.dtype,
) -> AttentionBackendEnum:
"""
Get the attention backend for Vision Transformer.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
multimodal_config: MultiModalConfig | None = (
vllm_config.model_config.multimodal_config
)
except AssertionError:
multimodal_config = None
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
attn_backend = _get_vit_attn_backend(
head_size,
dtype,
attn_backend_override=attn_backend_override,
)
return attn_backend
def is_vit_use_data_parallel():
"""
Get the tensor parallel type for Vision Transformer.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
multimodal_config: MultiModalConfig | None = (
vllm_config.model_config.multimodal_config
)
except AssertionError:
multimodal_config = None
mm_encoder_tp_mode = (
multimodal_config.mm_encoder_tp_mode if multimodal_config is not None else None
)
return mm_encoder_tp_mode == "data"
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool: def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument.""" """Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder return vllm_config.compilation_config.compile_mm_encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment