Unverified Commit cd7e32e2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Optimize attention in llama4 (#5127)

parent 88799448
...@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module): ...@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module):
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale) floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1) return attn_scale.unsqueeze(-1)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _mul_attn_scale(self, positions, q):
attn_scale = self._get_attn_scale(positions)
return (q * attn_scale).to(q.dtype)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module): ...@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None: if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k) q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
assert (q_out_unused is q_view) and (k_out_unused is k_view)
del q_view, k_view, q_out_unused, k_out_unused
if self.qk_norm is not None: if self.qk_norm is not None:
# TODO: support float # TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
q = q.reshape(-1, self.head_dim).contiguous().bfloat16() qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
k = k.reshape(-1, self.head_dim).contiguous().bfloat16() qk = self.qk_norm(qk).to(torch.bfloat16)
q = self.qk_norm(q).to(q.dtype) qk = qk.reshape(-1, self.q_size + self.kv_size)
k = self.qk_norm(k).to(k.dtype)
q = q.reshape(-1, self.q_size) q, k = qk.split([self.q_size, self.kv_size], dim=-1)
k = k.reshape(-1, self.kv_size)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context # the inference-time temperature tuning function is customized to not affect short context
# while working at very long context # while working at very long context
# https://arxiv.org/abs/2501.19399 # https://arxiv.org/abs/2501.19399
if self.attn_temperature_tuning and not self.use_rope: if self.attn_temperature_tuning and not self.use_rope:
attn_scale = self._get_attn_scale(positions) q = self._mul_attn_scale(positions=positions, q=q)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v, forward_batch) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment