Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
04f2abcb
Unverified
Commit
04f2abcb
authored
Apr 22, 2025
by
Yineng Zhang
Committed by
GitHub
Apr 22, 2025
Browse files
fix: gemma 3 not use softcap (#5622)
parent
506be6b8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
2 deletions
+17
-2
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-0
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
04f2abcb
...
...
@@ -78,6 +78,11 @@ class ModelConfig:
logger
.
info
(
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
)
elif
self
.
hf_config
.
architectures
[
0
]
==
"Gemma3ForConditionalGeneration"
:
enable_multimodal
=
False
logger
.
info
(
"Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal."
)
else
:
enable_multimodal
=
True
...
...
python/sglang/srt/models/gemma3_causal.py
View file @
04f2abcb
...
...
@@ -189,7 +189,7 @@ class Gemma3Attention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
logit_cap
=
getattr
(
self
.
config
,
"attn_logit_softcapping"
,
None
)
,
logit_cap
=
0.0
,
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size
=
self
.
sliding_window
,
...
...
python/sglang/srt/server_args.py
View file @
04f2abcb
...
...
@@ -154,6 +154,7 @@ class ServerArgs:
disable_outlines_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
enable_llama4_multimodal
:
Optional
[
bool
]
=
None
enable_gemma3_multimodal
:
Optional
[
bool
]
=
None
disable_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
...
...
@@ -285,7 +286,9 @@ class ServerArgs:
if
self
.
grammar_backend
is
None
:
self
.
grammar_backend
=
"xgrammar"
self
.
enable_multimodal
:
Optional
[
bool
]
=
self
.
enable_llama4_multimodal
self
.
enable_multimodal
:
Optional
[
bool
]
=
(
self
.
enable_llama4_multimodal
or
self
.
enable_gemma3_multimodal
)
# Data parallelism attention
if
self
.
enable_dp_attention
:
...
...
@@ -984,6 +987,12 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Enable the multimodal functionality for Llama-4."
,
)
parser
.
add_argument
(
"--enable-gemma3-multimodal"
,
default
=
ServerArgs
.
enable_gemma3_multimodal
,
action
=
"store_true"
,
help
=
"Enable the multimodal functionality for Gemma-3."
,
)
parser
.
add_argument
(
"--disable-overlap-schedule"
,
action
=
"store_true"
,
...
...
python/sglang/srt/utils.py
View file @
04f2abcb
...
...
@@ -1971,6 +1971,7 @@ def is_fa3_default_architecture(hf_config):
"LlamaForCausalLM"
,
"MistralForCausalLM"
,
"Gemma2ForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
}
return
architectures
[
0
]
in
default_archs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment