Unverified Commit 9ad7f89f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Models]: Make Multimodal config implicit in ViT implementation (#31972)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 6450b536
......@@ -4,7 +4,6 @@
import torch
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
......@@ -32,7 +31,6 @@ class MMEncoderAttention(CustomOp):
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
"""
Args:
......@@ -42,7 +40,6 @@ class MMEncoderAttention(CustomOp):
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
"""
super().__init__()
......@@ -62,16 +59,10 @@ class MMEncoderAttention(CustomOp):
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self.is_flash_attn_backend = self.attn_backend in {
......
......@@ -16,7 +16,7 @@ from transformers import (
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
......@@ -59,6 +59,7 @@ from .vision import (
VisionFeatureSelectStrategy,
VisionFeatureSelectStrategyStr,
get_num_selected_vision_tokens,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
)
......@@ -353,7 +354,6 @@ class CLIPAttention(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention],
......@@ -372,11 +372,7 @@ class CLIPAttention(nn.Module):
)
self.scale = self.head_dim**-0.5
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
......@@ -405,7 +401,6 @@ class CLIPAttention(nn.Module):
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
......@@ -434,17 +429,12 @@ class CLIPMLP(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
......@@ -477,7 +467,6 @@ class CLIPEncoderLayer(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention],
......@@ -487,7 +476,6 @@ class CLIPEncoderLayer(nn.Module):
self.self_attn = CLIPAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
......@@ -495,7 +483,6 @@ class CLIPEncoderLayer(nn.Module):
self.mlp = CLIPMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......@@ -528,7 +515,6 @@ class CLIPEncoder(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
......@@ -548,7 +534,6 @@ class CLIPEncoder(nn.Module):
CLIPEncoderLayer(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
......@@ -658,7 +643,6 @@ class CLIPVisionTransformer(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
......@@ -678,7 +662,6 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
......@@ -780,7 +763,6 @@ class CLIPVisionModel(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
......@@ -791,7 +773,6 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
......@@ -869,7 +850,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = CLIPVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.visual_projection = nn.Linear(
......
......@@ -18,7 +18,6 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.config import MultiModalConfig
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -609,7 +608,6 @@ class DeepCLIPVisionTransformer(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
......@@ -628,7 +626,6 @@ class DeepCLIPVisionTransformer(nn.Module):
self.transformer = CLIPEncoder(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
......
......@@ -398,7 +398,6 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
......
......@@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
......@@ -60,7 +60,7 @@ from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionCon
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .vision import run_dp_sharded_mrope_vision_model
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
IMAGE_TOKEN = "<|imgpad|>"
......@@ -183,9 +183,9 @@ class PatchMerger(nn.Module):
spatial_merge_size: int = 2,
pre_norm="layernorm",
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.pre_norm = pre_norm
if self.pre_norm == "layernorm":
......@@ -230,15 +230,10 @@ class DotsVisionAttention(nn.Module):
bias: bool = True,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.embed_dim = dim
self.tp_size = (
......@@ -272,7 +267,6 @@ class DotsVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
......@@ -329,7 +323,6 @@ class DotsSwiGLUFFN(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -337,11 +330,7 @@ class DotsSwiGLUFFN(nn.Module):
in_features = config.embed_dim
bias = config.use_bias
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
# Referenced aimv2.py AIMv2SwiGLUFFN
self.fc13 = MergedColumnParallelLinear(
in_features,
......@@ -447,7 +436,6 @@ class DotsVisionBlock(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -458,14 +446,12 @@ class DotsVisionBlock(nn.Module):
num_heads=config.num_attention_heads,
bias=config.use_bias,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
......@@ -493,7 +479,6 @@ class DotsVisionTransformer(nn.Module):
self,
config: DotsVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
......@@ -507,15 +492,9 @@ class DotsVisionTransformer(nn.Module):
head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers
......@@ -529,7 +508,6 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{i}",
)
for i in range(num_layers)
......@@ -542,16 +520,10 @@ class DotsVisionTransformer(nn.Module):
else:
self.post_trunk_norm = None
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.merger = PatchMerger(
dim=config.hidden_size,
context_dim=config.embed_dim,
spatial_merge_size=config.spatial_merge_size,
use_data_parallel=use_data_parallel,
)
@property
......@@ -693,7 +665,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -270,7 +270,6 @@ class Eagle2_5_VLForConditionalGeneration(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=self.multimodal_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
......
......@@ -36,7 +36,7 @@ import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
......@@ -119,7 +119,6 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -153,7 +152,6 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
......@@ -266,7 +264,6 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -282,7 +279,6 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
......@@ -357,7 +353,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -393,7 +388,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
......@@ -405,13 +399,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
)
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
@property
......@@ -1308,7 +1298,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.resampler_model = VariableResolutionResamplerModel(
......
......@@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import (
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
......@@ -107,6 +107,7 @@ from .utils import (
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
)
......@@ -196,15 +197,10 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int,
bias: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
......@@ -258,16 +254,11 @@ class Glm4vVisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
......@@ -305,7 +296,6 @@ class Glm4vVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
......@@ -373,7 +363,6 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -386,7 +375,6 @@ class Glm4vVisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Glm4vVisionMLP(
......@@ -394,7 +382,6 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -454,16 +441,11 @@ class Glm4vPatchMerger(nn.Module):
d_model: int,
context_dim: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = d_model
self.proj = ColumnParallelLinear(
self.hidden_size,
......@@ -619,13 +601,10 @@ class Glm4vVisionTransformer(nn.Module):
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
assert multimodal_config is not None, "multimodal_config must be provided"
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
......@@ -660,7 +639,6 @@ class Glm4vVisionTransformer(nn.Module):
mlp_hidden_dim=vision_config.out_hidden_size,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
......@@ -670,7 +648,6 @@ class Glm4vVisionTransformer(nn.Module):
d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
bias=False,
prefix=f"{prefix}.merger",
)
......@@ -692,7 +669,6 @@ class Glm4vVisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=multimodal_config.mm_encoder_attn_backend,
)
@property
......@@ -1439,7 +1415,6 @@ class Glm4vForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
......
......@@ -33,7 +33,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
......@@ -80,7 +80,6 @@ from vllm.transformers_utils.configs.hunyuan_vl import (
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import (
MultiModalEmbeddings,
......@@ -96,6 +95,7 @@ from .utils import (
init_vllm_registered_model,
maybe_prefix,
)
from .vision import is_vit_use_data_parallel
logger = init_logger(__name__)
......@@ -160,9 +160,9 @@ class HunYuanVisionMLP(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.dense_h_to_4h = ColumnParallelLinear(
in_features,
hidden_features,
......@@ -194,12 +194,11 @@ class HunYuanVisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
......@@ -237,7 +236,6 @@ class HunYuanVisionAttention(nn.Module):
self.hidden_size_per_attention_head,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
def forward(
......@@ -260,9 +258,7 @@ class HunYuanVisionBlock(nn.Module):
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
......@@ -274,9 +270,7 @@ class HunYuanVisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
)
self.mlp = HunYuanVisionMLP(
dim,
......@@ -285,7 +279,6 @@ class HunYuanVisionBlock(nn.Module):
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
......@@ -439,9 +432,6 @@ class HunYuanVisionTransformer(nn.Module):
vision_config: HunYuanVLVisionConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
......@@ -467,9 +457,7 @@ class HunYuanVisionTransformer(nn.Module):
act_fn=get_act_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(num_hidden_layers)
]
......@@ -872,23 +860,14 @@ class HunYuanVLForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: HunYuanVLConfig = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, {"image"}):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = HunYuanVisionTransformer(
config.vision_config,
quant_config=self.quant_config,
quant_config=vllm_config.quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
attn_backend_override=attn_backend_override,
)
with self._mark_language_model(vllm_config):
......
......@@ -17,7 +17,7 @@ from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
......@@ -360,7 +360,6 @@ def _build_hcxvision_hf_processor(
def init_vision_tower_for_hcxvision(
vision_config,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
use_nth_layer: int | None = None,
require_post_norm: bool | None = None,
......@@ -378,7 +377,6 @@ def init_vision_tower_for_hcxvision(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -387,7 +385,6 @@ def init_vision_tower_for_hcxvision(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -605,7 +602,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# init configs
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# text_config
text_config = config.text_config
if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
......@@ -628,7 +624,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_model = init_vision_tower_for_hcxvision(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
use_nth_layer=getattr(config, "use_nth_layer", -1),
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"),
......
......@@ -16,7 +16,7 @@ from transformers.image_processing_utils import BatchFeature
from transformers.tokenization_utils import TensorType
from typing_extensions import TypedDict, Unpack
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.model import ModelConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
......@@ -72,6 +72,8 @@ from vllm.transformers_utils.configs import (
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .vision import is_vit_use_data_parallel
def create_cumulative_seq_lengths(
seq_sizes: torch.Tensor, device: torch.device
......@@ -942,15 +944,10 @@ class Siglip2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
......@@ -987,7 +984,6 @@ class Siglip2VisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
......@@ -1038,7 +1034,6 @@ class Siglip2EncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
......@@ -1047,7 +1042,6 @@ class Siglip2EncoderLayer(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
multimodal_config=multimodal_config,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
......@@ -1088,7 +1082,6 @@ class Siglip2Encoder(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.config = config
......@@ -1098,7 +1091,6 @@ class Siglip2Encoder(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
multimodal_config=multimodal_config,
)
for layer_idx in range(config.num_hidden_layers)
]
......@@ -1127,7 +1119,6 @@ class Siglip2VisionTransformer(nn.Module):
config: PixelShuffleSiglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
):
super().__init__()
self.config = config
......@@ -1140,7 +1131,6 @@ class Siglip2VisionTransformer(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
......@@ -1221,14 +1211,12 @@ class IsaacVisionEmbedding(nn.Module):
hidden_dim: int,
output_dim: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
self.transformer = Siglip2VisionTransformer(
vision_cfg,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "0"),
)
self.linear_fc1 = ColumnParallelLinear(
......@@ -1309,7 +1297,6 @@ class IsaacForConditionalGeneration(
config: IsaacConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.multimodal_config = vllm_config.model_config.multimodal_config
head_dim = config.head_dim
calculated_mrope_section = [
......@@ -1373,7 +1360,6 @@ class IsaacForConditionalGeneration(
hidden_dim=hidden_dim,
output_dim=config.hidden_size,
quant_config=quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "vision_embedding"),
)
......
......@@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
......@@ -80,6 +80,7 @@ from .utils import (
is_pp_missing_parameter,
maybe_prefix,
)
from .vision import is_vit_use_data_parallel
logger = init_logger(__name__)
......@@ -358,7 +359,6 @@ class KeyeSiglipAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -366,7 +366,8 @@ class KeyeSiglipAttention(nn.Module):
hidden_size = config.hidden_size
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
......@@ -403,7 +404,6 @@ class KeyeSiglipAttention(nn.Module):
scale=self.scale,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
......@@ -497,7 +497,6 @@ class KeyeSiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -506,14 +505,12 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -552,7 +549,6 @@ class KeyeSiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -565,7 +561,6 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(config.num_hidden_layers)
......@@ -647,7 +642,6 @@ class KeyeSiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -658,7 +652,6 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
......@@ -730,7 +723,6 @@ class KeyeSiglipVisionModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -738,7 +730,6 @@ class KeyeSiglipVisionModel(nn.Module):
self.vision_model = KeyeSiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
)
self.quant_config = quant_config
......@@ -1275,16 +1266,13 @@ class BaseKeyeModule(nn.Module, SupportsMultiModal):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = self._build_projector(
......
......@@ -317,7 +317,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = MoonVitPretrainedModel(
config.vision_config,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = KimiVLMultiModalProjector(
......
......@@ -11,7 +11,6 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
......@@ -23,7 +22,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import should_torch_compile_mm_vit
from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit
class Siglip2VisionEmbeddings(nn.Module):
......@@ -154,7 +153,6 @@ class Siglip2Attention(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -171,10 +169,7 @@ class Siglip2Attention(nn.Module):
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
use_data_parallel = (
multimodal_config is not None
and multimodal_config.mm_encoder_tp_mode == "data"
)
use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size
......@@ -199,7 +194,6 @@ class Siglip2Attention(nn.Module):
head_size=self.head_dim,
scale=self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
def forward(
......@@ -241,16 +235,12 @@ class Siglip2MLP(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
use_data_parallel = (
multimodal_config is not None
and multimodal_config.mm_encoder_tp_mode == "data"
)
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
......@@ -282,7 +272,6 @@ class Siglip2EncoderLayer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -291,14 +280,12 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
......@@ -344,7 +331,6 @@ class Siglip2Encoder(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -354,7 +340,6 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(config.num_hidden_layers)
......@@ -383,7 +368,6 @@ class Siglip2VisionTransformer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -397,7 +381,6 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
......@@ -438,7 +421,6 @@ class Siglip2Model(torch.nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -446,7 +428,6 @@ class Siglip2Model(torch.nn.Module):
self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
)
......
......@@ -600,7 +600,6 @@ class Lfm2VLForConditionalGeneration(
self.vision_tower = Siglip2Model(
config=vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
else:
......
......@@ -166,7 +166,6 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -456,7 +456,6 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
require_post_norm: bool | None = None,
prefix: str = "",
......@@ -470,7 +469,6 @@ def init_vision_tower_for_llava(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -479,7 +477,6 @@ def init_vision_tower_for_llava(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -488,7 +485,6 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
......@@ -562,7 +558,6 @@ class LlavaForConditionalGeneration(
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -272,7 +272,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -332,7 +332,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
......@@ -513,7 +513,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment