Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d02bb3e
Unverified
Commit
9d02bb3e
authored
Mar 17, 2025
by
Mick
Committed by
GitHub
Mar 16, 2025
Browse files
Urgent model support: support gemma-3-it (#4424)
parent
402db5c5
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2543 additions
and
86 deletions
+2543
-86
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+8
-0
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+3
-0
python/sglang/srt/configs/gemma3.py
python/sglang/srt/configs/gemma3.py
+1086
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+7
-2
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+27
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+4
-0
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+2
-26
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+20
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+31
-0
python/sglang/srt/managers/image_processors/base_image_processor.py
...ang/srt/managers/image_processors/base_image_processor.py
+59
-48
python/sglang/srt/managers/image_processors/gemma3.py
python/sglang/srt/managers/image_processors/gemma3.py
+100
-0
python/sglang/srt/managers/image_processors/janus_pro.py
python/sglang/srt/managers/image_processors/janus_pro.py
+4
-1
python/sglang/srt/managers/image_processors/minicpmv.py
python/sglang/srt/managers/image_processors/minicpmv.py
+4
-1
python/sglang/srt/managers/image_processors/qwen_vl.py
python/sglang/srt/managers/image_processors/qwen_vl.py
+4
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+27
-0
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+687
-0
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+462
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+3
-3
No files found.
docs/references/supported_models.md
View file @
9d02bb3e
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
-
Phi-3-Small
-
Phi-3-Small
-
IBM Granite 3
-
IBM Granite 3
-
Janus-Pro-1B / Janus-Pro-7B
-
Janus-Pro-1B / Janus-Pro-7B
-
Gemma 3 (it)
## Embedding Models
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
9d02bb3e
...
@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
...
@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
return
get_chat_template
(
"granite-3-instruct"
)
return
get_chat_template
(
"granite-3-instruct"
)
@
register_chat_template_matching_function
def
match_gemma3_instruct
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"gemma-3"
in
model_path
and
"1b"
not
in
model_path
:
# gemma-3-1b-it is completion model
return
get_chat_template
(
"gemma-it"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
messages
=
[
messages
=
[
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
...
...
python/sglang/srt/configs/__init__.py
View file @
9d02bb3e
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.gemma3
import
Gemma3Config
,
Gemma3TextConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.qwen2_5_vl_config
import
(
from
sglang.srt.configs.qwen2_5_vl_config
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLConfig
,
...
@@ -14,4 +15,6 @@ __all__ = [
...
@@ -14,4 +15,6 @@ __all__ = [
"Qwen2_5_VLConfig"
,
"Qwen2_5_VLConfig"
,
"Qwen2_5_VLVisionConfig"
,
"Qwen2_5_VLVisionConfig"
,
"MultiModalityConfig"
,
"MultiModalityConfig"
,
"Gemma3Config"
,
"Gemma3TextConfig"
,
]
]
python/sglang/srt/configs/gemma3.py
0 → 100644
View file @
9d02bb3e
This diff is collapsed.
Click to expand it.
python/sglang/srt/configs/model_config.py
View file @
9d02bb3e
...
@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
...
@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
dtype
=
dtype
.
lower
()
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
if
dtype
==
"auto"
:
if
config_dtype
==
torch
.
float32
:
if
config_dtype
==
torch
.
float32
:
if
config
.
model_type
==
"gemma2"
:
if
config
.
model_type
.
startswith
(
"gemma"
):
if
config
.
model_type
==
"gemma"
:
gemma_version
=
""
else
:
gemma_version
=
config
.
model_type
[
5
]
logger
.
info
(
logger
.
info
(
"For Gemma
2
, we downcast float32 to bfloat16 instead "
f
"For Gemma
{
gemma_version
}
, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
"want to use float16."
)
)
...
@@ -453,6 +457,7 @@ multimodal_model_archs = [
...
@@ -453,6 +457,7 @@ multimodal_model_archs = [
"LlavaQwenForCausalLM"
,
"LlavaQwenForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaVidForCausalLM"
,
"LlavaVidForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
"Grok1VForCausalLM"
,
"Grok1VForCausalLM"
,
"Grok1AForCausalLM"
,
"Grok1AForCausalLM"
,
"MllamaForConditionalGeneration"
,
"MllamaForConditionalGeneration"
,
...
...
python/sglang/srt/conversation.py
View file @
9d02bb3e
...
@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
...
@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
DEEPSEEK_CHAT
=
auto
()
DEEPSEEK_CHAT
=
auto
()
METAMATH
=
auto
()
METAMATH
=
auto
()
QWEN2_VL_EMBED
=
auto
()
QWEN2_VL_EMBED
=
auto
()
GEMMA3
=
auto
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -285,6 +286,18 @@ class Conversation:
...
@@ -285,6 +286,18 @@ class Conversation:
else
:
else
:
ret
+=
role
+
":"
ret
+=
role
+
":"
return
ret
return
ret
elif
self
.
sep_style
==
SeparatorStyle
.
GEMMA3
:
ret
=
system_prompt
for
i
,
(
role
,
message
)
in
enumerate
(
self
.
messages
):
if
message
:
if
i
==
0
:
ret
+=
message
+
self
.
sep
else
:
ret
+=
role
+
message
+
self
.
sep
else
:
ret
+=
role
return
ret
else
:
else
:
raise
ValueError
(
f
"Invalid style:
{
self
.
sep_style
}
"
)
raise
ValueError
(
f
"Invalid style:
{
self
.
sep_style
}
"
)
...
@@ -604,6 +617,20 @@ register_conv_template(
...
@@ -604,6 +617,20 @@ register_conv_template(
)
)
)
)
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
register_conv_template
(
Conversation
(
name
=
"gemma-it"
,
system_message
=
"You are a helpful assistant."
,
system_template
=
"<bos><start_of_turn>user{system_message}
\n\n
"
,
roles
=
(
"<start_of_turn>user
\n
"
,
"<start_of_turn>model
\n
"
),
sep
=
"<end_of_turn>
\n
"
,
sep_style
=
SeparatorStyle
.
GEMMA3
,
stop_str
=
[
"<end_of_turn>"
],
image_token
=
"<start_of_image>"
,
)
)
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
register_conv_template
(
register_conv_template
(
Conversation
(
Conversation
(
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
9d02bb3e
...
@@ -34,6 +34,8 @@ from sglang.srt.configs import (
...
@@ -34,6 +34,8 @@ from sglang.srt.configs import (
ChatGLMConfig
,
ChatGLMConfig
,
DbrxConfig
,
DbrxConfig
,
ExaoneConfig
,
ExaoneConfig
,
Gemma3Config
,
Gemma3TextConfig
,
MultiModalityConfig
,
MultiModalityConfig
,
Qwen2_5_VLConfig
,
Qwen2_5_VLConfig
,
)
)
...
@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
...
@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig
.
model_type
:
ExaoneConfig
,
ExaoneConfig
.
model_type
:
ExaoneConfig
,
Qwen2_5_VLConfig
.
model_type
:
Qwen2_5_VLConfig
,
Qwen2_5_VLConfig
.
model_type
:
Qwen2_5_VLConfig
,
MultiModalityConfig
.
model_type
:
MultiModalityConfig
,
MultiModalityConfig
.
model_type
:
MultiModalityConfig
,
Gemma3Config
.
model_type
:
Gemma3Config
,
Gemma3TextConfig
.
model_type
:
Gemma3TextConfig
,
}
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
python/sglang/srt/layers/attention/vision.py
View file @
9d02bb3e
...
@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
...
@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
,
rotate_half
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
# Copied from transformers, modeling_qwen2_vl.py
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb_vision
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_q_dtype
=
q
.
dtype
orig_k_dtype
=
k
.
dtype
q
,
k
=
q
.
float
(),
k
.
float
()
cos
,
sin
=
cos
.
unsqueeze
(
-
2
).
float
(),
sin
.
unsqueeze
(
-
2
).
float
()
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
q_embed
=
q_embed
.
to
(
orig_q_dtype
)
k_embed
=
k_embed
.
to
(
orig_k_dtype
)
return
q_embed
,
k_embed
class
VisionAttention
(
nn
.
Module
):
class
VisionAttention
(
nn
.
Module
):
r
"""
r
"""
Multi-headed attention without any cache, mostly used for ViT.
Multi-headed attention without any cache, mostly used for ViT.
...
@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
...
@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
cos
,
sin
=
position_embeddings
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
original_shape
=
q
.
shape
q
,
k
=
q
.
view
(
s
,
head
,
-
1
),
k
.
view
(
s
,
head
,
-
1
)
q
,
k
=
q
.
view
(
s
,
head
,
-
1
),
k
.
view
(
s
,
head
,
-
1
)
q
,
k
=
apply_rotary_pos_emb
_vision
(
q
,
k
,
cos
,
sin
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
,
k
=
q
.
reshape
(
original_shape
),
k
.
reshape
(
original_shape
)
q
,
k
=
q
.
reshape
(
original_shape
),
k
.
reshape
(
original_shape
)
if
self
.
use_qkv_parallel
:
if
self
.
use_qkv_parallel
:
...
...
python/sglang/srt/layers/layernorm.py
View file @
9d02bb3e
...
@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
...
@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
return
out
return
out
class
Gemma3RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output
=
output
*
(
1.0
+
self
.
weight
.
float
())
return
output
.
type_as
(
x
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
if
not
_is_cuda
:
if
not
_is_cuda
:
logger
.
info
(
logger
.
info
(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
9d02bb3e
...
@@ -1173,6 +1173,37 @@ def get_rope(
...
@@ -1173,6 +1173,37 @@ def get_rope(
return
rotary_emb
return
rotary_emb
# Copied from transformers
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
unsqueeze_dim
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_q_dtype
=
q
.
dtype
orig_k_dtype
=
k
.
dtype
q
,
k
=
q
.
float
(),
k
.
float
()
# embedding is performed in float
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
).
float
()
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
).
float
()
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
q_embed
=
q_embed
.
to
(
orig_q_dtype
)
k_embed
=
k_embed
.
to
(
orig_k_dtype
)
return
q_embed
,
k_embed
def
get_rope_cpu
(
def
get_rope_cpu
(
head_size
:
int
,
head_size
:
int
,
rotary_dim
:
int
,
rotary_dim
:
int
,
...
...
python/sglang/srt/managers/image_processors/base_image_processor.py
View file @
9d02bb3e
...
@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
...
@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
def
load_images
(
def
load_images
(
self
,
self
,
input_ids
:
list
,
input_ids
:
list
[
int
]
,
image_data
,
image_data
,
image_token
:
str
,
image_token
:
str
,
max_req_input_len
:
int
,
max_req_input_len
:
int
,
...
@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
...
@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Each frame of video/image will be replaced by a single image token
Args:
Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
"""
image_hashes
,
image_sizes
=
[],
[]
all_frames
=
[]
new_text_parts
=
[]
if
isinstance
(
input_ids
,
list
)
and
return_text
:
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
else
:
else
:
input_text
=
input_ids
input_text
=
input_ids
if
return_text
:
if
return_text
:
text_parts
=
input_text
.
split
(
image_token
)
import
re
pattern
=
"("
+
"|"
.
join
(
re
.
escape
(
sep
)
for
sep
in
[
image_token
])
+
")"
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
pattern
,
input_text
)
# TODO(mick): load from server_args, env, or sampling_params
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
MAX_NUM_FRAMES
=
30
...
@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
...
@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
total_frame_count
=
sum
(
estimated_frames_list
)
total_frame_count
=
sum
(
estimated_frames_list
)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
total_frame_count
)
_
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
max
(
1
,
total_frame_count
)
)
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
# Process each input with allocated frames
image_index
,
audio_index
=
0
,
0
for
image_index
,
(
image
,
estimated_frames
)
in
enumerate
(
hashes
,
image_sizes
,
images
,
audios
=
[],
[],
[],
[]
zip
(
image_data
,
estimated_frames_list
)
new_text
=
""
):
for
index
,
text_part
in
enumerate
(
text_parts
):
if
len
(
all_frames
)
>=
MAX_NUM_FRAMES
:
try
:
max_frames_to_process
=
0
if
text_part
==
image_token
:
else
:
# load as image
max_frames_to_process
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
))
frames_to_process
=
estimated_frames_list
[
image_index
]
if
frames_to_process
==
0
:
if
max_frames_to_process
==
0
:
frames
=
[]
frames
=
[]
else
:
try
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
frames
=
BaseImageProcessor
.
encode_video
(
path
,
frame_count_limit
=
max_frames_to_process
)
else
:
else
:
raw_image
,
_size
=
load_image
(
image
)
image_file
=
image_data
[
image_index
]
if
discard_alpha_channel
:
if
isinstance
(
image_file
,
str
)
and
image_file
.
startswith
(
raw_image
=
raw_image
.
convert
(
"RGB"
)
"video:"
frames
=
[
raw_image
]
):
assert
len
(
frames
)
!=
0
# video
except
FileNotFoundError
as
e
:
path
=
image_file
[
len
(
"video:"
)
:]
print
(
e
)
frames
=
self
.
encode_video
(
return
None
path
,
frame_count_limit
=
frames_to_process
)
image_sizes
+=
[
frames
[
0
].
size
]
*
len
(
frames
)
else
:
image_hashes
+=
[
hash
(
image
)]
*
len
(
frames
)
# image
all_frames
+=
frames
raw_image
,
_size
=
load_image
(
image_file
)
if
discard_alpha_channel
:
if
return_text
:
raw_image
=
raw_image
.
convert
(
"RGB"
)
new_text_parts
.
append
(
text_parts
[
image_index
])
frames
=
[
raw_image
]
if
max_frames_to_process
!=
0
:
if
len
(
frames
)
==
0
:
new_text_parts
.
append
(
image_token
*
len
(
frames
))
continue
assert
max_frames_to_process
>=
len
(
frames
)
if
return_text
:
image_sizes
+=
frames
[
0
].
size
*
len
(
frames
)
new_text_parts
.
append
(
text_parts
[
-
1
])
hashes
+=
[
hash
(
image_file
)]
*
len
(
frames
)
images
+=
frames
image_index
+=
1
if
frames_to_process
!=
0
:
new_text
+=
image_token
*
len
(
frames
)
assert
frames_to_process
==
len
(
frames
)
else
:
# TODO(mick): handle video
# normal text
new_text
+=
text_part
except
Exception
as
e
:
import
openai
logger
.
error
(
f
"An exception occurred while loading images:
{
e
}
"
)
raise
BadRequestError
(
f
"An exception occurred while loading images:
{
e
}
"
)
continue
input_text
=
""
.
join
(
new_text_parts
)
return
BaseImageProcessorOutput
(
return
BaseImageProcessorOutput
(
image_hashes
,
image_sizes
,
all_frames
,
input_text
image_hashes
=
hashes
,
image_sizes
=
image_sizes
,
all_frames
=
images
,
input_text
=
new_text
,
)
)
...
...
python/sglang/srt/managers/image_processors/gemma3.py
0 → 100644
View file @
9d02bb3e
import
asyncio
from
typing
import
List
,
Union
from
transformers.utils
import
logging
from
sglang.srt.managers.image_processor
import
(
BaseImageProcessor
as
SGLangBaseImageProcessor
,
)
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger
=
logging
.
get_logger
(
__name__
)
class
Gemma3SGLangImageProcessor
(
SGLangBaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<start_of_image>"
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
,
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values
=
getattr
(
result
,
"pixel_values"
,
None
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
pixel_values
,
}
async
def
_process_images
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Gemma3SGLangImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
image_token
,
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
)
ret
=
await
self
.
_process_images
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
all_frames
)
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"image_hashes"
:
base_output
.
image_hashes
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
ImageProcessorMapping
=
{
Gemma3ForConditionalGeneration
:
Gemma3SGLangImageProcessor
,
}
python/sglang/srt/managers/image_processors/janus_pro.py
View file @
9d02bb3e
...
@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
...
@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
base_out
=
self
.
load_images
(
base_out
=
self
.
load_images
(
input_ids
,
image_data
,
"<image_placeholder>"
,
max_req_input_len
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
"<image_placeholder>"
,
max_req_input_len
=
max_req_input_len
,
)
)
images
=
base_out
.
all_frames
images
=
base_out
.
all_frames
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
...
...
python/sglang/srt/managers/image_processors/minicpmv.py
View file @
9d02bb3e
...
@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
...
@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
base_output
=
self
.
load_images
(
base_output
=
self
.
load_images
(
input_ids
,
image_data
,
self
.
IMAGE_TOKEN
,
max_req_input_len
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
self
.
IMAGE_TOKEN
,
max_req_input_len
=
max_req_input_len
,
)
)
if
base_output
is
None
:
if
base_output
is
None
:
return
None
return
None
...
...
python/sglang/srt/managers/image_processors/qwen_vl.py
View file @
9d02bb3e
...
@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
...
@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
base_output
=
self
.
load_images
(
input_ids
,
input_ids
=
input_ids
,
image_data
,
image_data
=
image_data
,
image_token
,
image_token
=
image_token
,
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
def
smart_resize
(
def
smart_resize
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
9d02bb3e
...
@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
...
@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
,
next_power_of_2
from
sglang.srt.utils
import
get_compiler_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
@@ -207,6 +207,9 @@ class ImageInputs:
...
@@ -207,6 +207,9 @@ class ImageInputs:
return
ret
return
ret
def
merge
(
self
,
other
):
def
merge
(
self
,
other
):
"""
merge image inputs when requests are being merged
"""
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
9d02bb3e
...
@@ -33,6 +33,7 @@ from dataclasses import dataclass
...
@@ -33,6 +33,7 @@ from dataclasses import dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -331,6 +332,32 @@ class ForwardBatch:
...
@@ -331,6 +332,32 @@ class ForwardBatch:
return
ret
return
ret
def
get_merged_image_inputs
(
self
)
->
Optional
[
ImageInputs
]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Returns:
if none, current batch contains no image input
"""
if
not
self
.
image_inputs
or
all
(
x
is
None
for
x
in
self
.
image_inputs
):
return
None
# Filter out None values
valid_inputs
=
[
x
for
x
in
self
.
image_inputs
if
x
is
not
None
]
# Start with the first valid image input
merged
=
valid_inputs
[
0
]
# Merge remaining inputs
for
img_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
img_input
)
if
isinstance
(
merged
.
pixel_values
,
np
.
ndarray
):
merged
.
pixel_values
=
torch
.
from_numpy
(
merged
.
pixel_values
)
return
merged
def
_compute_mrope_positions
(
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
):
...
...
python/sglang/srt/models/gemma3_causal.py
0 → 100644
View file @
9d02bb3e
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/gemma3_mm.py
0 → 100644
View file @
9d02bb3e
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import
logging
from
functools
import
lru_cache
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
TypedDict
import
torch
from
torch
import
nn
from
transformers
import
AutoModel
,
PreTrainedModel
from
sglang.srt.configs
import
Gemma3Config
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.multi_modality_padding
import
(
MultiModalityDataPaddingPatternTokenPairs
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.models.gemma3_causal
import
Gemma3ForCausalLM
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
cached_get_processor
=
lru_cache
(
get_processor
)
class
Gemma3ImagePixelInputs
(
TypedDict
):
pixel_values
:
torch
.
Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class
Gemma3MultiModalProjector
(
nn
.
Module
):
"""Projector for Gemma3 multimodal."""
def
__init__
(
self
,
config
:
Gemma3Config
):
super
().
__init__
()
self
.
mm_input_projection_weight
=
nn
.
Parameter
(
torch
.
zeros
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
)
self
.
mm_soft_emb_norm
=
Gemma3RMSNorm
(
config
.
vision_config
.
hidden_size
,
eps
=
config
.
vision_config
.
layer_norm_eps
)
self
.
patches_per_image
=
int
(
config
.
vision_config
.
image_size
//
config
.
vision_config
.
patch_size
)
self
.
tokens_per_side
=
int
(
config
.
mm_tokens_per_image
**
0.5
)
self
.
kernel_size
=
self
.
patches_per_image
//
self
.
tokens_per_side
self
.
avg_pool
=
nn
.
AvgPool2d
(
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
kernel_size
)
def
forward
(
self
,
vision_outputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
seq_length
,
hidden_size
=
vision_outputs
.
shape
# Reshape for pooling
reshaped_vision_outputs
=
vision_outputs
.
transpose
(
1
,
2
)
reshaped_vision_outputs
=
reshaped_vision_outputs
.
reshape
(
batch_size
,
hidden_size
,
self
.
patches_per_image
,
self
.
patches_per_image
)
reshaped_vision_outputs
=
reshaped_vision_outputs
.
contiguous
()
# Apply pooling
pooled_vision_outputs
=
self
.
avg_pool
(
reshaped_vision_outputs
)
pooled_vision_outputs
=
pooled_vision_outputs
.
flatten
(
2
)
pooled_vision_outputs
=
pooled_vision_outputs
.
transpose
(
1
,
2
)
# Apply normalization
normed_vision_outputs
=
self
.
mm_soft_emb_norm
(
pooled_vision_outputs
)
# Project to text embedding space
projected_vision_outputs
=
torch
.
matmul
(
normed_vision_outputs
,
self
.
mm_input_projection_weight
)
return
projected_vision_outputs
.
type_as
(
vision_outputs
)
class
Gemma3ForConditionalGeneration
(
PreTrainedModel
):
config_class
=
Gemma3Config
"""Gemma3 multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
supports_lora
=
True
def
__init__
(
self
,
config
:
Gemma3Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
quant_config
=
quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self
.
vision_tower
=
AutoModel
.
from_config
(
config
=
config
.
vision_config
)
self
.
multi_modal_projector
=
Gemma3MultiModalProjector
(
config
)
self
.
vocab_size
=
config
.
text_config
.
vocab_size
# Text model
self
.
language_model
=
Gemma3ForCausalLM
(
config
.
text_config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
if
self
.
language_model
.
logits_processor
.
logit_scale
:
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
language_model
.
logits_processor
.
logit_scale
*=
logit_scale
self
.
post_init
()
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
)
->
List
[
int
]:
"""Pad input IDs with image tokens."""
# Get special token IDs
im_start_id
:
int
=
image_inputs
.
im_start_id
im_end_id
:
int
=
image_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
ids
=
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
return
ids
def
prepare_attn_masks
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
mask_dtype
:
torch
.
dtype
,
**
kwargs
,
)
->
Dict
:
"""Prepare attention masks for multimodal inputs."""
kwargs
[
"has_images"
]
=
True
# Distinguish sequences by position id 0
start_indices
=
(
positions
==
0
).
cpu
().
nonzero
()
num_seqs
=
len
(
start_indices
)
seq_lens
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
start_indices
[
i
].
item
()
if
i
<
num_seqs
-
1
:
end_idx
=
start_indices
[
i
+
1
].
item
()
else
:
end_idx
=
len
(
input_ids
)
seq_lens
.
append
(
end_idx
-
start_idx
)
kwargs
[
"seq_lens"
]
=
seq_lens
# Create attention masks
global_attn_masks
=
[]
local_attn_masks
=
[]
sliding_window
=
self
.
config
.
text_config
.
interleaved_sliding_window
start_idx
=
0
for
seq_len
in
seq_lens
:
end_idx
=
start_idx
+
seq_len
input_token_ids
=
input_ids
[
start_idx
:
end_idx
]
start_idx
=
end_idx
# Create global causal mask
global_attn_mask
=
torch
.
empty
(
1
,
1
,
seq_len
,
seq_len
,
dtype
=
mask_dtype
,
device
=
input_ids
.
device
,
)
global_attn_mask
.
fill_
(
float
(
"-inf"
))
global_attn_mask
=
global_attn_mask
.
triu
(
diagonal
=
1
)
# Consider bidirectional attention between image tokens
img_mask
=
torch
.
zeros_like
(
global_attn_mask
)
img_pos
=
input_token_ids
==
self
.
config
.
image_token_index
img_mask
[:,
:,
:,
img_pos
]
+=
1
img_mask
[:,
:,
img_pos
,
:]
+=
1
global_attn_mask
=
torch
.
where
(
img_mask
==
2
,
0
,
global_attn_mask
)
global_attn_masks
.
append
(
global_attn_mask
)
# Create local causal mask with sliding window
local_attn_mask
=
torch
.
ones_like
(
global_attn_mask
)
local_attn_mask
=
torch
.
tril
(
local_attn_mask
,
diagonal
=-
sliding_window
)
local_attn_mask
=
torch
.
where
(
local_attn_mask
==
0
,
global_attn_mask
,
float
(
"-inf"
)
)
local_attn_masks
.
append
(
local_attn_mask
)
kwargs
[
"global_attn_masks"
]
=
global_attn_masks
kwargs
[
"local_attn_masks"
]
=
local_attn_masks
return
kwargs
def
get_input_embeddings
(
self
):
return
self
.
language_model
.
get_input_embeddings
()
def
get_image_features
(
self
,
pixel_values
:
torch
.
Tensor
):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values
=
pixel_values
.
to
(
"cuda"
)
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
vision_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
).
last_hidden_state
image_features
=
self
.
multi_modal_projector
(
vision_outputs
)
return
image_features
def
embed_image_inputs
(
self
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
image_input
:
ImageInputs
,
)
->
torch
.
Tensor
:
if
input_ids
is
None
:
raise
ValueError
(
"Unimplemented"
)
# boolean-masking image tokens
special_image_mask
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
image_input
.
pad_values
,
device
=
input_ids
.
device
),
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
inputs_embeds
=
None
if
num_image_tokens_in_input_ids
==
0
:
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
return
inputs_embeds
else
:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features
=
self
.
get_image_features
(
image_input
.
pixel_values
)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding
=
(
image_features
.
shape
[
0
]
*
image_features
.
shape
[
1
]
)
if
num_image_tokens_in_input_ids
!=
num_image_tokens_in_embedding
:
num_image
=
num_image_tokens_in_input_ids
//
image_features
.
shape
[
1
]
image_features
=
image_features
[:
num_image
,
:]
logger
.
warning
(
f
"Number of images does not match number of special image tokens in the input text. "
f
"Got
{
num_image_tokens_in_input_ids
}
image tokens in the text but
{
num_image_tokens_in_embedding
}
"
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
special_image_mask
=
special_image_mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
image_features
=
image_features
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
special_image_mask
,
image_features
)
return
inputs_embeds
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
**
kwargs
:
object
,
)
->
LogitsProcessor
:
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
# Important: position_ids in Gemma3 are 1-indexed
# This really does cost me sometime
positions
+=
1
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if
input_ids
is
not
None
and
self
.
config
.
image_token_index
>=
self
.
vocab_size
:
special_image_mask
=
input_ids
==
self
.
config
.
image_token_index
llm_input_ids
=
input_ids
.
clone
()
llm_input_ids
[
special_image_mask
]
=
0
else
:
llm_input_ids
=
input_ids
merged_image_input
=
forward_batch
.
get_merged_image_inputs
()
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
merged_image_input
is
not
None
):
inputs_embeds
=
self
.
embed_image_inputs
(
input_ids
=
llm_input_ids
,
forward_batch
=
forward_batch
,
image_input
=
merged_image_input
,
)
else
:
llm_input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
llm_input_ids
)
outputs
=
self
.
language_model
(
input_ids
=
None
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
**
kwargs
,
)
return
outputs
def
tie_weights
(
self
):
return
self
.
language_model
.
tie_weights
()
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
"""Load weights for the model."""
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"language_model"
in
name
:
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
causal_loaded_params
=
Gemma3ForCausalLM
.
load_weights
(
self
,
[(
name
,
loaded_weight
)]
)
loaded_params
.
update
(
causal_loaded_params
)
continue
else
:
# Skip lm_head.weight as it's tied with embed_tokens
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
pass
# raise RuntimeError(
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
return
loaded_params
EntryClass
=
Gemma3ForConditionalGeneration
AutoModel
.
register
(
Gemma3Config
,
Gemma3ForConditionalGeneration
,
exist_ok
=
True
)
python/sglang/srt/utils.py
View file @
9d02bb3e
...
@@ -41,7 +41,6 @@ from functools import lru_cache
...
@@ -41,7 +41,6 @@ from functools import lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
io
import
BytesIO
from
io
import
BytesIO
from
multiprocessing
import
Pool
from
multiprocessing.reduction
import
ForkingPickler
from
multiprocessing.reduction
import
ForkingPickler
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
...
@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
...
@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
image
=
Image
.
open
(
BytesIO
(
image_file
))
image
=
Image
.
open
(
BytesIO
(
image_file
))
elif
image_file
.
startswith
(
"http://"
)
or
image_file
.
startswith
(
"https://"
):
elif
image_file
.
startswith
(
"http://"
)
or
image_file
.
startswith
(
"https://"
):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"3"
))
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"3"
))
response
=
requests
.
get
(
image_file
,
timeout
=
timeout
)
response
=
requests
.
get
(
image_file
,
stream
=
True
,
timeout
=
timeout
).
raw
image
=
Image
.
open
(
BytesIO
(
response
.
content
))
image
=
Image
.
open
(
response
)
response
.
close
()
elif
image_file
.
lower
().
endswith
((
"png"
,
"jpg"
,
"jpeg"
,
"webp"
,
"gif"
)):
elif
image_file
.
lower
().
endswith
((
"png"
,
"jpg"
,
"jpeg"
,
"webp"
,
"gif"
)):
image
=
Image
.
open
(
image_file
)
image
=
Image
.
open
(
image_file
)
elif
image_file
.
startswith
(
"data:"
):
elif
image_file
.
startswith
(
"data:"
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment