"ci/install_dependencies_python2.sh" did not exist on "2ad4126fcba61ed039a21c19e73713d33716243f"
Unverified Commit 0bc0bf57 authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

gemma3: impl `get_attention_sliding_window_size` for attn init (#4823)

parent f60f2931
......@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
from sglang.srt.utils import add_prefix, make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def extract_layer_index(prefix: str) -> int:
......@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
self.rope_scaling = {"rope_type": "default"}
# FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window
self.sliding_window = config.sliding_window
self.sliding_window = get_attention_sliding_window_size(config)
else:
# Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta
......@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size=self.sliding_window,
prefix=add_prefix("attn", prefix),
)
......@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
......@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
input_embeds: torch.Tensor = None,
**kwargs,
) -> LogitsProcessor:
hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, **kwargs
)
......
......@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_attention_sliding_window_size(self):
"""
This value is used to initialize attention backends in `ForwardBatch`.
"""
return self.language_model.get_attention_sliding_window_size()
def get_image_feature(self, image_input: MultimodalInputs):
"""
Projects the last hidden state from the vision model into language model space.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment