Unverified Commit 4024e1d2 authored by Jiajun Li's avatar Jiajun Li Committed by GitHub
Browse files

Implement Siglip Vision model, and support BNB quantization for gemma3-mm (#5339)

parent 5c0b38f3
......@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module):
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
prefix=add_prefix("self_attn", prefix),
)
self.mlp = CLIPMLP(
config,
......@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module):
config, quant_config, prefix=add_prefix("vision_model", prefix)
)
@property
def device(self) -> torch.device:
return self.vision_model.device
def forward(self, pixel_values: torch.Tensor):
return self.vision_model(pixel_values)
......
......@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import AutoModel, Gemma3Config, PreTrainedModel
from transformers import Gemma3Config, PreTrainedModel
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
......@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
from sglang.srt.models.siglip import SiglipVisionModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
......@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
".k_proj.",
".v_proj.",
".o_proj.",
".out_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
......@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
"out_proj": ("proj", 0),
}
packed_modules_mapping = {
......@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.vision_tower = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_tower", prefix),
)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
# Text model
self.language_model = Gemma3ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
config.text_config,
quant_config,
prefix=add_prefix("language_model", prefix),
)
if self.language_model.logits_processor.logit_scale:
logit_scale = getattr(config, "logit_scale", 1.0)
......@@ -290,7 +294,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
pixel_values = pixel_values.to(device=self.vision_tower.device)
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
vision_outputs = self.vision_tower(pixel_values=pixel_values)
image_features = self.multi_modal_projector(vision_outputs)
return image_features
......@@ -366,6 +370,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
......@@ -379,21 +391,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
loaded_params.update(causal_loaded_params)
continue
else:
# Skip lm_head.weight as it's tied with embed_tokens
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
......@@ -404,5 +428,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
EntryClass = Gemma3ForConditionalGeneration
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
# Adapted from
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
from functools import partial
from typing import Optional, Type, Union
import torch
import torch.nn as nn
from transformers import SiglipVisionConfig
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.utils import add_prefix
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = VocabParallelEmbedding(
self.num_positions, self.embed_dim
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
# interpolate_pos_encoding is never used in sglang
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
# Copied from sglang.srt.models.clip.CLIPMLP
class SiglipMLP(nn.Module):
def __init__(
self,
config,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=add_prefix("fc1", prefix),
)
self.act = act_layer()
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("fc2", prefix),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
# Copied from sglang.srt.models.clip.CLIPEncoderLayer
class SiglipEncoderLayer(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
self.layer_norm1 = norm_layer(config.hidden_size)
self.layer_norm2 = norm_layer(config.hidden_size)
if attn_implementation == "sdpa":
qkv_backend = "sdpa"
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
qkv_backend = "triton_attn"
softmax_in_single_precision = False
elif attn_implementation == "eager":
qkv_backend = "sdpa"
softmax_in_single_precision = True
self.self_attn = VisionAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
use_qkv_parallel=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = SiglipMLP(
config,
act_layer=act_layer,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
# Siglip text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
hidden_states = self.self_attn(
hidden_states,
attention_mask=attn_mask,
# causal_attention_mask=causal_attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# Copied from sglang.srt.models.clip.CLIPEncoder
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
num_hidden_layers = config.num_hidden_layers
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
config=config,
norm_layer=norm_layer,
attn_implementation="sdpa",
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_idx}", prefix),
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor = None,
causal_attention_mask: torch.Tensor = None,
return_all_hidden_states: bool = False,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states, attention_mask, causal_attention_mask
)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
class SiglipVisionTransformer(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config=config,
quant_config=quant_config,
prefix=add_prefix("encoder", prefix),
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
# VisionAttention in SiglipEncoderLayer is multihead attention
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@property
def device(self) -> torch.device:
return self.encoder.layers[0].layer_norm1.weight.device
def forward(
self,
pixel_values: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values.to(self.device))
return_all_hidden_states = False
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
# Copied from sglang.srt.models.clip.CLIPVisionModel
class SiglipVisionModel(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.vision_model = SiglipVisionTransformer(
config, quant_config, prefix=add_prefix("vision_model", prefix)
)
@property
def device(self) -> torch.device:
return self.vision_model.device
def forward(self, pixel_values: torch.Tensor):
return self.vision_model(pixel_values)
......@@ -33,11 +33,14 @@ VISION_MODELS = [
"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
"unsloth/Llama-3.2-11B-Vision-bnb-4bit",
"unsloth/gemma-3-4b-it-bnb-4bit",
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
]
LANGUAGE_MODELS = [
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
"unsloth/Qwen2-7B-Instruct-bnb-4bit",
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
"unsloth/gemma-3-1b-it-bnb-4bit",
]
# image
......@@ -256,6 +259,7 @@ class TestVisionModel(CustomTestCase):
"0.6",
"--load-format",
"bitsandbytes",
"--enable-multimodal",
]
try:
process = popen_launch_server_wrapper(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment