Unverified Commit 5576227b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Standardize common vision encoders (#31947)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d1b6fe00
...@@ -17,7 +17,7 @@ from transformers import ( ...@@ -17,7 +17,7 @@ from transformers import (
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -353,6 +353,7 @@ class CLIPAttention(nn.Module): ...@@ -353,6 +353,7 @@ class CLIPAttention(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
...@@ -365,18 +366,24 @@ class CLIPAttention(nn.Module): ...@@ -365,18 +366,24 @@ class CLIPAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim: if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError( raise ValueError(
"embed_dim must be divisible by num_heads " f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f"(got `embed_dim`: {self.embed_dim} and "
f" {self.num_heads})." f"`num_heads`: {self.num_heads})."
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
...@@ -384,17 +391,29 @@ class CLIPAttention(nn.Module): ...@@ -384,17 +391,29 @@ class CLIPAttention(nn.Module):
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = attn_cls( if attn_cls == MMEncoderAttention:
self.num_heads_per_partition, self.attn = attn_cls(
self.head_dim, self.num_heads_per_partition,
self.scale, self.head_dim,
prefix=f"{prefix}.attn", self.scale,
) prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
def forward( def forward(
self, self,
...@@ -415,17 +434,26 @@ class CLIPMLP(nn.Module): ...@@ -415,17 +434,26 @@ class CLIPMLP(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
...@@ -433,6 +461,7 @@ class CLIPMLP(nn.Module): ...@@ -433,6 +461,7 @@ class CLIPMLP(nn.Module):
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2", prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -448,19 +477,27 @@ class CLIPEncoderLayer(nn.Module): ...@@ -448,19 +477,27 @@ class CLIPEncoderLayer(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = CLIPAttention( self.self_attn = CLIPAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") self.mlp = CLIPMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -491,6 +528,7 @@ class CLIPEncoder(nn.Module): ...@@ -491,6 +528,7 @@ class CLIPEncoder(nn.Module):
self, self,
config: CLIPTextConfig | CLIPVisionConfig, config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
...@@ -504,11 +542,13 @@ class CLIPEncoder(nn.Module): ...@@ -504,11 +542,13 @@ class CLIPEncoder(nn.Module):
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
else: else:
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
CLIPEncoderLayer( CLIPEncoderLayer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -618,6 +658,7 @@ class CLIPVisionTransformer(nn.Module): ...@@ -618,6 +658,7 @@ class CLIPVisionTransformer(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -637,6 +678,7 @@ class CLIPVisionTransformer(nn.Module): ...@@ -637,6 +678,7 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention, attn_cls=MMEncoderAttention,
...@@ -738,6 +780,7 @@ class CLIPVisionModel(nn.Module): ...@@ -738,6 +780,7 @@ class CLIPVisionModel(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -748,6 +791,7 @@ class CLIPVisionModel(nn.Module): ...@@ -748,6 +791,7 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
...@@ -817,6 +861,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -817,6 +861,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
......
...@@ -19,6 +19,7 @@ import torch.nn.functional as F ...@@ -19,6 +19,7 @@ import torch.nn.functional as F
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -608,6 +609,7 @@ class DeepCLIPVisionTransformer(nn.Module): ...@@ -608,6 +609,7 @@ class DeepCLIPVisionTransformer(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
prefix: str = "", prefix: str = "",
...@@ -626,6 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module): ...@@ -626,6 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
self.transformer = CLIPEncoder( self.transformer = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention, attn_cls=MMEncoderAttention,
......
...@@ -397,6 +397,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -397,6 +397,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
self.vision_model = DeepCLIPVisionTransformer( self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config, config=clip_vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
......
...@@ -6,7 +6,7 @@ from collections import defaultdict ...@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import accumulate from itertools import accumulate
from typing import Annotated, Any, Literal from typing import Annotated, Literal
import numpy as np import numpy as np
import torch import torch
...@@ -18,7 +18,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig ...@@ -18,7 +18,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
...@@ -361,6 +361,7 @@ def _build_hcxvision_hf_processor( ...@@ -361,6 +361,7 @@ def _build_hcxvision_hf_processor(
def init_vision_tower_for_hcxvision( def init_vision_tower_for_hcxvision(
vision_config, vision_config,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
use_nth_layer: int | None = None, use_nth_layer: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -378,6 +379,7 @@ def init_vision_tower_for_hcxvision( ...@@ -378,6 +379,7 @@ def init_vision_tower_for_hcxvision(
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -386,6 +388,7 @@ def init_vision_tower_for_hcxvision( ...@@ -386,6 +388,7 @@ def init_vision_tower_for_hcxvision(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -597,18 +600,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -597,18 +600,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs: Any | None,
) -> None:
super().__init__() super().__init__()
# init configs # init configs
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# text_config # text_config
text_config = config.text_config text_config = config.text_config
if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
...@@ -631,7 +629,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -631,7 +629,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with no_init_weights(): # weight will be loaded in from_pretrained with no_init_weights(): # weight will be loaded in from_pretrained
self.vision_model = init_vision_tower_for_hcxvision( self.vision_model = init_vision_tower_for_hcxvision(
vision_config, vision_config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
use_nth_layer=getattr(config, "use_nth_layer", -1), use_nth_layer=getattr(config, "use_nth_layer", -1),
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
......
...@@ -1226,8 +1226,8 @@ class IsaacVisionEmbedding(nn.Module): ...@@ -1226,8 +1226,8 @@ class IsaacVisionEmbedding(nn.Module):
self.transformer = Siglip2VisionTransformer( self.transformer = Siglip2VisionTransformer(
vision_cfg, vision_cfg,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "0"),
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "0"),
) )
self.linear_fc1 = ColumnParallelLinear( self.linear_fc1 = ColumnParallelLinear(
hidden_dim, hidden_dim,
......
...@@ -404,6 +404,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -404,6 +404,7 @@ class KeyeSiglipAttention(nn.Module):
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
num_heads=self.num_heads, num_heads=self.num_heads,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
...@@ -511,6 +512,7 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -511,6 +512,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
......
...@@ -155,6 +155,7 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): ...@@ -155,6 +155,7 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -164,7 +165,8 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): ...@@ -164,7 +165,8 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor ...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -468,6 +468,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: ...@@ -468,6 +468,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
def init_vision_tower_for_llava( def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig, hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
...@@ -481,6 +482,7 @@ def init_vision_tower_for_llava( ...@@ -481,6 +482,7 @@ def init_vision_tower_for_llava(
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -489,6 +491,7 @@ def init_vision_tower_for_llava( ...@@ -489,6 +491,7 @@ def init_vision_tower_for_llava(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -497,6 +500,7 @@ def init_vision_tower_for_llava( ...@@ -497,6 +500,7 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel( return PixtralHFVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -563,7 +567,8 @@ class LlavaForConditionalGeneration( ...@@ -563,7 +567,8 @@ class LlavaForConditionalGeneration(
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -243,6 +243,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -243,6 +243,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -270,7 +271,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -270,7 +271,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -321,6 +321,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -321,6 +321,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -331,7 +332,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -331,7 +332,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -511,7 +511,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -511,7 +511,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -204,7 +204,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -204,7 +204,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -16,7 +16,7 @@ from transformers import ( ...@@ -16,7 +16,7 @@ from transformers import (
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
...@@ -395,6 +395,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: ...@@ -395,6 +395,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
def init_vision_tower_for_llava( def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig, hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
...@@ -409,6 +410,7 @@ def init_vision_tower_for_llava( ...@@ -409,6 +410,7 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel( return PixtralHFVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -472,7 +474,8 @@ class Mistral3ForConditionalGeneration( ...@@ -472,7 +474,8 @@ class Mistral3ForConditionalGeneration(
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
...@@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig ...@@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
...@@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .ernie45 import Ernie4_5ForCausalLM from .ernie45 import Ernie4_5ForCausalLM
from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal
from .siglip import SiglipMLP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
...@@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module): ...@@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module):
return freqs return freqs
class SiglipMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module): class SiglipEncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
...@@ -96,6 +96,7 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( ...@@ -96,6 +96,7 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
def _init_img_processor( def _init_img_processor(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "", prefix: str = "",
) -> CLIPVisionModel: ) -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
...@@ -109,7 +110,8 @@ def _init_img_processor( ...@@ -109,7 +110,8 @@ def _init_img_processor(
img_processor = CLIPVisionModel( img_processor = CLIPVisionModel(
clip_config, clip_config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
prefix=prefix, prefix=prefix,
) )
...@@ -160,38 +162,15 @@ class Phi3VImageEmbeddingInputs(TensorSchema): ...@@ -160,38 +162,15 @@ class Phi3VImageEmbeddingInputs(TensorSchema):
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
TYPE_FEATURE = self.type_feature
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds)
if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
return patch_feature
if TYPE_FEATURE == "cls_patch":
return img_feature
raise NotImplementedError
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): class Phi3HDImageEmbedding(nn.Module):
"""Phi3 Image embedding with HD transform.""" """Phi3 Image embedding with HD transform."""
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -200,7 +179,10 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -200,7 +179,10 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
self.img_processor = _init_img_processor( self.img_processor = _init_img_processor(
config, quant_config, prefix=f"{prefix}.img_processor" config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.img_processor",
) )
image_dim_out = config.img_processor["image_dim_out"] image_dim_out = config.img_processor["image_dim_out"]
...@@ -223,13 +205,29 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -223,13 +205,29 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
dim_projection = hidden_size dim_projection = hidden_size
depth = 2 depth = 2
layers = [nn.Linear(image_dim_out * 4, dim_projection)] layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)]
for _ in range(1, depth): for _ in range(1, depth):
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers) self.img_projection = nn.Sequential(*layers)
self.type_feature = config.img_processor.get("type_feature", "patch") self.type_feature = config.img_processor.get("type_feature", "patch")
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
type_feature = self.type_feature
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds)
if type_feature == "patch":
patch_feature = img_feature[:, 1:]
return patch_feature
if type_feature == "cls_patch":
return img_feature
raise NotImplementedError(type_feature)
def forward( def forward(
self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -582,6 +580,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -582,6 +580,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -590,14 +589,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -590,14 +589,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=self.quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"), prefix=maybe_prefix(prefix, "model.embed_tokens"),
) )
# TODO: Optionally initializes this for supporting input embeddings. # TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
config, config,
self.quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
) )
......
...@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import ( ...@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module): ...@@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
assert config.intermediate_size is not None assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size, input_size=config.hidden_size,
...@@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module): ...@@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
input_size=config.intermediate_size, input_size=config.intermediate_size,
...@@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module): ...@@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
) )
self.act_and_mul = get_act_and_mul_fn(config.hidden_act) self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
...@@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module): ...@@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module): ...@@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module):
self.config = config self.config = config
assert not config.hidden_size % config.num_attention_heads assert not config.hidden_size % config.num_attention_heads
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
assert self.total_num_heads * self.head_dim == config.hidden_size
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
head_size=self.head_dim, head_size=self.head_dim,
...@@ -1096,15 +1110,21 @@ class PixtralHFAttention(nn.Module): ...@@ -1096,15 +1110,21 @@ class PixtralHFAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
) )
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
input_size=config.hidden_size, input_size=config.hidden_size,
output_size=config.hidden_size, output_size=config.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
disable_tp=use_data_parallel,
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
) )
self.n_heads = divide(config.num_attention_heads, self.tp_size)
def forward( def forward(
self, self,
...@@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module):
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention( self.attention = PixtralHFAttention(
config, quant_config=quant_config, prefix=f"{prefix}.attention" config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attention",
) )
self.feed_forward = PixtralHFMLP( self.feed_forward = PixtralHFMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.feed_forward",
) )
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
...@@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module): ...@@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
prefix: str = "", prefix: str = "",
...@@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module): ...@@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module):
PixtralHFTransformerBlock( PixtralHFTransformerBlock(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
) )
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
...@@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module):
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module):
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer( self.transformer = PixtralHFTransformer(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer", prefix=f"{prefix}.transformer",
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from functools import cached_property from functools import cached_property
from typing import Annotated, Literal from typing import Annotated, Literal
...@@ -19,7 +18,7 @@ from transformers import ( ...@@ -19,7 +18,7 @@ from transformers import (
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): ...@@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
return image_size // patch_size return image_size // patch_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216
class SiglipVisionEmbeddings(nn.Module): class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig): def __init__(self, config: SiglipVisionConfig):
super().__init__() super().__init__()
...@@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches self.num_positions = self.num_patches
self.position_embedding = VocabParallelEmbedding( self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.num_positions, self.embed_dim
)
self.register_buffer( self.register_buffer(
"position_ids", "position_ids",
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)), torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
...@@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module):
def interpolate_pos_encoding( def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor: ) -> torch.Tensor:
"""
This method is an adapted method for SigLIP (due to SigLIP not having
class embedding unlike other ViTs) that allows the model to interpolate
the pre-trained position encodings such that it can be usable on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1] num_patches = embeddings.shape[1]
num_positions = position_embeddings.shape[1] num_positions = self.position_embedding.weight.shape[1]
if num_patches == num_positions and height == width: if num_patches == num_positions and height == width:
return position_embeddings return self.position_embedding(self.position_ids)
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1] dim = embeddings.shape[-1]
height = height // self.patch_size
width = width // self.patch_size new_height = height // self.patch_size
# we add a small number to avoid floating point error new_width = width // self.patch_size
# in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8 sqrt_num_positions = int(num_positions**0.5)
height, width = height + 0.1, width + 0.1 patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
patch_pos_embed = position_embeddings.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
) )
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate( patch_pos_embed = nn.functional.interpolate(
patch_pos_embed, patch_pos_embed,
scale_factor=( size=(new_height, new_width),
height / math.sqrt(num_positions),
width / math.sqrt(num_positions),
),
mode="bicubic", mode="bicubic",
align_corners=False, align_corners=False,
) )
if (
int(height) != patch_pos_embed.shape[-2]
or int(width) != patch_pos_embed.shape[-1]
):
raise ValueError(
"Width or height does not match with "
"the interpolated position embeddings"
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed return patch_pos_embed
...@@ -377,6 +354,7 @@ class SiglipAttention(nn.Module): ...@@ -377,6 +354,7 @@ class SiglipAttention(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
...@@ -389,19 +367,25 @@ class SiglipAttention(nn.Module): ...@@ -389,19 +367,25 @@ class SiglipAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim: if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError( raise ValueError(
f"embed_dim must be divisible by num_heads (got " f"embed_dim must be divisible by num_heads "
"`embed_dim`: {self.embed_dim} and `num_heads`:" f"(got `embed_dim`: {self.embed_dim} and "
f" {self.num_heads})." f"`num_heads`: {self.num_heads})."
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
...@@ -409,17 +393,29 @@ class SiglipAttention(nn.Module): ...@@ -409,17 +393,29 @@ class SiglipAttention(nn.Module):
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = attn_cls( if attn_cls == MMEncoderAttention:
self.num_heads_per_partition, self.attn = attn_cls(
self.head_dim, self.num_heads_per_partition,
self.scale, self.head_dim,
prefix=f"{prefix}.attn", self.scale,
) prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
def forward( def forward(
self, self,
...@@ -439,12 +435,19 @@ class SiglipMLP(nn.Module): ...@@ -439,12 +435,19 @@ class SiglipMLP(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization # Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]: if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True quantizable = True
...@@ -454,17 +457,20 @@ class SiglipMLP(nn.Module): ...@@ -454,17 +457,20 @@ class SiglipMLP(nn.Module):
quantizable = ( quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0 config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
) )
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
quant_config=quant_config if quantizable else None, quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config if quantizable else None, quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2", prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
...@@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module):
self.self_attn = SiglipAttention( self.self_attn = SiglipAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
...@@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module): ...@@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module):
self, self,
config: SiglipVisionConfig | SiglipTextConfig, config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
...@@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module): ...@@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer( SiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls, attn_cls=attn_cls,
) )
...@@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
) )
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
) )
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
...@@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module):
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention, attn_cls=MMEncoderAttention,
...@@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module):
self.head = SiglipMultiheadAttentionPoolingHead( self.head = SiglipMultiheadAttentionPoolingHead(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.head", prefix=f"{prefix}.head",
) )
...@@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module): ...@@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module):
class SiglipVisionModel(nn.Module): class SiglipVisionModel(nn.Module):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__( def __init__(
self, self,
config: SiglipVisionConfig, config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
...@@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module): ...@@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
...@@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
......
...@@ -11,7 +11,6 @@ from torch.nn import functional as F ...@@ -11,7 +11,6 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
...@@ -186,7 +185,6 @@ class Siglip2Attention(nn.Module): ...@@ -186,7 +185,6 @@ class Siglip2Attention(nn.Module):
multimodal_config: MultiModalConfig | None = None, multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -196,12 +194,11 @@ class Siglip2Attention(nn.Module): ...@@ -196,12 +194,11 @@ class Siglip2Attention(nn.Module):
if self.head_dim * self.num_heads != self.embed_dim: if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError( raise ValueError(
f"embed_dim must be divisible by num_heads " f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f"(got `embed_dim`: {self.embed_dim} and "
f" {self.num_heads})." f"`num_heads`: {self.num_heads})."
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.is_causal = False
use_data_parallel = ( use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data" multimodal_config.mm_encoder_tp_mode == "data"
...@@ -233,6 +230,7 @@ class Siglip2Attention(nn.Module): ...@@ -233,6 +230,7 @@ class Siglip2Attention(nn.Module):
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
num_heads=self.num_heads_per_partition, num_heads=self.num_heads_per_partition,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scale,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
......
...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor ...@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.processing_utils import ProcessingKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -346,6 +346,7 @@ def _build_tarsier_hf_processor( ...@@ -346,6 +346,7 @@ def _build_tarsier_hf_processor(
def init_vision_tower_for_tarsier( def init_vision_tower_for_tarsier(
hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*, *,
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
...@@ -377,6 +378,7 @@ def init_vision_tower_for_tarsier( ...@@ -377,6 +378,7 @@ def init_vision_tower_for_tarsier(
return CLIPVisionModel( return CLIPVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init, num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -385,6 +387,7 @@ def init_vision_tower_for_tarsier( ...@@ -385,6 +387,7 @@ def init_vision_tower_for_tarsier(
return SiglipVisionModel( return SiglipVisionModel(
vision_config, vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init, num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=prefix, prefix=prefix,
...@@ -414,12 +417,16 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -414,12 +417,16 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config: TarsierHfConfig = vllm_config.model_config.hf_config config: TarsierHfConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config # Storing the Tarsier-specific HF config self.config = config # Storing the Tarsier-specific HF config
self.vision_tower = init_vision_tower_for_tarsier( self.vision_tower = init_vision_tower_for_tarsier(
config, config,
quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment