Unverified Commit 0bc0bf57 authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

gemma3: impl `get_attention_sliding_window_size` for attn init (#4823)

parent f60f2931
...@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
# Adapted from: # Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def extract_layer_index(prefix: str) -> int: def extract_layer_index(prefix: str) -> int:
...@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module): ...@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
self.rope_scaling = {"rope_type": "default"} self.rope_scaling = {"rope_type": "default"}
# FIXME(mick): idk why vllm does this # FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window # self.sliding_window = config.interleaved_sliding_window
self.sliding_window = config.sliding_window self.sliding_window = get_attention_sliding_window_size(config)
else: else:
# Global attention. Use the values in config.json. # Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
...@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module): ...@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
logit_cap=getattr(self.config, "attn_logit_softcapping", None), logit_cap=getattr(self.config, "attn_logit_softcapping", None),
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size=self.sliding_window, sliding_window_size=self.sliding_window,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
) )
...@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel): ...@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype return next(self.parameters()).dtype
...@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel): ...@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
**kwargs, **kwargs,
) -> LogitsProcessor: ) -> LogitsProcessor:
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, **kwargs input_ids, positions, forward_batch, input_embeds, **kwargs
) )
......
...@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
def get_attention_sliding_window_size(self):
"""
This value is used to initialize attention backends in `ForwardBatch`.
"""
return self.language_model.get_attention_sliding_window_size()
def get_image_feature(self, image_input: MultimodalInputs): def get_image_feature(self, image_input: MultimodalInputs):
""" """
Projects the last hidden state from the vision model into language model space. Projects the last hidden state from the vision model into language model space.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment