Unverified Commit 827fad66 authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Improve performance of NPU FA (#12260)


Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 9b721db2
...@@ -955,12 +955,13 @@ def _native_npu_attention( ...@@ -955,12 +955,13 @@ def _native_npu_attention(
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: Optional[float] = None, scale: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return npu_fusion_attention( query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query, query,
key, key,
value, value,
query.size(2), # num_heads query.size(1), # num_heads
input_layout="BSND", input_layout="BNSD",
pse=None, pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536, pre_tockens=65536,
...@@ -969,6 +970,8 @@ def _native_npu_attention( ...@@ -969,6 +970,8 @@ def _native_npu_attention(
sync=False, sync=False,
inner_precise=0, inner_precise=0,
)[0] )[0]
out = out.transpose(1, 2).contiguous()
return out
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 # Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment