Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4024e1d2
Unverified
Commit
4024e1d2
authored
May 20, 2025
by
Jiajun Li
Committed by
GitHub
May 20, 2025
Browse files
Implement Siglip Vision model, and support BNB quantization for gemma3-mm (#5339)
parent
5c0b38f3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
353 additions
and
29 deletions
+353
-29
python/sglang/srt/models/clip.py
python/sglang/srt/models/clip.py
+5
-1
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+50
-28
python/sglang/srt/models/siglip.py
python/sglang/srt/models/siglip.py
+294
-0
test/srt/test_bnb.py
test/srt/test_bnb.py
+4
-0
No files found.
python/sglang/srt/models/clip.py
View file @
4024e1d2
...
...
@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module):
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
prefix
=
add_prefix
(
"
self_
attn"
,
prefix
),
)
self
.
mlp
=
CLIPMLP
(
config
,
...
...
@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module):
config
,
quant_config
,
prefix
=
add_prefix
(
"vision_model"
,
prefix
)
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
vision_model
.
device
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
return
self
.
vision_model
(
pixel_values
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
4024e1d2
...
...
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import
torch
from
torch
import
nn
from
transformers
import
AutoModel
,
Gemma3Config
,
PreTrainedModel
from
transformers
import
Gemma3Config
,
PreTrainedModel
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
...
...
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import (
maybe_remap_kv_scale_name
,
)
from
sglang.srt.models.gemma3_causal
import
Gemma3ForCausalLM
from
sglang.srt.models.siglip
import
SiglipVisionModel
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
".k_proj."
,
".v_proj."
,
".o_proj."
,
".out_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
...
...
@@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
"out_proj"
:
(
"proj"
,
0
),
}
packed_modules_mapping
=
{
...
...
@@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
quant_config
=
quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self
.
vision_tower
=
AutoModel
.
from_config
(
config
=
config
.
vision_config
)
self
.
vision_tower
=
SiglipVisionModel
(
config
=
config
.
vision_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"vision_tower"
,
prefix
),
)
self
.
multi_modal_projector
=
Gemma3MultiModalProjector
(
config
)
self
.
vocab_size
=
config
.
text_config
.
vocab_size
# Text model
self
.
language_model
=
Gemma3ForCausalLM
(
config
.
text_config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
.
text_config
,
quant_config
,
prefix
=
add_prefix
(
"language_model"
,
prefix
),
)
if
self
.
language_model
.
logits_processor
.
logit_scale
:
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
...
...
@@ -290,7 +294,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_tower
.
device
)
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
vision_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
)
.
last_hidden_state
vision_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
)
image_features
=
self
.
multi_modal_projector
(
vision_outputs
)
return
image_features
...
...
@@ -366,6 +370,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return
self
.
language_model
.
tie_weights
()
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
"""Load weights for the model."""
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
...
...
@@ -379,21 +391,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
loaded_params
.
update
(
causal_loaded_params
)
continue
else
:
# Skip lm_head.weight as it's tied with embed_tokens
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"vision_model"
in
name
:
# adapt to VisionAttention
name
=
name
.
replace
(
".self_attn.out_proj"
,
".self_attn.proj"
)
# Skip loading extra bias for GPTQ models
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
...
...
@@ -404,5 +428,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
EntryClass
=
Gemma3ForConditionalGeneration
AutoModel
.
register
(
Gemma3Config
,
Gemma3ForConditionalGeneration
,
exist_ok
=
True
)
python/sglang/srt/models/siglip.py
0 → 100644
View file @
4024e1d2
# Adapted from
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
from
functools
import
partial
from
typing
import
Optional
,
Type
,
Union
import
torch
import
torch.nn
as
nn
from
transformers
import
SiglipVisionConfig
from
sglang.srt.layers.activation
import
QuickGELU
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.utils
import
add_prefix
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
class
SiglipVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
self
.
position_embedding
=
VocabParallelEmbedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)),
persistent
=
False
,
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
)
)
# shape = [*, width, grid, grid]
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
# interpolate_pos_encoding is never used in sglang
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
# Copied from sglang.srt.models.clip.CLIPMLP
class
SiglipMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"fc1"
,
prefix
),
)
self
.
act
=
act_layer
()
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"fc2"
,
prefix
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_parallel
,
_
=
self
.
fc1
(
x
)
x_parallel
=
self
.
act
(
x_parallel
)
x
,
_
=
self
.
fc2
(
x_parallel
)
return
x
# Copied from sglang.srt.models.clip.CLIPEncoderLayer
class
SiglipEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
attn_implementation
:
Optional
[
str
]
=
"sdpa"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_norm1
=
norm_layer
(
config
.
hidden_size
)
self
.
layer_norm2
=
norm_layer
(
config
.
hidden_size
)
if
attn_implementation
==
"sdpa"
:
qkv_backend
=
"sdpa"
softmax_in_single_precision
=
False
elif
attn_implementation
==
"flash_attention_2"
:
qkv_backend
=
"triton_attn"
softmax_in_single_precision
=
False
elif
attn_implementation
==
"eager"
:
qkv_backend
=
"sdpa"
softmax_in_single_precision
=
True
self
.
self_attn
=
VisionAttention
(
embed_dim
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
projection_size
=
config
.
hidden_size
,
use_qkv_parallel
=
True
,
qkv_backend
=
qkv_backend
,
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
self
.
mlp
=
SiglipMLP
(
config
,
act_layer
=
act_layer
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
# Siglip text model uses both `causal_attention_mask` and `attention_mask`
if
attention_mask
is
not
None
and
causal_attention_mask
is
not
None
:
attn_mask
=
attention_mask
+
causal_attention_mask
elif
causal_attention_mask
is
not
None
:
attn_mask
=
causal_attention_mask
else
:
attn_mask
=
attention_mask
hidden_states
=
self
.
self_attn
(
hidden_states
,
attention_mask
=
attn_mask
,
# causal_attention_mask=causal_attention_mask,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
# Copied from sglang.srt.models.clip.CLIPEncoder
class
SiglipEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
num_hidden_layers
=
config
.
num_hidden_layers
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
self
.
layers
=
nn
.
ModuleList
(
[
SiglipEncoderLayer
(
config
=
config
,
norm_layer
=
norm_layer
,
attn_implementation
=
"sdpa"
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"layers.
{
layer_idx
}
"
,
prefix
),
)
for
layer_idx
in
range
(
num_hidden_layers
)
]
)
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
=
None
,
causal_attention_mask
:
torch
.
Tensor
=
None
,
return_all_hidden_states
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
hidden_states_pool
=
[
inputs_embeds
]
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
causal_attention_mask
)
if
return_all_hidden_states
:
hidden_states_pool
.
append
(
hidden_states
)
if
return_all_hidden_states
:
return
hidden_states_pool
return
hidden_states
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
class
SiglipVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
encoder
=
SiglipEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"encoder"
,
prefix
),
)
num_hidden_layers
=
config
.
num_hidden_layers
if
len
(
self
.
encoder
.
layers
)
>
config
.
num_hidden_layers
:
raise
ValueError
(
f
"The original encoder only has
{
num_hidden_layers
}
"
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
# VisionAttention in SiglipEncoderLayer is multihead attention
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
encoder
.
layers
[
0
].
layer_norm1
.
weight
.
device
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
return_all_hidden_states
=
False
last_hidden_state
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
return_all_hidden_states
=
return_all_hidden_states
,
)
last_hidden_state
=
self
.
post_layernorm
(
last_hidden_state
)
return
last_hidden_state
# Copied from sglang.srt.models.clip.CLIPVisionModel
class
SiglipVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
vision_model
=
SiglipVisionTransformer
(
config
,
quant_config
,
prefix
=
add_prefix
(
"vision_model"
,
prefix
)
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
vision_model
.
device
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
return
self
.
vision_model
(
pixel_values
)
test/srt/test_bnb.py
View file @
4024e1d2
...
...
@@ -33,11 +33,14 @@ VISION_MODELS = [
"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit"
,
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit"
,
"unsloth/Llama-3.2-11B-Vision-bnb-4bit"
,
"unsloth/gemma-3-4b-it-bnb-4bit"
,
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit"
,
]
LANGUAGE_MODELS
=
[
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
,
"unsloth/Qwen2-7B-Instruct-bnb-4bit"
,
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
,
"unsloth/gemma-3-1b-it-bnb-4bit"
,
]
# image
...
...
@@ -256,6 +259,7 @@ class TestVisionModel(CustomTestCase):
"0.6"
,
"--load-format"
,
"bitsandbytes"
,
"--enable-multimodal"
,
]
try
:
process
=
popen_launch_server_wrapper
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment