Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0bc0bf57
"ci/install_dependencies_python2.sh" did not exist on "2ad4126fcba61ed039a21c19e73713d33716243f"
Unverified
Commit
0bc0bf57
authored
Mar 27, 2025
by
Juwan Yoo
Committed by
GitHub
Mar 27, 2025
Browse files
gemma3: impl `get_attention_sliding_window_size` for attn init (#4823)
parent
f60f2931
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
2 deletions
+18
-2
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+12
-2
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+6
-0
No files found.
python/sglang/srt/models/gemma3_causal.py
View file @
0bc0bf57
...
...
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
from
sglang.srt.utils
import
add_prefix
,
make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def
get_attention_sliding_window_size
(
config
):
return
config
.
sliding_window
-
1
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def
extract_layer_index
(
prefix
:
str
)
->
int
:
...
...
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
self
.
rope_scaling
=
{
"rope_type"
:
"default"
}
# FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window
self
.
sliding_window
=
config
.
sliding_window
self
.
sliding_window
=
get_attention_
sliding_window
_size
(
config
)
else
:
# Global attention. Use the values in config.json.
self
.
rope_theta
=
config
.
rope_theta
...
...
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
logit_cap
=
getattr
(
self
.
config
,
"attn_logit_softcapping"
,
None
),
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size
=
self
.
sliding_window
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
def
dtype
(
self
)
->
torch
.
dtype
:
return
next
(
self
.
parameters
()).
dtype
...
...
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
input_embeds
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
LogitsProcessor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
**
kwargs
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
0bc0bf57
...
...
@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
get_input_embeddings
()
def
get_attention_sliding_window_size
(
self
):
"""
This value is used to initialize attention backends in `ForwardBatch`.
"""
return
self
.
language_model
.
get_attention_sliding_window_size
()
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
):
"""
Projects the last hidden state from the vision model into language model space.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment