Unverified Commit 5bfe0dea authored by qizixi's avatar qizixi Committed by GitHub
Browse files

[bug fix] Fix llama4 spec decoding (#22691)


Signed-off-by: default avatarqizixi <qizixi@meta.com>
Co-authored-by: default avatarLu Fang <30275821+houseroad@users.noreply.github.com>
parent 31fd3265
...@@ -195,7 +195,9 @@ class Llama4Attention(nn.Module): ...@@ -195,7 +195,9 @@ class Llama4Attention(nn.Module):
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
) if not self.nope else None ) if not self.nope else None
attn_cls = Attention if self.nope else ChunkedLocalAttention use_chunked_local_attn = not self.nope and config.attention_chunk_size
attn_cls = (ChunkedLocalAttention
if use_chunked_local_attn else Attention)
self.attn = attn_cls( self.attn = attn_cls(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -206,7 +208,7 @@ class Llama4Attention(nn.Module): ...@@ -206,7 +208,7 @@ class Llama4Attention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
**({ **({
"attention_chunk_size": config.attention_chunk_size "attention_chunk_size": config.attention_chunk_size
} if not self.nope else {})) } if use_chunked_local_attn else {}))
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale) floor = torch.floor((positions + 1.0) / self.floor_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment