Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0bc0bf57
Unverified
Commit
0bc0bf57
authored
Mar 27, 2025
by
Juwan Yoo
Committed by
GitHub
Mar 27, 2025
Browse files
gemma3: impl `get_attention_sliding_window_size` for attn init (#4823)
parent
f60f2931
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
2 deletions
+18
-2
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+12
-2
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+6
-0
No files found.
python/sglang/srt/models/gemma3_causal.py
View file @
0bc0bf57
...
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def
get_attention_sliding_window_size
(
config
):
return
config
.
sliding_window
-
1
# Adapted from:
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def
extract_layer_index
(
prefix
:
str
)
->
int
:
def
extract_layer_index
(
prefix
:
str
)
->
int
:
...
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
...
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
self
.
rope_scaling
=
{
"rope_type"
:
"default"
}
self
.
rope_scaling
=
{
"rope_type"
:
"default"
}
# FIXME(mick): idk why vllm does this
# FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window
# self.sliding_window = config.interleaved_sliding_window
self
.
sliding_window
=
config
.
sliding_window
self
.
sliding_window
=
get_attention_
sliding_window
_size
(
config
)
else
:
else
:
# Global attention. Use the values in config.json.
# Global attention. Use the values in config.json.
self
.
rope_theta
=
config
.
rope_theta
self
.
rope_theta
=
config
.
rope_theta
...
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
...
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
logit_cap
=
getattr
(
self
.
config
,
"attn_logit_softcapping"
,
None
),
logit_cap
=
getattr
(
self
.
config
,
"attn_logit_softcapping"
,
None
),
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size
=
self
.
sliding_window
,
sliding_window_size
=
self
.
sliding_window
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
)
...
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
...
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
return
self
.
model
.
embed_tokens
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
return
next
(
self
.
parameters
()).
dtype
return
next
(
self
.
parameters
()).
dtype
...
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
...
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
**
kwargs
,
**
kwargs
,
)
->
LogitsProcessor
:
)
->
LogitsProcessor
:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
**
kwargs
input_ids
,
positions
,
forward_batch
,
input_embeds
,
**
kwargs
)
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
0bc0bf57
...
@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
get_input_embeddings
()
return
self
.
language_model
.
get_input_embeddings
()
def
get_attention_sliding_window_size
(
self
):
"""
This value is used to initialize attention backends in `ForwardBatch`.
"""
return
self
.
language_model
.
get_attention_sliding_window_size
()
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
):
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
):
"""
"""
Projects the last hidden state from the vision model into language model space.
Projects the last hidden state from the vision model into language model space.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment