Unverified Commit 9ad7f89f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Models]: Make Multimodal config implicit in ViT implementation (#31972)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 6450b536
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import torch import torch
from vllm.config import MultiModalConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
...@@ -32,7 +31,6 @@ class MMEncoderAttention(CustomOp): ...@@ -32,7 +31,6 @@ class MMEncoderAttention(CustomOp):
scale: float | None = None, scale: float | None = None,
num_kv_heads: int | None = None, num_kv_heads: int | None = None,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None: ) -> None:
""" """
Args: Args:
...@@ -42,7 +40,6 @@ class MMEncoderAttention(CustomOp): ...@@ -42,7 +40,6 @@ class MMEncoderAttention(CustomOp):
num_kv_heads: number of kv heads. num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
""" """
super().__init__() super().__init__()
...@@ -62,16 +59,10 @@ class MMEncoderAttention(CustomOp): ...@@ -62,16 +59,10 @@ class MMEncoderAttention(CustomOp):
# weight and activation dtype. # weight and activation dtype.
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
# Get device-specific vision attention backend. # Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
attn_backend_override=attn_backend_override,
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
......
...@@ -16,7 +16,7 @@ from transformers import ( ...@@ -16,7 +16,7 @@ from transformers import (
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
...@@ -59,6 +59,7 @@ from .vision import ( ...@@ -59,6 +59,7 @@ from .vision import (
VisionFeatureSelectStrategy, VisionFeatureSelectStrategy,
VisionFeatureSelectStrategyStr, VisionFeatureSelectStrategyStr,
get_num_selected_vision_tokens, get_num_selected_vision_tokens,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs, resolve_visual_encoder_outputs,
) )
...@@ -353,7 +354,6 @@ class CLIPAttention(nn.Module): ...@@ -353,7 +354,6 @@ class CLIPAttention(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
...@@ -372,11 +372,7 @@ class CLIPAttention(nn.Module): ...@@ -372,11 +372,7 @@ class CLIPAttention(nn.Module):
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
...@@ -405,7 +401,6 @@ class CLIPAttention(nn.Module): ...@@ -405,7 +401,6 @@ class CLIPAttention(nn.Module):
self.head_dim, self.head_dim,
self.scale, self.scale,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
else: else:
self.attn = attn_cls( self.attn = attn_cls(
...@@ -434,17 +429,12 @@ class CLIPMLP(nn.Module): ...@@ -434,17 +429,12 @@ class CLIPMLP(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
...@@ -477,7 +467,6 @@ class CLIPEncoderLayer(nn.Module): ...@@ -477,7 +467,6 @@ class CLIPEncoderLayer(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
...@@ -487,7 +476,6 @@ class CLIPEncoderLayer(nn.Module): ...@@ -487,7 +476,6 @@ class CLIPEncoderLayer(nn.Module):
self.self_attn = CLIPAttention( self.self_attn = CLIPAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -495,7 +483,6 @@ class CLIPEncoderLayer(nn.Module): ...@@ -495,7 +483,6 @@ class CLIPEncoderLayer(nn.Module):
self.mlp = CLIPMLP( self.mlp = CLIPMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -528,7 +515,6 @@ class CLIPEncoder(nn.Module): ...@@ -528,7 +515,6 @@ class CLIPEncoder(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
...@@ -548,7 +534,6 @@ class CLIPEncoder(nn.Module): ...@@ -548,7 +534,6 @@ class CLIPEncoder(nn.Module):
CLIPEncoderLayer( CLIPEncoderLayer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -658,7 +643,6 @@ class CLIPVisionTransformer(nn.Module): ...@@ -658,7 +643,6 @@ class CLIPVisionTransformer(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -678,7 +662,6 @@ class CLIPVisionTransformer(nn.Module): ...@@ -678,7 +662,6 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention, attn_cls=MMEncoderAttention,
...@@ -780,7 +763,6 @@ class CLIPVisionModel(nn.Module): ...@@ -780,7 +763,6 @@ class CLIPVisionModel(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -791,7 +773,6 @@ class CLIPVisionModel(nn.Module): ...@@ -791,7 +773,6 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
...@@ -869,7 +850,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -869,7 +850,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
self.visual_projection = nn.Linear( self.visual_projection = nn.Linear(
......
...@@ -18,7 +18,6 @@ import torch.nn as nn ...@@ -18,7 +18,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from vllm.config import MultiModalConfig
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -609,7 +608,6 @@ class DeepCLIPVisionTransformer(nn.Module): ...@@ -609,7 +608,6 @@ class DeepCLIPVisionTransformer(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
prefix: str = "", prefix: str = "",
...@@ -628,7 +626,6 @@ class DeepCLIPVisionTransformer(nn.Module): ...@@ -628,7 +626,6 @@ class DeepCLIPVisionTransformer(nn.Module):
self.transformer = CLIPEncoder( self.transformer = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention, attn_cls=MMEncoderAttention,
......
...@@ -398,7 +398,6 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -398,7 +398,6 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
self.vision_model = DeepCLIPVisionTransformer( self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config, config=clip_vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -60,7 +60,7 @@ from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionCon ...@@ -60,7 +60,7 @@ from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionCon
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .vision import run_dp_sharded_mrope_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
IMAGE_TOKEN = "<|imgpad|>" IMAGE_TOKEN = "<|imgpad|>"
...@@ -183,9 +183,9 @@ class PatchMerger(nn.Module): ...@@ -183,9 +183,9 @@ class PatchMerger(nn.Module):
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
pre_norm="layernorm", pre_norm="layernorm",
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
self.pre_norm = pre_norm self.pre_norm = pre_norm
if self.pre_norm == "layernorm": if self.pre_norm == "layernorm":
...@@ -230,15 +230,10 @@ class DotsVisionAttention(nn.Module): ...@@ -230,15 +230,10 @@ class DotsVisionAttention(nn.Module):
bias: bool = True, bias: bool = True,
*, *,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.embed_dim = dim self.embed_dim = dim
self.tp_size = ( self.tp_size = (
...@@ -272,7 +267,6 @@ class DotsVisionAttention(nn.Module): ...@@ -272,7 +267,6 @@ class DotsVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
...@@ -329,7 +323,6 @@ class DotsSwiGLUFFN(nn.Module): ...@@ -329,7 +323,6 @@ class DotsSwiGLUFFN(nn.Module):
config, config,
*, *,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -337,11 +330,7 @@ class DotsSwiGLUFFN(nn.Module): ...@@ -337,11 +330,7 @@ class DotsSwiGLUFFN(nn.Module):
in_features = config.embed_dim in_features = config.embed_dim
bias = config.use_bias bias = config.use_bias
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
# Referenced aimv2.py AIMv2SwiGLUFFN # Referenced aimv2.py AIMv2SwiGLUFFN
self.fc13 = MergedColumnParallelLinear( self.fc13 = MergedColumnParallelLinear(
in_features, in_features,
...@@ -447,7 +436,6 @@ class DotsVisionBlock(nn.Module): ...@@ -447,7 +436,6 @@ class DotsVisionBlock(nn.Module):
config, config,
*, *,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -458,14 +446,12 @@ class DotsVisionBlock(nn.Module): ...@@ -458,14 +446,12 @@ class DotsVisionBlock(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN( self.mlp = DotsSwiGLUFFN(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
...@@ -493,7 +479,6 @@ class DotsVisionTransformer(nn.Module): ...@@ -493,7 +479,6 @@ class DotsVisionTransformer(nn.Module):
self, self,
config: DotsVisionConfig, config: DotsVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -507,15 +492,9 @@ class DotsVisionTransformer(nn.Module): ...@@ -507,15 +492,9 @@ class DotsVisionTransformer(nn.Module):
head_dim = config.embed_dim // config.num_attention_heads head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
self.out_hidden_size = config.hidden_size self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers # Keep blocks for compatibility with other vision towers
...@@ -529,7 +508,6 @@ class DotsVisionTransformer(nn.Module): ...@@ -529,7 +508,6 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock( DotsVisionBlock(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{i}", prefix=f"{prefix}.blocks.{i}",
) )
for i in range(num_layers) for i in range(num_layers)
...@@ -542,16 +520,10 @@ class DotsVisionTransformer(nn.Module): ...@@ -542,16 +520,10 @@ class DotsVisionTransformer(nn.Module):
else: else:
self.post_trunk_norm = None self.post_trunk_norm = None
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.merger = PatchMerger( self.merger = PatchMerger(
dim=config.hidden_size, dim=config.hidden_size,
context_dim=config.embed_dim, context_dim=config.embed_dim,
spatial_merge_size=config.spatial_merge_size, spatial_merge_size=config.spatial_merge_size,
use_data_parallel=use_data_parallel,
) )
@property @property
...@@ -693,7 +665,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -693,7 +665,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.vision_tower = DotsVisionTransformer( self.vision_tower = DotsVisionTransformer(
vision_config, vision_config,
quant_config=self.quant_config, quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -270,7 +270,6 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -270,7 +270,6 @@ class Eagle2_5_VLForConditionalGeneration(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=self.multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
prefix=prefix, prefix=prefix,
) )
......
...@@ -36,7 +36,7 @@ import torch.nn.functional as F ...@@ -36,7 +36,7 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -119,7 +119,6 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -119,7 +119,6 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -153,7 +152,6 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -153,7 +152,6 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
...@@ -266,7 +264,6 @@ class Ernie4_5_VisionBlock(nn.Module): ...@@ -266,7 +264,6 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -282,7 +279,6 @@ class Ernie4_5_VisionBlock(nn.Module): ...@@ -282,7 +279,6 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
...@@ -357,7 +353,6 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -357,7 +353,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config, vision_config,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -393,7 +388,6 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -393,7 +388,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(depth) for layer_idx in range(depth)
...@@ -405,13 +399,9 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -405,13 +399,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
) )
self.ln = nn.LayerNorm(hidden_size, eps=1e-6) self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
@property @property
...@@ -1308,7 +1298,6 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1308,7 +1298,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
self.resampler_model = VariableResolutionResamplerModel( self.resampler_model = VariableResolutionResamplerModel(
......
...@@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import ( ...@@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import (
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -107,6 +107,7 @@ from .utils import ( ...@@ -107,6 +107,7 @@ from .utils import (
) )
from .vision import ( from .vision import (
get_vit_attn_backend, get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -196,15 +197,10 @@ class Glm4vVisionMLP(nn.Module): ...@@ -196,15 +197,10 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int, hidden_features: int,
bias: bool = False, bias: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features, input_size=in_features,
output_sizes=[hidden_features] * 2, output_sizes=[hidden_features] * 2,
...@@ -258,16 +254,11 @@ class Glm4vVisionAttention(nn.Module): ...@@ -258,16 +254,11 @@ class Glm4vVisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size() 1 if use_data_parallel else get_tensor_model_parallel_world_size()
) )
...@@ -305,7 +296,6 @@ class Glm4vVisionAttention(nn.Module): ...@@ -305,7 +296,6 @@ class Glm4vVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
...@@ -373,7 +363,6 @@ class Glm4vVisionBlock(nn.Module): ...@@ -373,7 +363,6 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim: int, mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -386,7 +375,6 @@ class Glm4vVisionBlock(nn.Module): ...@@ -386,7 +375,6 @@ class Glm4vVisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.mlp = Glm4vVisionMLP( self.mlp = Glm4vVisionMLP(
...@@ -394,7 +382,6 @@ class Glm4vVisionBlock(nn.Module): ...@@ -394,7 +382,6 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim, mlp_hidden_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -454,16 +441,11 @@ class Glm4vPatchMerger(nn.Module): ...@@ -454,16 +441,11 @@ class Glm4vPatchMerger(nn.Module):
d_model: int, d_model: int,
context_dim: int, context_dim: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = d_model self.hidden_size = d_model
self.proj = ColumnParallelLinear( self.proj = ColumnParallelLinear(
self.hidden_size, self.hidden_size,
...@@ -619,13 +601,10 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -619,13 +601,10 @@ class Glm4vVisionTransformer(nn.Module):
vision_config: Glm4vVisionConfig, vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
assert multimodal_config is not None, "multimodal_config must be provided"
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels in_channels = vision_config.in_channels
...@@ -660,7 +639,6 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -660,7 +639,6 @@ class Glm4vVisionTransformer(nn.Module):
mlp_hidden_dim=vision_config.out_hidden_size, mlp_hidden_dim=vision_config.out_hidden_size,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(depth) for layer_idx in range(depth)
...@@ -670,7 +648,6 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -670,7 +648,6 @@ class Glm4vVisionTransformer(nn.Module):
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size, context_dim=vision_config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
bias=False, bias=False,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
) )
...@@ -692,7 +669,6 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -692,7 +669,6 @@ class Glm4vVisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=multimodal_config.mm_encoder_attn_backend,
) )
@property @property
...@@ -1439,7 +1415,6 @@ class Glm4vForConditionalGeneration( ...@@ -1439,7 +1415,6 @@ class Glm4vForConditionalGeneration(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5), norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
......
...@@ -33,7 +33,7 @@ import torch.nn as nn ...@@ -33,7 +33,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -80,7 +80,6 @@ from vllm.transformers_utils.configs.hunyuan_vl import ( ...@@ -80,7 +80,6 @@ from vllm.transformers_utils.configs.hunyuan_vl import (
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
...@@ -96,6 +95,7 @@ from .utils import ( ...@@ -96,6 +95,7 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import is_vit_use_data_parallel
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -160,9 +160,9 @@ class HunYuanVisionMLP(nn.Module): ...@@ -160,9 +160,9 @@ class HunYuanVisionMLP(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
in_features, in_features,
hidden_features, hidden_features,
...@@ -194,12 +194,11 @@ class HunYuanVisionAttention(nn.Module): ...@@ -194,12 +194,11 @@ class HunYuanVisionAttention(nn.Module):
num_heads: int, num_heads: int,
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = ( self.tp_size = (
1 1
if use_data_parallel if use_data_parallel
...@@ -237,7 +236,6 @@ class HunYuanVisionAttention(nn.Module): ...@@ -237,7 +236,6 @@ class HunYuanVisionAttention(nn.Module):
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.scale, self.scale,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
def forward( def forward(
...@@ -260,9 +258,7 @@ class HunYuanVisionBlock(nn.Module): ...@@ -260,9 +258,7 @@ class HunYuanVisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -274,9 +270,7 @@ class HunYuanVisionBlock(nn.Module): ...@@ -274,9 +270,7 @@ class HunYuanVisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
) )
self.mlp = HunYuanVisionMLP( self.mlp = HunYuanVisionMLP(
dim, dim,
...@@ -285,7 +279,6 @@ class HunYuanVisionBlock(nn.Module): ...@@ -285,7 +279,6 @@ class HunYuanVisionBlock(nn.Module):
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -439,9 +432,6 @@ class HunYuanVisionTransformer(nn.Module): ...@@ -439,9 +432,6 @@ class HunYuanVisionTransformer(nn.Module):
vision_config: HunYuanVLVisionConfig, vision_config: HunYuanVLVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -467,9 +457,7 @@ class HunYuanVisionTransformer(nn.Module): ...@@ -467,9 +457,7 @@ class HunYuanVisionTransformer(nn.Module):
act_fn=get_act_fn(vision_config.hidden_act), act_fn=get_act_fn(vision_config.hidden_act),
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) )
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
] ]
...@@ -872,23 +860,14 @@ class HunYuanVLForConditionalGeneration( ...@@ -872,23 +860,14 @@ class HunYuanVLForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: HunYuanVLConfig = vllm_config.model_config.hf_config config: HunYuanVLConfig = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, {"image"}): with self._mark_tower_model(vllm_config, {"image"}):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = HunYuanVisionTransformer( self.visual = HunYuanVisionTransformer(
config.vision_config, config.vision_config,
quant_config=self.quant_config, quant_config=vllm_config.quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override,
) )
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
......
...@@ -17,7 +17,7 @@ from timm.models.regnet import RegStage ...@@ -17,7 +17,7 @@ from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
...@@ -360,7 +360,6 @@ def _build_hcxvision_hf_processor( ...@@ -360,7 +360,6 @@ def _build_hcxvision_hf_processor(
def init_vision_tower_for_hcxvision( def init_vision_tower_for_hcxvision(
vision_config, vision_config,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
use_nth_layer: int | None = None, use_nth_layer: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -378,7 +377,6 @@ def init_vision_tower_for_hcxvision( ...@@ -378,7 +377,6 @@ def init_vision_tower_for_hcxvision(
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -387,7 +385,6 @@ def init_vision_tower_for_hcxvision( ...@@ -387,7 +385,6 @@ def init_vision_tower_for_hcxvision(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -605,7 +602,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -605,7 +602,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# init configs # init configs
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# text_config # text_config
text_config = config.text_config text_config = config.text_config
if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
...@@ -628,7 +624,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -628,7 +624,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_model = init_vision_tower_for_hcxvision( self.vision_model = init_vision_tower_for_hcxvision(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
use_nth_layer=getattr(config, "use_nth_layer", -1), use_nth_layer=getattr(config, "use_nth_layer", -1),
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
......
...@@ -16,7 +16,7 @@ from transformers.image_processing_utils import BatchFeature ...@@ -16,7 +16,7 @@ from transformers.image_processing_utils import BatchFeature
from transformers.tokenization_utils import TensorType from transformers.tokenization_utils import TensorType
from typing_extensions import TypedDict, Unpack from typing_extensions import TypedDict, Unpack
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -72,6 +72,8 @@ from vllm.transformers_utils.configs import ( ...@@ -72,6 +72,8 @@ from vllm.transformers_utils.configs import (
) )
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .vision import is_vit_use_data_parallel
def create_cumulative_seq_lengths( def create_cumulative_seq_lengths(
seq_sizes: torch.Tensor, device: torch.device seq_sizes: torch.Tensor, device: torch.device
...@@ -942,15 +944,10 @@ class Siglip2VisionAttention(nn.Module): ...@@ -942,15 +944,10 @@ class Siglip2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = ( self.tp_size = (
1 1
if use_data_parallel if use_data_parallel
...@@ -987,7 +984,6 @@ class Siglip2VisionAttention(nn.Module): ...@@ -987,7 +984,6 @@ class Siglip2VisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head, head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5, scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
...@@ -1038,7 +1034,6 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -1038,7 +1034,6 @@ class Siglip2EncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -1047,7 +1042,6 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -1047,7 +1042,6 @@ class Siglip2EncoderLayer(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
multimodal_config=multimodal_config,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -1088,7 +1082,6 @@ class Siglip2Encoder(nn.Module): ...@@ -1088,7 +1082,6 @@ class Siglip2Encoder(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -1098,7 +1091,6 @@ class Siglip2Encoder(nn.Module): ...@@ -1098,7 +1091,6 @@ class Siglip2Encoder(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
multimodal_config=multimodal_config,
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]
...@@ -1127,7 +1119,6 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -1127,7 +1119,6 @@ class Siglip2VisionTransformer(nn.Module):
config: PixelShuffleSiglip2VisionConfig, config: PixelShuffleSiglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -1140,7 +1131,6 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -1140,7 +1131,6 @@ class Siglip2VisionTransformer(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -1221,14 +1211,12 @@ class IsaacVisionEmbedding(nn.Module): ...@@ -1221,14 +1211,12 @@ class IsaacVisionEmbedding(nn.Module):
hidden_dim: int, hidden_dim: int,
output_dim: int, output_dim: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.transformer = Siglip2VisionTransformer( self.transformer = Siglip2VisionTransformer(
vision_cfg, vision_cfg,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "0"), prefix=maybe_prefix(prefix, "0"),
) )
self.linear_fc1 = ColumnParallelLinear( self.linear_fc1 = ColumnParallelLinear(
...@@ -1309,7 +1297,6 @@ class IsaacForConditionalGeneration( ...@@ -1309,7 +1297,6 @@ class IsaacForConditionalGeneration(
config: IsaacConfig = vllm_config.model_config.hf_config config: IsaacConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.multimodal_config = vllm_config.model_config.multimodal_config
head_dim = config.head_dim head_dim = config.head_dim
calculated_mrope_section = [ calculated_mrope_section = [
...@@ -1373,7 +1360,6 @@ class IsaacForConditionalGeneration( ...@@ -1373,7 +1360,6 @@ class IsaacForConditionalGeneration(
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=config.hidden_size, output_dim=config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "vision_embedding"), prefix=maybe_prefix(prefix, "vision_embedding"),
) )
......
...@@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature ...@@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -80,6 +80,7 @@ from .utils import ( ...@@ -80,6 +80,7 @@ from .utils import (
is_pp_missing_parameter, is_pp_missing_parameter,
maybe_prefix, maybe_prefix,
) )
from .vision import is_vit_use_data_parallel
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -358,7 +359,6 @@ class KeyeSiglipAttention(nn.Module): ...@@ -358,7 +359,6 @@ class KeyeSiglipAttention(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -366,7 +366,8 @@ class KeyeSiglipAttention(nn.Module): ...@@ -366,7 +366,8 @@ class KeyeSiglipAttention(nn.Module):
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size() use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
...@@ -403,7 +404,6 @@ class KeyeSiglipAttention(nn.Module): ...@@ -403,7 +404,6 @@ class KeyeSiglipAttention(nn.Module):
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb( self.apply_rotary_emb = ApplyRotaryEmb(
...@@ -497,7 +497,6 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -497,7 +497,6 @@ class KeyeSiglipEncoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -506,14 +505,12 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -506,14 +505,12 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention( self.self_attn = KeyeSiglipAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -552,7 +549,6 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -552,7 +549,6 @@ class KeyeSiglipEncoder(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -565,7 +561,6 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -565,7 +561,6 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer( KeyeSiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
...@@ -647,7 +642,6 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -647,7 +642,6 @@ class KeyeSiglipVisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -658,7 +652,6 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -658,7 +652,6 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder( self.encoder = KeyeSiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -730,7 +723,6 @@ class KeyeSiglipVisionModel(nn.Module): ...@@ -730,7 +723,6 @@ class KeyeSiglipVisionModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -738,7 +730,6 @@ class KeyeSiglipVisionModel(nn.Module): ...@@ -738,7 +730,6 @@ class KeyeSiglipVisionModel(nn.Module):
self.vision_model = KeyeSiglipVisionTransformer( self.vision_model = KeyeSiglipVisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1275,16 +1266,13 @@ class BaseKeyeModule(nn.Module, SupportsMultiModal): ...@@ -1275,16 +1266,13 @@ class BaseKeyeModule(nn.Module, SupportsMultiModal):
super().__init__() super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = KeyeSiglipVisionModel( self.visual = KeyeSiglipVisionModel(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.mlp_AR = self._build_projector( self.mlp_AR = self._build_projector(
......
...@@ -317,7 +317,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -317,7 +317,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
with self._mark_tower_model(vllm_config, "image"): with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = MoonVitPretrainedModel( self.vision_tower = MoonVitPretrainedModel(
config.vision_config, config.vision_config,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.multi_modal_projector = KimiVLMultiModalProjector( self.multi_modal_projector = KimiVLMultiModalProjector(
......
...@@ -11,7 +11,6 @@ from torch.nn import functional as F ...@@ -11,7 +11,6 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
...@@ -23,7 +22,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -23,7 +22,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import should_torch_compile_mm_vit from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit
class Siglip2VisionEmbeddings(nn.Module): class Siglip2VisionEmbeddings(nn.Module):
...@@ -154,7 +153,6 @@ class Siglip2Attention(nn.Module): ...@@ -154,7 +153,6 @@ class Siglip2Attention(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -171,10 +169,7 @@ class Siglip2Attention(nn.Module): ...@@ -171,10 +169,7 @@ class Siglip2Attention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config is not None
and multimodal_config.mm_encoder_tp_mode == "data"
)
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0 assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size self.num_heads_per_partition = self.num_heads // tp_size
...@@ -199,7 +194,6 @@ class Siglip2Attention(nn.Module): ...@@ -199,7 +194,6 @@ class Siglip2Attention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scale, scale=self.scale,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
) )
def forward( def forward(
...@@ -241,16 +235,12 @@ class Siglip2MLP(nn.Module): ...@@ -241,16 +235,12 @@ class Siglip2MLP(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
use_data_parallel = ( use_data_parallel = is_vit_use_data_parallel()
multimodal_config is not None
and multimodal_config.mm_encoder_tp_mode == "data"
)
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
...@@ -282,7 +272,6 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -282,7 +272,6 @@ class Siglip2EncoderLayer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -291,14 +280,12 @@ class Siglip2EncoderLayer(nn.Module): ...@@ -291,14 +280,12 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention( self.self_attn = Siglip2Attention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP( self.mlp = Siglip2MLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
...@@ -344,7 +331,6 @@ class Siglip2Encoder(nn.Module): ...@@ -344,7 +331,6 @@ class Siglip2Encoder(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -354,7 +340,6 @@ class Siglip2Encoder(nn.Module): ...@@ -354,7 +340,6 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer( Siglip2EncoderLayer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}", prefix=f"{prefix}.layers.{idx}",
) )
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
...@@ -383,7 +368,6 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -383,7 +368,6 @@ class Siglip2VisionTransformer(nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -397,7 +381,6 @@ class Siglip2VisionTransformer(nn.Module): ...@@ -397,7 +381,6 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder( self.encoder = Siglip2Encoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
...@@ -438,7 +421,6 @@ class Siglip2Model(torch.nn.Module): ...@@ -438,7 +421,6 @@ class Siglip2Model(torch.nn.Module):
self, self,
config: Siglip2VisionConfig, config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -446,7 +428,6 @@ class Siglip2Model(torch.nn.Module): ...@@ -446,7 +428,6 @@ class Siglip2Model(torch.nn.Module):
self.vision_model = Siglip2VisionTransformer( self.vision_model = Siglip2VisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
) )
......
...@@ -600,7 +600,6 @@ class Lfm2VLForConditionalGeneration( ...@@ -600,7 +600,6 @@ class Lfm2VLForConditionalGeneration(
self.vision_tower = Siglip2Model( self.vision_tower = Siglip2Model(
config=vision_config, config=vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
else: else:
......
...@@ -166,7 +166,6 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): ...@@ -166,7 +166,6 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor ...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -456,7 +456,6 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: ...@@ -456,7 +456,6 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
def init_vision_tower_for_llava( def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig, hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
...@@ -470,7 +469,6 @@ def init_vision_tower_for_llava( ...@@ -470,7 +469,6 @@ def init_vision_tower_for_llava(
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -479,7 +477,6 @@ def init_vision_tower_for_llava( ...@@ -479,7 +477,6 @@ def init_vision_tower_for_llava(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -488,7 +485,6 @@ def init_vision_tower_for_llava( ...@@ -488,7 +485,6 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel( return PixtralHFVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -562,7 +558,6 @@ class LlavaForConditionalGeneration( ...@@ -562,7 +558,6 @@ class LlavaForConditionalGeneration(
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -272,7 +272,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -272,7 +272,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -332,7 +332,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -332,7 +332,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -513,7 +513,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -513,7 +513,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment