flash_attention_npu.py 2.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from ..base_extension import _Extension


class FlashAttentionNpuExtension(_Extension):
    def __init__(self):
        super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)

    def is_hardware_available(self) -> bool:
        try:
            import torch_npu  # noqa

            return True
        except:
            return False

    def assert_hardware_compatible(self) -> bool:
        pass

    def build_aot(self) -> None:
        raise NotImplementedError(
            "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
        )

    def build_jit(self) -> None:
        raise NotImplementedError(
            "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
        )

    def load(self):
        import torch
        from einops import rearrange

        def npu_sdpa_attention(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            seq_len_info_q=None,
            seq_len_info_kv=None,
            origin_attn_mask: torch.Tensor = None,
            dropout_p: float = 0.0,
            scale: float = 1.0,
            causal=None,
            padded=None,
        ):
            """
            The scaled dot product attention.

            Arguments:
                q: (batch, q_seqlen, nheads, headdim)
                k: (batch, kv_seqlen, nheads, headdim)
                v: (batch, kv_seqlen, nheads, headdim)
                batch_size: int.
                seq_len: int.
                dropout_p: float. Dropout probability.
                scale: float. The scaling of QK^T before applying softmax.
                    Default to 1.
            Return:
                attn_out: (batch, q_seqlen, nheads, headdim).
            """
            q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
            output = torch.nn.functional.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=origin_attn_mask,
                dropout_p=dropout_p,
                is_causal=origin_attn_mask is None,
                scale=scale,
            )
            output = rearrange(output, "b h s d -> b s (h d)")
            return output

        return npu_sdpa_attention