Unverified Commit 512f5221 authored by Luciano Martins's avatar Luciano Martins Committed by GitHub
Browse files

[Model] Gemma4: add bidirectional vision attention for sliding layers with window guard (#40534)


Signed-off-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: default avatarLuciano Martins <lucianomartins@google.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 4c34b2f6
......@@ -969,6 +969,16 @@ class Gemma4ForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors
)
# --- Precompute full-attention layer indices for bidi clearing ---
self._full_attn_layer_idxs: frozenset[int] = frozenset()
text_config = config.text_config
if getattr(text_config, "use_bidirectional_attention", None) == "vision":
layer_types = getattr(text_config, "layer_types", None)
if layer_types:
self._full_attn_layer_idxs = frozenset(
i for i, lt in enumerate(layer_types) if lt != "sliding_attention"
)
# --- MixtureOfExperts delegation to language_model ---
self.expert_weights = self.language_model.expert_weights
self.moe_layers = self.language_model.moe_layers
......@@ -1310,6 +1320,12 @@ class Gemma4ForConditionalGeneration(
else None
)
# Gemma4 bidi: clear mm_prefix_range for full_attention layers.
# Must run here (outside @support_torch_compile boundary) because
# _run_decoder_layers is inside a compiled graph where Python
# side effects are eliminated.
self._clear_mm_prefix_for_full_attn_layers()
hidden_states = self.language_model.model(
input_ids,
positions,
......@@ -1327,6 +1343,49 @@ class Gemma4ForConditionalGeneration(
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
# ------------------------------------------------------------------ #
# Bidirectional attention helpers
# ------------------------------------------------------------------ #
def _clear_mm_prefix_for_full_attn_layers(self) -> None:
"""Clear mm_prefix_range for non-sliding layers.
Gemma4 with use_bidirectional_attention='vision' applies
bidirectional attention only to sliding_attention layers.
Full attention layers use plain causal masking.
Uses _full_attn_layer_idxs (precomputed in __init__) for O(1)
lookup instead of per-call regex parsing.
"""
if not self._full_attn_layer_idxs:
return
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
return
def _process(metadata_dict: dict) -> None:
for layer_name, metadata in metadata_dict.items():
if ".layers." not in layer_name:
continue
try:
layer_idx = int(layer_name.split(".layers.")[1].split(".")[0])
except (ValueError, IndexError):
continue
if layer_idx in self._full_attn_layer_idxs:
if hasattr(metadata, "mm_prefix_range"):
metadata.mm_prefix_range = None
if hasattr(metadata, "mm_prefix_range_tensor"):
metadata.mm_prefix_range_tensor = None
if isinstance(attn_metadata, list):
for ub_metadata in attn_metadata:
_process(ub_metadata)
elif isinstance(attn_metadata, dict):
_process(attn_metadata)
# ------------------------------------------------------------------ #
# Weight loading
# ------------------------------------------------------------------ #
......
......@@ -2314,13 +2314,26 @@ class GPUModelRunner(
if self.is_mm_prefix_lm:
req_doc_ranges = {}
# Gemma4 bidi: skip ranges that exceed the sliding
# window. When image tokens > sliding_window, bidi causes
# early image tokens to attend to the entire image
# (e.g. 6 → 1092 targets), degrading spatial precision.
# Per-range filtering keeps bidi for small images/video
# frames while skipping oversized images.
hf_text_config = self.model_config.hf_text_config
_bidi_sw = getattr(hf_text_config, "sliding_window", None)
for req_id in self.input_batch.req_ids:
image_doc_ranges = []
req_state = self.requests[req_id]
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
img_doc_range = pos_info.extract_embeds_range()
image_doc_ranges.extend(img_doc_range)
for r in img_doc_range:
if _bidi_sw is not None and (r[1] - r[0] + 1) > _bidi_sw:
continue
image_doc_ranges.append(r)
req_idx = self.input_batch.req_id_to_index[req_id]
req_doc_ranges[req_idx] = image_doc_ranges
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment