Unverified Commit 7a1c4025 authored by Ignacio Sica's avatar Ignacio Sica Committed by GitHub
Browse files

[Kernel] [CPU] refactor `cpu_attn.py:_run_sdpa_forward` for better memory access (#24701)


Signed-off-by: default avatarignaciosica <mignacio.sica@gmail.com>
parent 60a09519
......@@ -641,10 +641,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata,
attn_type: str = AttentionType.DECODER,
) -> None:
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
......@@ -665,6 +661,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment