"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "de5288312d7f428a3bd795a2f0299cd4003f7480"
Unverified Commit 4024e1d2 authored by Jiajun Li's avatar Jiajun Li Committed by GitHub
Browse files

Implement Siglip Vision model, and support BNB quantization for gemma3-mm (#5339)

parent 5c0b38f3
...@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module): ...@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module):
softmax_in_single_precision=softmax_in_single_precision, softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True, flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
self.mlp = CLIPMLP( self.mlp = CLIPMLP(
config, config,
...@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module): ...@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module):
config, quant_config, prefix=add_prefix("vision_model", prefix) config, quant_config, prefix=add_prefix("vision_model", prefix)
) )
@property
def device(self) -> torch.device:
return self.vision_model.device
def forward(self, pixel_values: torch.Tensor): def forward(self, pixel_values: torch.Tensor):
return self.vision_model(pixel_values) return self.vision_model(pixel_values)
......
...@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict ...@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, Gemma3Config, PreTrainedModel from transformers import Gemma3Config, PreTrainedModel
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
...@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
from sglang.srt.models.siglip import SiglipVisionModel
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
".k_proj.", ".k_proj.",
".v_proj.", ".v_proj.",
".o_proj.", ".o_proj.",
".out_proj.",
] ]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
...@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"v_proj": ("qkv_proj", 2), "v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0), "gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
"out_proj": ("proj", 0),
} }
packed_modules_mapping = { packed_modules_mapping = {
...@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
super().__init__(config=config) super().__init__(config=config)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
# Vision components
# TODO: replace with vision attention self.vision_tower = SiglipVisionModel(
# self.vision_tower = SiglipVisionModel( config=config.vision_config,
# config.vision_config, quant_config=quant_config,
# quant_config, prefix=add_prefix("vision_tower", prefix),
# prefix=add_prefix("vision_tower", prefix), )
# )
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config) self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
# Text model # Text model
self.language_model = Gemma3ForCausalLM( self.language_model = Gemma3ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix) config.text_config,
quant_config,
prefix=add_prefix("language_model", prefix),
) )
if self.language_model.logits_processor.logit_scale: if self.language_model.logits_processor.logit_scale:
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
...@@ -290,7 +294,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -290,7 +294,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
pixel_values = pixel_values.to(device=self.vision_tower.device) pixel_values = pixel_values.to(device=self.vision_tower.device)
pixel_values = pixel_values.to(dtype=self.language_model.dtype()) pixel_values = pixel_values.to(dtype=self.language_model.dtype())
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state vision_outputs = self.vision_tower(pixel_values=pixel_values)
image_features = self.multi_modal_projector(vision_outputs) image_features = self.multi_modal_projector(vision_outputs)
return image_features return image_features
...@@ -366,6 +370,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -366,6 +370,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return self.language_model.tie_weights() return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
"""Load weights for the model.""" """Load weights for the model."""
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
...@@ -379,21 +391,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -379,21 +391,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
loaded_params.update(causal_loaded_params) loaded_params.update(causal_loaded_params)
continue continue
else: else:
# Skip lm_head.weight as it's tied with embed_tokens for param_name, weight_name, shard_id in stacked_params_mapping:
if "lm_head.weight" in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name]
# Remapping the name of FP8 kv-scale weight_loader = param.weight_loader
name = maybe_remap_kv_scale_name(name, params_dict) weight_loader(param, loaded_weight, shard_id)
if name is None: break
continue else:
param = params_dict[name] if "vision_model" in name:
weight_loader = getattr(param, "weight_loader", default_weight_loader) # adapt to VisionAttention
weight_loader(param, loaded_weight) name = name.replace(".self_attn.out_proj", ".self_attn.proj")
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params unloaded_params = params_dict.keys() - loaded_params
if unloaded_params: if unloaded_params:
...@@ -404,5 +428,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -404,5 +428,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
EntryClass = Gemma3ForConditionalGeneration EntryClass = Gemma3ForConditionalGeneration
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
# Adapted from
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
from functools import partial
from typing import Optional, Type, Union
import torch
import torch.nn as nn
from transformers import SiglipVisionConfig
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.utils import add_prefix
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = VocabParallelEmbedding(
self.num_positions, self.embed_dim
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
# interpolate_pos_encoding is never used in sglang
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
# Copied from sglang.srt.models.clip.CLIPMLP
class SiglipMLP(nn.Module):
def __init__(
self,
config,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=add_prefix("fc1", prefix),
)
self.act = act_layer()
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("fc2", prefix),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
# Copied from sglang.srt.models.clip.CLIPEncoderLayer
class SiglipEncoderLayer(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
self.layer_norm1 = norm_layer(config.hidden_size)
self.layer_norm2 = norm_layer(config.hidden_size)
if attn_implementation == "sdpa":
qkv_backend = "sdpa"
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
qkv_backend = "triton_attn"
softmax_in_single_precision = False
elif attn_implementation == "eager":
qkv_backend = "sdpa"
softmax_in_single_precision = True
self.self_attn = VisionAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
use_qkv_parallel=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = SiglipMLP(
config,
act_layer=act_layer,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
# Siglip text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
hidden_states = self.self_attn(
hidden_states,
attention_mask=attn_mask,
# causal_attention_mask=causal_attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# Copied from sglang.srt.models.clip.CLIPEncoder
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
num_hidden_layers = config.num_hidden_layers
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
config=config,
norm_layer=norm_layer,
attn_implementation="sdpa",
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_idx}", prefix),
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor = None,
causal_attention_mask: torch.Tensor = None,
return_all_hidden_states: bool = False,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states, attention_mask, causal_attention_mask
)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
class SiglipVisionTransformer(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config=config,
quant_config=quant_config,
prefix=add_prefix("encoder", prefix),
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
# VisionAttention in SiglipEncoderLayer is multihead attention
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@property
def device(self) -> torch.device:
return self.encoder.layers[0].layer_norm1.weight.device
def forward(
self,
pixel_values: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values.to(self.device))
return_all_hidden_states = False
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
# Copied from sglang.srt.models.clip.CLIPVisionModel
class SiglipVisionModel(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.vision_model = SiglipVisionTransformer(
config, quant_config, prefix=add_prefix("vision_model", prefix)
)
@property
def device(self) -> torch.device:
return self.vision_model.device
def forward(self, pixel_values: torch.Tensor):
return self.vision_model(pixel_values)
...@@ -33,11 +33,14 @@ VISION_MODELS = [ ...@@ -33,11 +33,14 @@ VISION_MODELS = [
"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
"unsloth/Llama-3.2-11B-Vision-bnb-4bit", "unsloth/Llama-3.2-11B-Vision-bnb-4bit",
"unsloth/gemma-3-4b-it-bnb-4bit",
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
] ]
LANGUAGE_MODELS = [ LANGUAGE_MODELS = [
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
"unsloth/Qwen2-7B-Instruct-bnb-4bit", "unsloth/Qwen2-7B-Instruct-bnb-4bit",
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
"unsloth/gemma-3-1b-it-bnb-4bit",
] ]
# image # image
...@@ -256,6 +259,7 @@ class TestVisionModel(CustomTestCase): ...@@ -256,6 +259,7 @@ class TestVisionModel(CustomTestCase):
"0.6", "0.6",
"--load-format", "--load-format",
"bitsandbytes", "bitsandbytes",
"--enable-multimodal",
] ]
try: try:
process = popen_launch_server_wrapper( process = popen_launch_server_wrapper(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment