Unverified Commit cd7e32e2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Optimize attention in llama4 (#5127)

parent 88799448
......@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module):
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _mul_attn_scale(self, positions, q):
attn_scale = self._get_attn_scale(positions)
return (q * attn_scale).to(q.dtype)
def forward(
self,
positions: torch.Tensor,
......@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
assert (q_out_unused is q_view) and (k_out_unused is k_view)
del q_view, k_view, q_out_unused, k_out_unused
if self.qk_norm is not None:
# TODO: support float
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
q = self.qk_norm(q).to(q.dtype)
k = self.qk_norm(k).to(k.dtype)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
qk = self.qk_norm(qk).to(torch.bfloat16)
qk = qk.reshape(-1, self.q_size + self.kv_size)
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if self.attn_temperature_tuning and not self.use_rope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
q = self._mul_attn_scale(positions=positions, q=q)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment