Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d02bb3e
"vscode:/vscode.git/clone" did not exist on "12cb760a3773fe1a97d5a00fca26412f814f20fa"
Unverified
Commit
9d02bb3e
authored
Mar 17, 2025
by
Mick
Committed by
GitHub
Mar 16, 2025
Browse files
Urgent model support: support gemma-3-it (#4424)
parent
402db5c5
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2543 additions
and
86 deletions
+2543
-86
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+8
-0
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+3
-0
python/sglang/srt/configs/gemma3.py
python/sglang/srt/configs/gemma3.py
+1086
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+7
-2
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+27
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+4
-0
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+2
-26
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+20
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+31
-0
python/sglang/srt/managers/image_processors/base_image_processor.py
...ang/srt/managers/image_processors/base_image_processor.py
+59
-48
python/sglang/srt/managers/image_processors/gemma3.py
python/sglang/srt/managers/image_processors/gemma3.py
+100
-0
python/sglang/srt/managers/image_processors/janus_pro.py
python/sglang/srt/managers/image_processors/janus_pro.py
+4
-1
python/sglang/srt/managers/image_processors/minicpmv.py
python/sglang/srt/managers/image_processors/minicpmv.py
+4
-1
python/sglang/srt/managers/image_processors/qwen_vl.py
python/sglang/srt/managers/image_processors/qwen_vl.py
+4
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+27
-0
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+687
-0
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+462
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+3
-3
No files found.
docs/references/supported_models.md
View file @
9d02bb3e
...
...
@@ -32,6 +32,7 @@
-
Phi-3-Small
-
IBM Granite 3
-
Janus-Pro-1B / Janus-Pro-7B
-
Gemma 3 (it)
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
9d02bb3e
...
...
@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
return
get_chat_template
(
"granite-3-instruct"
)
@
register_chat_template_matching_function
def
match_gemma3_instruct
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"gemma-3"
in
model_path
and
"1b"
not
in
model_path
:
# gemma-3-1b-it is completion model
return
get_chat_template
(
"gemma-it"
)
if
__name__
==
"__main__"
:
messages
=
[
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
...
...
python/sglang/srt/configs/__init__.py
View file @
9d02bb3e
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.gemma3
import
Gemma3Config
,
Gemma3TextConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.qwen2_5_vl_config
import
(
Qwen2_5_VLConfig
,
...
...
@@ -14,4 +15,6 @@ __all__ = [
"Qwen2_5_VLConfig"
,
"Qwen2_5_VLVisionConfig"
,
"MultiModalityConfig"
,
"Gemma3Config"
,
"Gemma3TextConfig"
,
]
python/sglang/srt/configs/gemma3.py
0 → 100644
View file @
9d02bb3e
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
itertools
import
logging
import
math
import
re
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
import
numpy
as
np
import
PIL
import
transformers
from
torch
import
TensorType
from
transformers
import
(
AutoImageProcessor
,
AutoProcessor
,
BatchFeature
,
PretrainedConfig
,
SiglipVisionConfig
,
)
from
transformers.image_processing_utils
import
BaseImageProcessor
,
get_size_dict
from
transformers.image_transforms
import
(
convert_to_rgb
,
resize
,
to_channel_dimension_format
,
)
from
transformers.image_utils
import
(
ChannelDimension
,
ImageInput
,
PILImageResampling
,
get_image_size
,
infer_channel_dimension_format
,
is_pil_image
,
is_scaled_image
,
is_valid_image
,
to_numpy_array
,
valid_images
,
validate_preprocess_arguments
,
)
from
transformers.modeling_rope_utils
import
rope_config_validation
from
transformers.processing_utils
import
(
ImagesKwargs
,
ProcessingKwargs
,
ProcessorMixin
,
Unpack
,
)
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
transformers.utils
import
(
IMAGENET_STANDARD_MEAN
,
IMAGENET_STANDARD_STD
,
filter_out_non_signature_kwargs
,
to_py_obj
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
is_valid_list_of_images
(
images
:
List
):
return
images
and
all
(
is_valid_image
(
image
)
for
image
in
images
)
# copied from transformer
def
make_nested_list_of_images
(
images
:
Union
[
List
[
ImageInput
],
ImageInput
],
)
->
ImageInput
:
"""
Ensure that the output is a nested list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
# If it's a list of batches, it's already in the right format
if
(
isinstance
(
images
,
(
list
,
tuple
))
and
all
(
isinstance
(
images_i
,
(
list
,
tuple
))
for
images_i
in
images
)
and
all
(
is_valid_list_of_images
(
images_i
)
for
images_i
in
images
)
):
return
images
# If it's a list of images, it's a single batch, so convert it to a list of lists
if
isinstance
(
images
,
(
list
,
tuple
))
and
is_valid_list_of_images
(
images
):
if
is_pil_image
(
images
[
0
])
or
images
[
0
].
ndim
==
3
:
return
[
images
]
if
images
[
0
].
ndim
==
4
:
return
[
list
(
image
)
for
image
in
images
]
# If it's a single image, convert it to a list of lists
if
is_valid_image
(
images
):
if
is_pil_image
(
images
)
or
images
.
ndim
==
3
:
return
[[
images
]]
if
images
.
ndim
==
4
:
return
[
list
(
images
)]
raise
ValueError
(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)
def
rescale
(
image
:
np
.
ndarray
,
scale
:
float
,
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
return
transformers
.
image_transforms
.
rescale
(
image
,
scale
=
scale
,
data_format
=
data_format
,
input_data_format
=
input_data_format
,
**
kwargs
,
)
def
normalize
(
image
:
np
.
ndarray
,
mean
:
Union
[
float
,
Iterable
[
float
]],
std
:
Union
[
float
,
Iterable
[
float
]],
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `Iterable[float]`):
Image mean to use for normalization.
std (`float` or `Iterable[float]`):
Image standard deviation to use for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return
transformers
.
image_transforms
.
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
input_data_format
=
input_data_format
,
**
kwargs
,
)
class
Gemma3ImagesKwargs
(
ImagesKwargs
):
do_pan_and_scan
:
Optional
[
bool
]
pan_and_scan_min_crop_size
:
Optional
[
int
]
pan_and_scan_max_num_crops
:
Optional
[
int
]
pan_and_scan_min_ratio_to_activate
:
Optional
[
float
]
do_convert_rgb
:
Optional
[
bool
]
class
Gemma3ProcessorKwargs
(
ProcessingKwargs
,
total
=
False
):
images_kwargs
:
Gemma3ImagesKwargs
_defaults
=
{
"text_kwargs"
:
{
"padding"
:
False
,
},
"images_kwargs"
:
{
"do_pan_and_scan"
:
False
,
"pan_and_scan_min_crop_size"
:
256
,
"pan_and_scan_max_num_crops"
:
4
,
"pan_and_scan_min_ratio_to_activate"
:
1.2
,
},
}
class
Gemma3Processor
(
ProcessorMixin
):
attributes
=
[
"image_processor"
,
"tokenizer"
]
valid_kwargs
=
[
"chat_template"
,
"image_seq_length"
]
image_processor_class
=
"AutoImageProcessor"
tokenizer_class
=
"AutoTokenizer"
def
__init__
(
self
,
image_processor
,
tokenizer
,
chat_template
=
None
,
image_seq_length
:
int
=
256
,
**
kwargs
,
):
self
.
image_seq_length
=
image_seq_length
self
.
image_token_id
=
tokenizer
.
image_token_id
self
.
boi_token
=
tokenizer
.
boi_token
image_tokens_expanded
=
""
.
join
([
tokenizer
.
image_token
]
*
image_seq_length
)
self
.
full_image_sequence
=
(
f
"
\n\n
{
tokenizer
.
boi_token
}{
image_tokens_expanded
}{
tokenizer
.
eoi_token
}
\n\n
"
)
super
().
__init__
(
image_processor
=
image_processor
,
tokenizer
=
tokenizer
,
chat_template
=
chat_template
,
**
kwargs
,
)
# TODO: if transformers is updated, the chat_template needs to be adjusted
self
.
tokenizer
.
add_bos_token
=
False
def
__call__
(
self
,
images
:
ImageInput
=
None
,
text
:
Union
[
TextInput
,
PreTokenizedInput
,
List
[
TextInput
],
List
[
PreTokenizedInput
]
]
=
None
,
videos
=
None
,
audio
=
None
,
**
kwargs
:
Unpack
[
Gemma3ProcessorKwargs
],
)
->
BatchFeature
:
if
text
is
None
and
images
is
None
:
raise
ValueError
(
"Provide at least one of `text` or `images`."
)
# print(f"processing, text:{text}")
output_kwargs
=
self
.
_merge_kwargs
(
Gemma3ProcessorKwargs
,
tokenizer_init_kwargs
=
self
.
tokenizer
.
init_kwargs
,
**
kwargs
,
)
if
isinstance
(
text
,
str
):
text
=
[
text
]
elif
not
isinstance
(
text
,
list
)
and
not
isinstance
(
text
[
0
],
str
):
raise
ValueError
(
"Invalid input text. Please provide a string, or a list of strings"
)
image_inputs
=
{}
if
images
is
not
None
:
batched_images
=
make_nested_list_of_images
(
images
)
image_inputs
=
self
.
image_processor
(
batched_images
,
**
output_kwargs
[
"images_kwargs"
]
)
# Create empty text to be replaced with placeholders
if
not
text
:
text
=
[
" "
.
join
([
self
.
boi_token
]
*
len
(
images
))
for
images
in
batched_images
]
if
len
(
batched_images
)
!=
len
(
text
):
raise
ValueError
(
f
"Received inconsistently sized batches of images (
{
len
(
batched_images
)
}
) and text (
{
len
(
text
)
}
)."
)
# Replace image tokens by the full expanded sequence
batch_num_crops
=
to_py_obj
(
image_inputs
.
pop
(
"num_crops"
))
text_with_crops
=
text
for
batch_idx
,
(
prompt
,
images
,
num_crops
)
in
enumerate
(
zip
(
text
,
batched_images
,
batch_num_crops
)
):
image_indexes
=
[
m
.
start
()
for
m
in
re
.
finditer
(
self
.
boi_token
,
prompt
)]
if
len
(
images
)
!=
len
(
image_indexes
):
raise
ValueError
(
f
"Prompt contained
{
len
(
image_indexes
)
}
image tokens but received
{
len
(
images
)
}
images."
)
# Insert additional image tokens for Pan-and-Scan crops
for
num
,
idx
in
reversed
(
list
(
zip
(
num_crops
,
image_indexes
))):
if
num
:
formatted_image_text
=
(
f
"Here is the original image
{
self
.
boi_token
}
and here are some crops to help you see better "
+
" "
.
join
([
self
.
boi_token
]
*
num
)
)
prompt
=
(
prompt
[:
idx
]
+
formatted_image_text
+
prompt
[
idx
+
len
(
self
.
boi_token
)
:]
)
text_with_crops
[
batch_idx
]
=
prompt
# Expand placeholder image tokens to the full image token sequence
text
=
[
prompt
.
replace
(
self
.
boi_token
,
self
.
full_image_sequence
)
for
prompt
in
text
]
return_tensors
=
output_kwargs
[
"text_kwargs"
].
pop
(
"return_tensors"
,
None
)
text_inputs
=
self
.
tokenizer
(
text
=
text
,
**
output_kwargs
[
"text_kwargs"
],
return_tensors
=
"np"
)
# print(f"processing, text_inputs:{text_inputs}")
# Add token type ids manually, as tokenizer can't do arbitrary position token types
array_ids
=
np
.
array
(
text_inputs
[
"input_ids"
])
mm_token_type_ids
=
np
.
zeros_like
(
text_inputs
[
"input_ids"
])
mm_token_type_ids
[
array_ids
==
self
.
image_token_id
]
=
1
text_inputs
=
{
k
:
v
.
tolist
()
for
k
,
v
in
text_inputs
.
items
()
}
# in case user requested list inputs
text_inputs
[
"token_type_ids"
]
=
mm_token_type_ids
.
tolist
()
return
BatchFeature
(
data
=
{
**
text_inputs
,
**
image_inputs
},
tensor_type
=
return_tensors
)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def
batch_decode
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return
self
.
tokenizer
.
batch_decode
(
*
args
,
**
kwargs
)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def
decode
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return
self
.
tokenizer
.
decode
(
*
args
,
**
kwargs
)
@
property
def
model_input_names
(
self
):
tokenizer_input_names
=
self
.
tokenizer
.
model_input_names
+
[
"token_type_ids"
]
image_processor_input_names
=
self
.
image_processor
.
model_input_names
return
list
(
dict
.
fromkeys
(
tokenizer_input_names
+
image_processor_input_names
))
class
Gemma3ImageProcessor
(
BaseImageProcessor
):
r
"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*):
Whether to apply `pan_and_scan` to images.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
"""
model_input_names
=
[
"pixel_values"
,
"num_crops"
]
def
__init__
(
self
,
do_resize
:
bool
=
True
,
size
:
Dict
[
str
,
int
]
=
None
,
resample
:
PILImageResampling
=
PILImageResampling
.
BILINEAR
,
do_rescale
:
bool
=
True
,
rescale_factor
:
Union
[
int
,
float
]
=
1
/
255
,
do_normalize
:
bool
=
True
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
do_convert_rgb
:
bool
=
None
,
do_pan_and_scan
:
bool
=
None
,
pan_and_scan_min_crop_size
:
int
=
None
,
pan_and_scan_max_num_crops
:
int
=
None
,
pan_and_scan_min_ratio_to_activate
:
float
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
**
kwargs
)
size
=
size
if
size
is
not
None
else
{
"height"
:
224
,
"width"
:
224
}
size
=
get_size_dict
(
size
,
default_to_square
=
True
)
image_mean
=
image_mean
if
image_mean
is
not
None
else
IMAGENET_STANDARD_MEAN
image_std
=
image_std
if
image_std
is
not
None
else
IMAGENET_STANDARD_STD
self
.
do_resize
=
do_resize
self
.
size
=
size
self
.
resample
=
resample
self
.
do_rescale
=
do_rescale
self
.
rescale_factor
=
rescale_factor
self
.
do_normalize
=
do_normalize
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
do_convert_rgb
=
do_convert_rgb
self
.
do_pan_and_scan
=
do_pan_and_scan
self
.
pan_and_scan_min_crop_size
=
pan_and_scan_min_crop_size
self
.
pan_and_scan_max_num_crops
=
pan_and_scan_max_num_crops
self
.
pan_and_scan_min_ratio_to_activate
=
pan_and_scan_min_ratio_to_activate
def
pan_and_scan
(
self
,
image
:
np
.
ndarray
,
pan_and_scan_min_crop_size
:
int
,
pan_and_scan_max_num_crops
:
int
,
pan_and_scan_min_ratio_to_activate
:
float
,
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
):
"""
Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds
minumum allowed ratio.
Args:
image (`np.ndarray`):
Image to resize.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
height
,
width
=
get_image_size
(
image
)
# Square or landscape image.
if
width
>=
height
:
# Only apply PaS if the image is sufficiently exaggerated
if
width
/
height
<
pan_and_scan_min_ratio_to_activate
:
return
[]
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_w
=
int
(
math
.
floor
(
width
/
height
+
0.5
)
)
# Half round up rounding.
num_crops_w
=
min
(
int
(
math
.
floor
(
width
/
pan_and_scan_min_crop_size
)),
num_crops_w
)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_w
=
max
(
2
,
num_crops_w
)
num_crops_w
=
min
(
pan_and_scan_max_num_crops
,
num_crops_w
)
num_crops_h
=
1
# Portrait image.
else
:
# Only apply PaS if the image is sufficiently exaggerated
if
height
/
width
<
pan_and_scan_min_ratio_to_activate
:
return
[]
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_h
=
int
(
math
.
floor
(
height
/
width
+
0.5
))
num_crops_h
=
min
(
int
(
math
.
floor
(
height
/
pan_and_scan_min_crop_size
)),
num_crops_h
)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_h
=
max
(
2
,
num_crops_h
)
num_crops_h
=
min
(
pan_and_scan_max_num_crops
,
num_crops_h
)
num_crops_w
=
1
crop_size_w
=
int
(
math
.
ceil
(
width
/
num_crops_w
))
crop_size_h
=
int
(
math
.
ceil
(
height
/
num_crops_h
))
# Don't apply PaS if crop size is too small.
if
min
(
crop_size_w
,
crop_size_h
)
<
pan_and_scan_min_crop_size
:
return
[]
crop_positions_w
=
[
crop_size_w
*
i
for
i
in
range
(
num_crops_w
)]
crop_positions_h
=
[
crop_size_h
*
i
for
i
in
range
(
num_crops_h
)]
if
input_data_format
==
ChannelDimension
.
LAST
:
image_crops
=
[
image
[
pos_h
:
pos_h
+
crop_size_h
,
pos_w
:
pos_w
+
crop_size_w
]
for
pos_h
,
pos_w
in
itertools
.
product
(
crop_positions_h
,
crop_positions_w
)
]
else
:
image_crops
=
[
image
[:,
pos_h
:
pos_h
+
crop_size_h
,
pos_w
:
pos_w
+
crop_size_w
]
for
pos_h
,
pos_w
in
itertools
.
product
(
crop_positions_h
,
crop_positions_w
)
]
return
image_crops
def
_process_images_for_pan_and_scan
(
self
,
images
:
List
[
np
.
ndarray
],
do_pan_and_scan
:
bool
,
pan_and_scan_min_crop_size
:
int
,
pan_and_scan_max_num_crops
:
int
,
pan_and_scan_min_ratio_to_activate
:
float
,
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
):
pas_images_list
=
[]
num_crops
=
[]
for
image
in
images
:
pas_images
=
self
.
pan_and_scan
(
image
=
image
,
pan_and_scan_min_crop_size
=
pan_and_scan_min_crop_size
,
pan_and_scan_max_num_crops
=
pan_and_scan_max_num_crops
,
pan_and_scan_min_ratio_to_activate
=
pan_and_scan_min_ratio_to_activate
,
data_format
=
data_format
,
input_data_format
=
input_data_format
,
)
pas_images_list
.
extend
([
image
]
+
pas_images
)
num_crops
.
append
(
len
(
pas_images
))
return
pas_images_list
,
num_crops
@
filter_out_non_signature_kwargs
()
def
preprocess
(
self
,
images
:
ImageInput
,
do_resize
:
bool
=
None
,
size
:
Dict
[
str
,
int
]
=
None
,
resample
:
PILImageResampling
=
None
,
do_rescale
:
bool
=
None
,
rescale_factor
:
float
=
None
,
do_normalize
:
bool
=
None
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
data_format
:
Optional
[
ChannelDimension
]
=
ChannelDimension
.
FIRST
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
do_convert_rgb
:
bool
=
None
,
do_pan_and_scan
:
bool
=
None
,
pan_and_scan_min_crop_size
:
int
=
None
,
pan_and_scan_max_num_crops
:
int
=
None
,
pan_and_scan_min_ratio_to_activate
:
float
=
None
,
)
->
PIL
.
Image
.
Image
:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to apply `pan_and_scan` to images.
pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`):
Minimum aspect ratio to activate pan and scan.
"""
do_resize
=
do_resize
if
do_resize
is
not
None
else
self
.
do_resize
size
=
size
if
size
is
not
None
else
self
.
size
size
=
get_size_dict
(
size
,
param_name
=
"size"
,
default_to_square
=
False
)
resample
=
resample
if
resample
is
not
None
else
self
.
resample
do_rescale
=
do_rescale
if
do_rescale
is
not
None
else
self
.
do_rescale
rescale_factor
=
(
rescale_factor
if
rescale_factor
is
not
None
else
self
.
rescale_factor
)
do_normalize
=
do_normalize
if
do_normalize
is
not
None
else
self
.
do_normalize
image_mean
=
image_mean
if
image_mean
is
not
None
else
self
.
image_mean
image_std
=
image_std
if
image_std
is
not
None
else
self
.
image_std
do_convert_rgb
=
(
do_convert_rgb
if
do_convert_rgb
is
not
None
else
self
.
do_convert_rgb
)
do_pan_and_scan
=
(
do_pan_and_scan
if
do_pan_and_scan
is
not
None
else
self
.
do_pan_and_scan
)
pan_and_scan_min_crop_size
=
(
pan_and_scan_min_crop_size
if
pan_and_scan_min_crop_size
is
not
None
else
self
.
pan_and_scan_min_crop_size
)
pan_and_scan_max_num_crops
=
(
pan_and_scan_max_num_crops
if
pan_and_scan_max_num_crops
is
not
None
else
self
.
pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate
=
(
pan_and_scan_min_ratio_to_activate
if
pan_and_scan_min_ratio_to_activate
is
not
None
else
self
.
pan_and_scan_min_ratio_to_activate
)
images_list
=
make_nested_list_of_images
(
images
)
if
not
valid_images
(
images_list
[
0
]):
raise
ValueError
(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments
(
do_rescale
=
do_rescale
,
rescale_factor
=
rescale_factor
,
do_normalize
=
do_normalize
,
image_mean
=
image_mean
,
image_std
=
image_std
,
do_resize
=
do_resize
,
size
=
size
,
resample
=
resample
,
)
if
do_convert_rgb
:
images_list
=
[
[
convert_to_rgb
(
image
)
for
image
in
images
]
for
images
in
images_list
]
# All transformations expect numpy arrays.
images_list
=
[
[
to_numpy_array
(
image
)
for
image
in
images
]
for
images
in
images_list
]
if
do_rescale
and
is_scaled_image
(
images_list
[
0
][
0
]):
logger
.
warning
(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if
input_data_format
is
None
:
# We assume that all images have the same channel dimension format.
input_data_format
=
infer_channel_dimension_format
(
images_list
[
0
][
0
])
if
do_pan_and_scan
:
images_list_and_num_crops
=
[
self
.
_process_images_for_pan_and_scan
(
images
=
images
,
do_pan_and_scan
=
do_pan_and_scan
,
pan_and_scan_min_crop_size
=
pan_and_scan_min_crop_size
,
pan_and_scan_max_num_crops
=
pan_and_scan_max_num_crops
,
pan_and_scan_min_ratio_to_activate
=
pan_and_scan_min_ratio_to_activate
,
data_format
=
data_format
,
input_data_format
=
input_data_format
,
)
for
images
in
images_list
]
images_list
=
[
images
for
images
,
_
in
images_list_and_num_crops
]
num_crops
=
[
num_crops
for
_
,
num_crops
in
images_list_and_num_crops
]
else
:
num_crops
=
[[
0
]
for
images
in
images_list
]
processed_images
=
[]
for
images
in
images_list
:
for
image
in
images
:
if
do_resize
:
height
,
width
=
size
[
"height"
],
size
[
"width"
]
image
=
resize
(
image
=
image
,
size
=
(
height
,
width
),
resample
=
resample
,
input_data_format
=
input_data_format
,
)
if
do_rescale
:
image
=
rescale
(
image
=
image
,
scale
=
rescale_factor
,
input_data_format
=
input_data_format
,
)
if
do_normalize
:
image
=
normalize
(
image
=
image
,
mean
=
image_mean
,
std
=
image_std
,
input_data_format
=
input_data_format
,
)
image
=
to_channel_dimension_format
(
image
,
data_format
,
input_channel_dim
=
input_data_format
)
processed_images
.
append
(
image
)
data
=
{
"pixel_values"
:
processed_images
,
"num_crops"
:
num_crops
}
return
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
class
Gemma3TextConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma3Text-7B.
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262208):
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma3TextModel`]
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
Scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
```python
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
>>> configuration = Gemma3TextConfig()
>>> # Initializing a model from the gemma3_text-7b style configuration
>>> model = Gemma3TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
"""
model_type
=
"gemma3_text"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
=
262_208
,
hidden_size
=
2304
,
intermediate_size
=
9216
,
num_hidden_layers
=
26
,
num_attention_heads
=
8
,
num_key_value_heads
=
4
,
head_dim
=
256
,
hidden_activation
=
"gelu_pytorch_tanh"
,
max_position_embeddings
=
131_072
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
0
,
eos_token_id
=
1
,
bos_token_id
=
2
,
tie_word_embeddings
=
True
,
rope_theta
=
1_000_000.0
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
query_pre_attn_scalar
=
256
,
sliding_window
=
4096
,
final_logit_softcapping
=
None
,
attn_logit_softcapping
=
None
,
cache_implementation
=
"hybrid"
,
rope_scaling
=
None
,
rope_local_base_freq
=
10_000.0
,
sliding_window_pattern
=
6
,
**
kwargs
,
):
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
head_dim
=
head_dim
self
.
num_key_value_heads
=
num_key_value_heads
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
hidden_activation
=
hidden_activation
self
.
query_pre_attn_scalar
=
query_pre_attn_scalar
self
.
sliding_window
=
sliding_window
self
.
final_logit_softcapping
=
final_logit_softcapping
self
.
attn_logit_softcapping
=
attn_logit_softcapping
self
.
cache_implementation
=
cache_implementation
self
.
rope_local_base_freq
=
rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self
.
sliding_window_pattern
=
sliding_window_pattern
self
.
rope_scaling
=
rope_scaling
rope_config_validation
(
self
)
class
Gemma3Config
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"gemma3"
sub_configs
=
{
"text_config"
:
Gemma3TextConfig
,
"vision_config"
:
SiglipVisionConfig
,
}
def
__init__
(
self
,
text_config
:
Optional
[
Gemma3TextConfig
]
=
None
,
vision_config
:
Optional
[
SiglipVisionConfig
]
=
None
,
mm_tokens_per_image
:
int
=
256
,
boi_token_index
:
int
=
255_999
,
eoi_token_index
:
int
=
256_000
,
image_token_index
:
int
=
262_144
,
initializer_range
:
float
=
0.02
,
**
kwargs
,
):
if
text_config
is
None
:
text_config
=
Gemma3TextConfig
()
# logger.info(
# "text_config is None, using default Gemma3TextConfig config."
# )
elif
isinstance
(
text_config
,
dict
):
text_config
=
Gemma3TextConfig
(
**
text_config
)
if
isinstance
(
vision_config
,
dict
):
vision_config
=
SiglipVisionConfig
(
**
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
pass
else
:
# logger.info(
# "vision_config is None or incompatible with Gemma3VisionConfig initialization. Gemma3 will be limited "
# "to text tasks."
# )
# logger.info(f"vision_config: {vision_config}")
vision_config
=
SiglipVisionConfig
()
self
.
text_config
=
text_config
self
.
vision_config
=
vision_config
self
.
mm_tokens_per_image
=
mm_tokens_per_image
self
.
boi_token_index
=
boi_token_index
self
.
eoi_token_index
=
eoi_token_index
self
.
image_token_index
=
image_token_index
self
.
initializer_range
=
initializer_range
super
().
__init__
(
**
kwargs
)
AutoProcessor
.
register
(
config_class
=
Gemma3Config
,
processor_class
=
Gemma3Processor
,
exist_ok
=
True
)
AutoImageProcessor
.
register
(
config_class
=
Gemma3Config
,
image_processor_class
=
None
,
slow_image_processor_class
=
Gemma3ImageProcessor
,
fast_image_processor_class
=
None
,
exist_ok
=
True
,
)
python/sglang/srt/configs/model_config.py
View file @
9d02bb3e
...
...
@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
if
config_dtype
==
torch
.
float32
:
if
config
.
model_type
==
"gemma2"
:
if
config
.
model_type
.
startswith
(
"gemma"
):
if
config
.
model_type
==
"gemma"
:
gemma_version
=
""
else
:
gemma_version
=
config
.
model_type
[
5
]
logger
.
info
(
"For Gemma
2
, we downcast float32 to bfloat16 instead "
f
"For Gemma
{
gemma_version
}
, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
...
...
@@ -453,6 +457,7 @@ multimodal_model_archs = [
"LlavaQwenForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaVidForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
"Grok1VForCausalLM"
,
"Grok1AForCausalLM"
,
"MllamaForConditionalGeneration"
,
...
...
python/sglang/srt/conversation.py
View file @
9d02bb3e
...
...
@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
DEEPSEEK_CHAT
=
auto
()
METAMATH
=
auto
()
QWEN2_VL_EMBED
=
auto
()
GEMMA3
=
auto
()
@
dataclasses
.
dataclass
...
...
@@ -285,6 +286,18 @@ class Conversation:
else
:
ret
+=
role
+
":"
return
ret
elif
self
.
sep_style
==
SeparatorStyle
.
GEMMA3
:
ret
=
system_prompt
for
i
,
(
role
,
message
)
in
enumerate
(
self
.
messages
):
if
message
:
if
i
==
0
:
ret
+=
message
+
self
.
sep
else
:
ret
+=
role
+
message
+
self
.
sep
else
:
ret
+=
role
return
ret
else
:
raise
ValueError
(
f
"Invalid style:
{
self
.
sep_style
}
"
)
...
...
@@ -604,6 +617,20 @@ register_conv_template(
)
)
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
register_conv_template
(
Conversation
(
name
=
"gemma-it"
,
system_message
=
"You are a helpful assistant."
,
system_template
=
"<bos><start_of_turn>user{system_message}
\n\n
"
,
roles
=
(
"<start_of_turn>user
\n
"
,
"<start_of_turn>model
\n
"
),
sep
=
"<end_of_turn>
\n
"
,
sep_style
=
SeparatorStyle
.
GEMMA3
,
stop_str
=
[
"<end_of_turn>"
],
image_token
=
"<start_of_image>"
,
)
)
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
register_conv_template
(
Conversation
(
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
9d02bb3e
...
...
@@ -34,6 +34,8 @@ from sglang.srt.configs import (
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
Gemma3Config
,
Gemma3TextConfig
,
MultiModalityConfig
,
Qwen2_5_VLConfig
,
)
...
...
@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig
.
model_type
:
ExaoneConfig
,
Qwen2_5_VLConfig
.
model_type
:
Qwen2_5_VLConfig
,
MultiModalityConfig
.
model_type
:
MultiModalityConfig
,
Gemma3Config
.
model_type
:
Gemma3Config
,
Gemma3TextConfig
.
model_type
:
Gemma3TextConfig
,
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
python/sglang/srt/layers/attention/vision.py
View file @
9d02bb3e
...
...
@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
,
rotate_half
from
sglang.srt.utils
import
add_prefix
# Copied from transformers, modeling_qwen2_vl.py
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb_vision
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_q_dtype
=
q
.
dtype
orig_k_dtype
=
k
.
dtype
q
,
k
=
q
.
float
(),
k
.
float
()
cos
,
sin
=
cos
.
unsqueeze
(
-
2
).
float
(),
sin
.
unsqueeze
(
-
2
).
float
()
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
q_embed
=
q_embed
.
to
(
orig_q_dtype
)
k_embed
=
k_embed
.
to
(
orig_k_dtype
)
return
q_embed
,
k_embed
class
VisionAttention
(
nn
.
Module
):
r
"""
Multi-headed attention without any cache, mostly used for ViT.
...
...
@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
q
,
k
=
q
.
view
(
s
,
head
,
-
1
),
k
.
view
(
s
,
head
,
-
1
)
q
,
k
=
apply_rotary_pos_emb
_vision
(
q
,
k
,
cos
,
sin
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
,
k
=
q
.
reshape
(
original_shape
),
k
.
reshape
(
original_shape
)
if
self
.
use_qkv_parallel
:
...
...
python/sglang/srt/layers/layernorm.py
View file @
9d02bb3e
...
...
@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
return
out
class
Gemma3RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output
=
output
*
(
1.0
+
self
.
weight
.
float
())
return
output
.
type_as
(
x
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
if
not
_is_cuda
:
logger
.
info
(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
9d02bb3e
...
...
@@ -1173,6 +1173,37 @@ def get_rope(
return
rotary_emb
# Copied from transformers
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
unsqueeze_dim
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_q_dtype
=
q
.
dtype
orig_k_dtype
=
k
.
dtype
q
,
k
=
q
.
float
(),
k
.
float
()
# embedding is performed in float
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
).
float
()
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
).
float
()
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
q_embed
=
q_embed
.
to
(
orig_q_dtype
)
k_embed
=
k_embed
.
to
(
orig_k_dtype
)
return
q_embed
,
k_embed
def
get_rope_cpu
(
head_size
:
int
,
rotary_dim
:
int
,
...
...
python/sglang/srt/managers/image_processors/base_image_processor.py
View file @
9d02bb3e
...
...
@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
def
load_images
(
self
,
input_ids
:
list
,
input_ids
:
list
[
int
]
,
image_data
,
image_token
:
str
,
max_req_input_len
:
int
,
...
...
@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
image_hashes
,
image_sizes
=
[],
[]
all_frames
=
[]
new_text_parts
=
[]
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
else
:
input_text
=
input_ids
if
return_text
:
text_parts
=
input_text
.
split
(
image_token
)
import
re
pattern
=
"("
+
"|"
.
join
(
re
.
escape
(
sep
)
for
sep
in
[
image_token
])
+
")"
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
pattern
,
input_text
)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
...
...
@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
total_frame_count
=
sum
(
estimated_frames_list
)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
total_frame_count
)
_
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
max
(
1
,
total_frame_count
)
)
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
# Process each input with allocated frames
for
image_index
,
(
image
,
estimated_frames
)
in
enumerate
(
zip
(
image_data
,
estimated_frames_list
)
):
if
len
(
all_frames
)
>=
MAX_NUM_FRAMES
:
max_frames_to_process
=
0
else
:
max_frames_to_process
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
))
if
max_frames_to_process
==
0
:
frames
=
[]
else
:
try
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
frames
=
BaseImageProcessor
.
encode_video
(
path
,
frame_count_limit
=
max_frames_to_process
)
image_index
,
audio_index
=
0
,
0
hashes
,
image_sizes
,
images
,
audios
=
[],
[],
[],
[]
new_text
=
""
for
index
,
text_part
in
enumerate
(
text_parts
):
try
:
if
text_part
==
image_token
:
# load as image
frames_to_process
=
estimated_frames_list
[
image_index
]
if
frames_to_process
==
0
:
frames
=
[]
else
:
raw_image
,
_size
=
load_image
(
image
)
if
discard_alpha_channel
:
raw_image
=
raw_image
.
convert
(
"RGB"
)
frames
=
[
raw_image
]
assert
len
(
frames
)
!=
0
except
FileNotFoundError
as
e
:
print
(
e
)
return
None
image_sizes
+=
[
frames
[
0
].
size
]
*
len
(
frames
)
image_hashes
+=
[
hash
(
image
)]
*
len
(
frames
)
all_frames
+=
frames
if
return_text
:
new_text_parts
.
append
(
text_parts
[
image_index
])
if
max_frames_to_process
!=
0
:
new_text_parts
.
append
(
image_token
*
len
(
frames
))
assert
max_frames_to_process
>=
len
(
frames
)
if
return_text
:
new_text_parts
.
append
(
text_parts
[
-
1
])
image_file
=
image_data
[
image_index
]
if
isinstance
(
image_file
,
str
)
and
image_file
.
startswith
(
"video:"
):
# video
path
=
image_file
[
len
(
"video:"
)
:]
frames
=
self
.
encode_video
(
path
,
frame_count_limit
=
frames_to_process
)
else
:
# image
raw_image
,
_size
=
load_image
(
image_file
)
if
discard_alpha_channel
:
raw_image
=
raw_image
.
convert
(
"RGB"
)
frames
=
[
raw_image
]
if
len
(
frames
)
==
0
:
continue
image_sizes
+=
frames
[
0
].
size
*
len
(
frames
)
hashes
+=
[
hash
(
image_file
)]
*
len
(
frames
)
images
+=
frames
image_index
+=
1
if
frames_to_process
!=
0
:
new_text
+=
image_token
*
len
(
frames
)
assert
frames_to_process
==
len
(
frames
)
else
:
# TODO(mick): handle video
# normal text
new_text
+=
text_part
except
Exception
as
e
:
import
openai
logger
.
error
(
f
"An exception occurred while loading images:
{
e
}
"
)
raise
BadRequestError
(
f
"An exception occurred while loading images:
{
e
}
"
)
continue
input_text
=
""
.
join
(
new_text_parts
)
return
BaseImageProcessorOutput
(
image_hashes
,
image_sizes
,
all_frames
,
input_text
image_hashes
=
hashes
,
image_sizes
=
image_sizes
,
all_frames
=
images
,
input_text
=
new_text
,
)
...
...
python/sglang/srt/managers/image_processors/gemma3.py
0 → 100644
View file @
9d02bb3e
import
asyncio
from
typing
import
List
,
Union
from
transformers.utils
import
logging
from
sglang.srt.managers.image_processor
import
(
BaseImageProcessor
as
SGLangBaseImageProcessor
,
)
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger
=
logging
.
get_logger
(
__name__
)
class
Gemma3SGLangImageProcessor
(
SGLangBaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<start_of_image>"
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
,
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values
=
getattr
(
result
,
"pixel_values"
,
None
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
pixel_values
,
}
async
def
_process_images
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Gemma3SGLangImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
image_token
,
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
)
ret
=
await
self
.
_process_images
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
all_frames
)
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"image_hashes"
:
base_output
.
image_hashes
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
ImageProcessorMapping
=
{
Gemma3ForConditionalGeneration
:
Gemma3SGLangImageProcessor
,
}
python/sglang/srt/managers/image_processors/janus_pro.py
View file @
9d02bb3e
...
...
@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
image_data
=
[
image_data
]
base_out
=
self
.
load_images
(
input_ids
,
image_data
,
"<image_placeholder>"
,
max_req_input_len
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
"<image_placeholder>"
,
max_req_input_len
=
max_req_input_len
,
)
images
=
base_out
.
all_frames
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
...
...
python/sglang/srt/managers/image_processors/minicpmv.py
View file @
9d02bb3e
...
...
@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
image_data
=
[
image_data
]
base_output
=
self
.
load_images
(
input_ids
,
image_data
,
self
.
IMAGE_TOKEN
,
max_req_input_len
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
self
.
IMAGE_TOKEN
,
max_req_input_len
=
max_req_input_len
,
)
if
base_output
is
None
:
return
None
...
...
python/sglang/srt/managers/image_processors/qwen_vl.py
View file @
9d02bb3e
...
...
@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
input_ids
,
image_data
,
image_token
,
max_req_input_len
,
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
image_token
,
max_req_input_len
=
max_req_input_len
,
)
def
smart_resize
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
9d02bb3e
...
...
@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
,
next_power_of_2
from
sglang.srt.utils
import
get_compiler_backend
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
...
@@ -207,6 +207,9 @@ class ImageInputs:
return
ret
def
merge
(
self
,
other
):
"""
merge image inputs when requests are being merged
"""
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
9d02bb3e
...
...
@@ -33,6 +33,7 @@ from dataclasses import dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -331,6 +332,32 @@ class ForwardBatch:
return
ret
def
get_merged_image_inputs
(
self
)
->
Optional
[
ImageInputs
]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Returns:
if none, current batch contains no image input
"""
if
not
self
.
image_inputs
or
all
(
x
is
None
for
x
in
self
.
image_inputs
):
return
None
# Filter out None values
valid_inputs
=
[
x
for
x
in
self
.
image_inputs
if
x
is
not
None
]
# Start with the first valid image input
merged
=
valid_inputs
[
0
]
# Merge remaining inputs
for
img_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
img_input
)
if
isinstance
(
merged
.
pixel_values
,
np
.
ndarray
):
merged
.
pixel_values
=
torch
.
from_numpy
(
merged
.
pixel_values
)
return
merged
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
...
...
python/sglang/srt/models/gemma3_causal.py
0 → 100644
View file @
9d02bb3e
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
copy
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
einops
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
(
ROPE_INIT_FUNCTIONS
,
AutoModel
,
PretrainedConfig
,
PreTrainedModel
,
)
from
sglang.srt.configs.gemma3
import
Gemma3TextConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
,
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.utils
import
add_prefix
,
make_layers
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def
extract_layer_index
(
prefix
:
str
)
->
int
:
"""Extract the layer index from a prefix string."""
parts
=
prefix
.
split
(
"."
)
for
part
in
parts
:
if
part
.
startswith
(
"layers."
):
layer_str
=
part
.
split
(
"."
)[
-
1
]
try
:
return
int
(
layer_str
)
except
ValueError
:
continue
return
-
1
class
Gemma3MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_activation
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
)
if
hidden_activation
!=
"gelu_pytorch_tanh"
:
raise
ValueError
(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
self
.
act_fn
=
GeluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Gemma3Attention
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
:
int
,
config
:
Gemma3TextConfig
,
max_position_embeddings
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
config
=
config
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
config
.
num_key_value_heads
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
hidden_size
=
config
.
hidden_size
head_dim
=
getattr
(
config
,
"head_dim"
,
hidden_size
//
config
.
num_attention_heads
)
self
.
head_dim
=
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
config
.
query_pre_attn_scalar
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
)
# Determine if layer uses sliding window based on pattern
self
.
is_sliding
=
bool
((
layer_id
+
1
)
%
config
.
sliding_window_pattern
)
# Initialize the rotary embedding.
if
self
.
is_sliding
:
# Local attention. Override the values in config.json.
self
.
rope_theta
=
config
.
rope_local_base_freq
self
.
rope_scaling
=
{
"rope_type"
:
"default"
}
# FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window
self
.
sliding_window
=
config
.
sliding_window
else
:
# Global attention. Use the values in config.json.
self
.
rope_theta
=
config
.
rope_theta
self
.
rope_scaling
=
config
.
rope_scaling
self
.
sliding_window
=
None
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
logit_cap
=
getattr
(
self
.
config
,
"attn_logit_softcapping"
,
None
),
sliding_window_size
=
self
.
sliding_window
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
# Gemma3 adds normalization for q and k
self
.
q_norm
=
Gemma3RMSNorm
(
dim
=
config
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
Gemma3RMSNorm
(
dim
=
config
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
def
naive_attn_with_masks
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
# Expand the key and value to handle GQA.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
k
=
k
.
repeat_interleave
(
num_queries_per_kv
,
dim
=-
2
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
v
=
v
.
repeat_interleave
(
num_queries_per_kv
,
dim
=-
2
)
if
self
.
is_sliding
:
attn_masks
=
kwargs
[
"local_attn_masks"
]
else
:
attn_masks
=
kwargs
[
"global_attn_masks"
]
seq_lens
=
kwargs
[
"seq_lens"
]
start_idx
=
0
for
seq_len
,
attn_mask
in
zip
(
seq_lens
,
attn_masks
):
end_idx
=
start_idx
+
seq_len
query
=
q
[
start_idx
:
end_idx
].
unsqueeze
(
0
)
key
=
k
[
start_idx
:
end_idx
].
unsqueeze
(
0
)
value
=
v
[
start_idx
:
end_idx
].
unsqueeze
(
0
)
# Transpose.
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
value
=
value
.
transpose
(
1
,
2
)
output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
,
self
.
scaling
,
)
output
=
output
.
transpose
(
1
,
2
).
flatten
(
-
2
,
-
1
)
out
[
start_idx
:
end_idx
]
=
output
start_idx
=
end_idx
return
out
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
forward_batch
:
ForwardBatch
,
**
kwargs
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
# [s, h * head_dim]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# [s, h, head_dim]
q
=
q
.
unflatten
(
-
1
,
(
self
.
num_heads
,
self
.
head_dim
))
# -> [h, s, head_dim]
q
=
q
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
q
=
self
.
q_norm
(
q
)
k
=
k
.
unflatten
(
-
1
,
(
self
.
num_kv_heads
,
self
.
head_dim
))
# -> [h, s, head_dim]
k
=
k
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
k
=
self
.
k_norm
(
k
)
# q, k = self.rotary_emb(positions, q, k)
cos
,
sin
=
position_embeddings
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
# [b, h, s, head_dim] -> [b, s, h, head_dim]
q
=
q
.
permute
(
0
,
2
,
1
,
3
)
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
=
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Gemma3DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
:
int
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Gemma3Attention
(
layer_id
=
layer_id
,
config
=
config
,
max_position_embeddings
=
config
.
max_position_embeddings
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
self
.
hidden_size
=
config
.
hidden_size
self
.
mlp
=
Gemma3MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_activation
=
config
.
hidden_activation
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
self
.
input_layernorm
=
Gemma3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Gemma3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_feedforward_layernorm
=
Gemma3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
Gemma3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
is_sliding
=
self
.
self_attn
.
is_sliding
self
.
layer_id
=
layer_id
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_embeddings_global
:
torch
.
Tensor
,
position_embeddings_local
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
**
kwargs
,
)
->
tuple
[
torch
.
FloatTensor
,
Optional
[
tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]
]:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# apply global RoPE to non-sliding layer only
if
self
.
self_attn
.
is_sliding
:
position_embeddings
=
position_embeddings_local
else
:
position_embeddings
=
position_embeddings_global
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
position_embeddings
=
position_embeddings
,
forward_batch
=
forward_batch
,
**
kwargs
,
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
pre_feedforward_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
return
outputs
class
Gemma3RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Gemma3TextConfig
,
device
=
None
):
super
().
__init__
()
# BC: "rope_type" was originally "type"
if
hasattr
(
config
,
"rope_scaling"
)
and
config
.
rope_scaling
is
not
None
:
self
.
rope_type
=
config
.
rope_scaling
.
get
(
"rope_type"
,
config
.
rope_scaling
.
get
(
"type"
)
)
else
:
self
.
rope_type
=
"default"
self
.
max_seq_len_cached
=
config
.
max_position_embeddings
self
.
original_max_seq_len
=
config
.
max_position_embeddings
self
.
config
=
config
self
.
rope_init_fn
=
ROPE_INIT_FUNCTIONS
[
self
.
rope_type
]
inv_freq
,
self
.
attention_scaling
=
self
.
rope_init_fn
(
self
.
config
,
device
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
original_inv_freq
=
self
.
inv_freq
def
_dynamic_frequency_update
(
self
,
position_ids
,
device
):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len
=
torch
.
max
(
position_ids
)
+
1
if
seq_len
>
self
.
max_seq_len_cached
:
# growth
inv_freq
,
self
.
attention_scaling
=
self
.
rope_init_fn
(
self
.
config
,
device
,
seq_len
=
seq_len
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
# TODO joao: may break with compilation
self
.
max_seq_len_cached
=
seq_len
if
(
seq_len
<
self
.
original_max_seq_len
and
self
.
max_seq_len_cached
>
self
.
original_max_seq_len
):
# reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self
.
original_inv_freq
=
self
.
original_inv_freq
.
to
(
device
)
self
.
register_buffer
(
"inv_freq"
,
self
.
original_inv_freq
,
persistent
=
False
)
self
.
max_seq_len_cached
=
self
.
original_max_seq_len
@
torch
.
no_grad
()
def
forward
(
self
,
x
,
position_ids
):
if
"dynamic"
in
self
.
rope_type
:
self
.
_dynamic_frequency_update
(
position_ids
,
device
=
x
.
device
)
# Core RoPE block
inv_freq_expanded
=
(
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type
=
x
.
device
.
type
device_type
=
(
device_type
if
isinstance
(
device_type
,
str
)
and
device_type
!=
"mps"
else
"cpu"
)
with
torch
.
autocast
(
device_type
=
device_type
,
enabled
=
False
):
freqs
=
(
inv_freq_expanded
.
float
().
to
(
x
.
device
)
@
position_ids_expanded
.
float
()
).
transpose
(
1
,
2
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
()
sin
=
emb
.
sin
()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos
=
cos
*
self
.
attention_scaling
sin
=
sin
*
self
.
attention_scaling
return
cos
.
to
(
dtype
=
x
.
dtype
),
sin
.
to
(
dtype
=
x
.
dtype
)
class
Gemma3TextScaledWordEmbedding
(
nn
.
Embedding
):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
embed_scale
:
Optional
[
float
]
=
1.0
,
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
embed_scale
=
embed_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
):
return
super
().
forward
(
input_ids
)
*
self
.
embed_scale
class
Gemma3TextModel
(
PreTrainedModel
):
def
__init__
(
self
,
config
:
Gemma3TextConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
self
.
embed_tokens
=
Gemma3TextScaledWordEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
,
embed_scale
=
self
.
config
.
hidden_size
**
0.5
,
)
self
.
norm
=
Gemma3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
rotary_emb
=
Gemma3RotaryEmbedding
(
config
=
config
)
self
.
gradient_checkpointing
=
False
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
config
=
copy
.
deepcopy
(
config
)
config
.
rope_theta
=
config
.
rope_local_base_freq
config
.
rope_scaling
=
{
"rope_type"
:
"default"
}
self
.
rotary_emb_local
=
Gemma3RotaryEmbedding
(
config
=
config
)
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
Gemma3DecoderLayer
(
layer_id
=
idx
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
),
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
self
.
norm
=
Gemma3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_init
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
if
len
(
positions
.
shape
)
==
1
:
positions
=
einops
.
rearrange
(
positions
,
"s -> 1 s"
)
position_embeddings_global
=
self
.
rotary_emb
(
hidden_states
,
positions
)
position_embeddings_local
=
self
.
rotary_emb_local
(
hidden_states
,
positions
)
for
layer
in
self
.
layers
:
layer_outputs
=
layer
(
positions
=
positions
,
position_embeddings_global
=
position_embeddings_global
,
position_embeddings_local
=
position_embeddings_local
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
**
kwargs
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Gemma3ForCausalLM
(
PreTrainedModel
):
config_class
=
Gemma3TextConfig
_tied_weights_keys
=
[
"lm_head.weight"
]
_tp_plan
=
{
"lm_head"
:
"colwise_rep"
}
_pp_plan
=
{
"lm_head"
:
([
"hidden_states"
],
[
"logits"
])}
config_class
=
Gemma3TextConfig
base_model_prefix
=
"language_model"
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
supports_lora
=
True
def
__init__
(
self
,
config
:
Gemma3TextConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma3TextModel
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
model
.
layers
[
0
].
mlp
.
gate_up_proj
.
weight
.
dtype
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
LogitsProcessor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
**
kwargs
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
# if param_name in name:
# print(f"{param_name} is already in {name}")
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
# unloaded_params = params_dict.keys() - loaded_params
# if unloaded_params:
# logger.warning(
# "Some weights are not initialized from checkpoints: %s", unloaded_params
# )
return
loaded_params
EntryClass
=
Gemma3ForCausalLM
AutoModel
.
register
(
Gemma3TextConfig
,
Gemma3ForCausalLM
,
exist_ok
=
True
)
python/sglang/srt/models/gemma3_mm.py
0 → 100644
View file @
9d02bb3e
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import
logging
from
functools
import
lru_cache
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
TypedDict
import
torch
from
torch
import
nn
from
transformers
import
AutoModel
,
PreTrainedModel
from
sglang.srt.configs
import
Gemma3Config
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.multi_modality_padding
import
(
MultiModalityDataPaddingPatternTokenPairs
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.models.gemma3_causal
import
Gemma3ForCausalLM
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
cached_get_processor
=
lru_cache
(
get_processor
)
class
Gemma3ImagePixelInputs
(
TypedDict
):
pixel_values
:
torch
.
Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class
Gemma3MultiModalProjector
(
nn
.
Module
):
"""Projector for Gemma3 multimodal."""
def
__init__
(
self
,
config
:
Gemma3Config
):
super
().
__init__
()
self
.
mm_input_projection_weight
=
nn
.
Parameter
(
torch
.
zeros
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
)
self
.
mm_soft_emb_norm
=
Gemma3RMSNorm
(
config
.
vision_config
.
hidden_size
,
eps
=
config
.
vision_config
.
layer_norm_eps
)
self
.
patches_per_image
=
int
(
config
.
vision_config
.
image_size
//
config
.
vision_config
.
patch_size
)
self
.
tokens_per_side
=
int
(
config
.
mm_tokens_per_image
**
0.5
)
self
.
kernel_size
=
self
.
patches_per_image
//
self
.
tokens_per_side
self
.
avg_pool
=
nn
.
AvgPool2d
(
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
kernel_size
)
def
forward
(
self
,
vision_outputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
seq_length
,
hidden_size
=
vision_outputs
.
shape
# Reshape for pooling
reshaped_vision_outputs
=
vision_outputs
.
transpose
(
1
,
2
)
reshaped_vision_outputs
=
reshaped_vision_outputs
.
reshape
(
batch_size
,
hidden_size
,
self
.
patches_per_image
,
self
.
patches_per_image
)
reshaped_vision_outputs
=
reshaped_vision_outputs
.
contiguous
()
# Apply pooling
pooled_vision_outputs
=
self
.
avg_pool
(
reshaped_vision_outputs
)
pooled_vision_outputs
=
pooled_vision_outputs
.
flatten
(
2
)
pooled_vision_outputs
=
pooled_vision_outputs
.
transpose
(
1
,
2
)
# Apply normalization
normed_vision_outputs
=
self
.
mm_soft_emb_norm
(
pooled_vision_outputs
)
# Project to text embedding space
projected_vision_outputs
=
torch
.
matmul
(
normed_vision_outputs
,
self
.
mm_input_projection_weight
)
return
projected_vision_outputs
.
type_as
(
vision_outputs
)
class
Gemma3ForConditionalGeneration
(
PreTrainedModel
):
config_class
=
Gemma3Config
"""Gemma3 multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
supports_lora
=
True
def
__init__
(
self
,
config
:
Gemma3Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
quant_config
=
quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self
.
vision_tower
=
AutoModel
.
from_config
(
config
=
config
.
vision_config
)
self
.
multi_modal_projector
=
Gemma3MultiModalProjector
(
config
)
self
.
vocab_size
=
config
.
text_config
.
vocab_size
# Text model
self
.
language_model
=
Gemma3ForCausalLM
(
config
.
text_config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
if
self
.
language_model
.
logits_processor
.
logit_scale
:
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
language_model
.
logits_processor
.
logit_scale
*=
logit_scale
self
.
post_init
()
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
)
->
List
[
int
]:
"""Pad input IDs with image tokens."""
# Get special token IDs
im_start_id
:
int
=
image_inputs
.
im_start_id
im_end_id
:
int
=
image_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
ids
=
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
return
ids
def
prepare_attn_masks
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mask_dtype
:
torch
.
dtype
,
**
kwargs
,
)
->
Dict
:
"""Prepare attention masks for multimodal inputs."""
kwargs
[
"has_images"
]
=
True
# Distinguish sequences by position id 0
start_indices
=
(
positions
==
0
).
cpu
().
nonzero
()
num_seqs
=
len
(
start_indices
)
seq_lens
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
start_indices
[
i
].
item
()
if
i
<
num_seqs
-
1
:
end_idx
=
start_indices
[
i
+
1
].
item
()
else
:
end_idx
=
len
(
input_ids
)
seq_lens
.
append
(
end_idx
-
start_idx
)
kwargs
[
"seq_lens"
]
=
seq_lens
# Create attention masks
global_attn_masks
=
[]
local_attn_masks
=
[]
sliding_window
=
self
.
config
.
text_config
.
interleaved_sliding_window
start_idx
=
0
for
seq_len
in
seq_lens
:
end_idx
=
start_idx
+
seq_len
input_token_ids
=
input_ids
[
start_idx
:
end_idx
]
start_idx
=
end_idx
# Create global causal mask
global_attn_mask
=
torch
.
empty
(
1
,
1
,
seq_len
,
seq_len
,
dtype
=
mask_dtype
,
device
=
input_ids
.
device
,
)
global_attn_mask
.
fill_
(
float
(
"-inf"
))
global_attn_mask
=
global_attn_mask
.
triu
(
diagonal
=
1
)
# Consider bidirectional attention between image tokens
img_mask
=
torch
.
zeros_like
(
global_attn_mask
)
img_pos
=
input_token_ids
==
self
.
config
.
image_token_index
img_mask
[:,
:,
:,
img_pos
]
+=
1
img_mask
[:,
:,
img_pos
,
:]
+=
1
global_attn_mask
=
torch
.
where
(
img_mask
==
2
,
0
,
global_attn_mask
)
global_attn_masks
.
append
(
global_attn_mask
)
# Create local causal mask with sliding window
local_attn_mask
=
torch
.
ones_like
(
global_attn_mask
)
local_attn_mask
=
torch
.
tril
(
local_attn_mask
,
diagonal
=-
sliding_window
)
local_attn_mask
=
torch
.
where
(
local_attn_mask
==
0
,
global_attn_mask
,
float
(
"-inf"
)
)
local_attn_masks
.
append
(
local_attn_mask
)
kwargs
[
"global_attn_masks"
]
=
global_attn_masks
kwargs
[
"local_attn_masks"
]
=
local_attn_masks
return
kwargs
def
get_input_embeddings
(
self
):
return
self
.
language_model
.
get_input_embeddings
()
def
get_image_features
(
self
,
pixel_values
:
torch
.
Tensor
):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values
=
pixel_values
.
to
(
"cuda"
)
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
vision_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
).
last_hidden_state
image_features
=
self
.
multi_modal_projector
(
vision_outputs
)
return
image_features
def
embed_image_inputs
(
self
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
image_input
:
ImageInputs
,
)
->
torch
.
Tensor
:
if
input_ids
is
None
:
raise
ValueError
(
"Unimplemented"
)
# boolean-masking image tokens
special_image_mask
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
image_input
.
pad_values
,
device
=
input_ids
.
device
),
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
inputs_embeds
=
None
if
num_image_tokens_in_input_ids
==
0
:
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
return
inputs_embeds
else
:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features
=
self
.
get_image_features
(
image_input
.
pixel_values
)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding
=
(
image_features
.
shape
[
0
]
*
image_features
.
shape
[
1
]
)
if
num_image_tokens_in_input_ids
!=
num_image_tokens_in_embedding
:
num_image
=
num_image_tokens_in_input_ids
//
image_features
.
shape
[
1
]
image_features
=
image_features
[:
num_image
,
:]
logger
.
warning
(
f
"Number of images does not match number of special image tokens in the input text. "
f
"Got
{
num_image_tokens_in_input_ids
}
image tokens in the text but
{
num_image_tokens_in_embedding
}
"
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
special_image_mask
=
special_image_mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
image_features
=
image_features
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
special_image_mask
,
image_features
)
return
inputs_embeds
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
**
kwargs
:
object
,
)
->
LogitsProcessor
:
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
# Important: position_ids in Gemma3 are 1-indexed
# This really does cost me sometime
positions
+=
1
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if
input_ids
is
not
None
and
self
.
config
.
image_token_index
>=
self
.
vocab_size
:
special_image_mask
=
input_ids
==
self
.
config
.
image_token_index
llm_input_ids
=
input_ids
.
clone
()
llm_input_ids
[
special_image_mask
]
=
0
else
:
llm_input_ids
=
input_ids
merged_image_input
=
forward_batch
.
get_merged_image_inputs
()
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
merged_image_input
is
not
None
):
inputs_embeds
=
self
.
embed_image_inputs
(
input_ids
=
llm_input_ids
,
forward_batch
=
forward_batch
,
image_input
=
merged_image_input
,
)
else
:
llm_input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
llm_input_ids
)
outputs
=
self
.
language_model
(
input_ids
=
None
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
**
kwargs
,
)
return
outputs
def
tie_weights
(
self
):
return
self
.
language_model
.
tie_weights
()
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
"""Load weights for the model."""
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"language_model"
in
name
:
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
causal_loaded_params
=
Gemma3ForCausalLM
.
load_weights
(
self
,
[(
name
,
loaded_weight
)]
)
loaded_params
.
update
(
causal_loaded_params
)
continue
else
:
# Skip lm_head.weight as it's tied with embed_tokens
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
pass
# raise RuntimeError(
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
return
loaded_params
EntryClass
=
Gemma3ForConditionalGeneration
AutoModel
.
register
(
Gemma3Config
,
Gemma3ForConditionalGeneration
,
exist_ok
=
True
)
python/sglang/srt/utils.py
View file @
9d02bb3e
...
...
@@ -41,7 +41,6 @@ from functools import lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
from
io
import
BytesIO
from
multiprocessing
import
Pool
from
multiprocessing.reduction
import
ForkingPickler
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
...
...
@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
image
=
Image
.
open
(
BytesIO
(
image_file
))
elif
image_file
.
startswith
(
"http://"
)
or
image_file
.
startswith
(
"https://"
):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"3"
))
response
=
requests
.
get
(
image_file
,
timeout
=
timeout
)
image
=
Image
.
open
(
BytesIO
(
response
.
content
))
response
=
requests
.
get
(
image_file
,
stream
=
True
,
timeout
=
timeout
).
raw
image
=
Image
.
open
(
response
)
response
.
close
()
elif
image_file
.
lower
().
endswith
((
"png"
,
"jpg"
,
"jpeg"
,
"webp"
,
"gif"
)):
image
=
Image
.
open
(
image_file
)
elif
image_file
.
startswith
(
"data:"
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment