Unverified Commit 9266d980 authored by Yong Hoon Shin's avatar Yong Hoon Shin Committed by GitHub
Browse files

[BugFix] Fix interleaved sliding window not set for Gemma3n (#21863)


Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
parent 176bbce1
......@@ -723,11 +723,16 @@ class ModelConfig:
)
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
# when Gemma 2 is fixed in Transformers.
# attention, but it's not specified in its config.
# TODO: remove this when Gemma 2 config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2
# TODO: remove this when Gemma 3n config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma3n_text":
# 4 sliding window attention followed by 1 full attention
self.hf_text_config.sliding_window_pattern = "LLLLG"
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
......
......@@ -297,8 +297,13 @@ class Gemma3nAttention(nn.Module):
has_weight=False)
layer_idx = extract_layer_index(prefix)
if config.layer_types[layer_idx] == "sliding_attention":
self.sliding_window = config.sliding_window
is_sliding_window = (
getattr(config, "interleaved_sliding_window", None) is not None
and config.layer_types[layer_idx] == "sliding_attention")
if is_sliding_window:
self.sliding_window = config.interleaved_sliding_window
rope_theta = config.rope_local_base_freq
rope_scaling = {"rope_type": "default"}
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment