flash_attention_npu.py 1.88 KB
Newer Older
1
2
3
4
5
6
7
from ..base_extension import _Extension


class FlashAttentionNpuExtension(_Extension):
    def __init__(self):
        super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)

8
    def is_available(self) -> bool:
9
        try:
10
            import torch_npu
11

12
            return hasattr(torch_npu, "npu_fusion_attention")
13
14
15
        except:
            return False

16
    def assert_compatible(self) -> bool:
17
18
19
20
21
22
23
24
25
26
27
28
29
        pass

    def build_aot(self) -> None:
        raise NotImplementedError(
            "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
        )

    def build_jit(self) -> None:
        raise NotImplementedError(
            "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
        )

    def load(self):
30
31
        from typing import Optional

32
        import torch
33
        import torch_npu
34

35
        def flash_attention(
36
37
38
39
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            dropout_p: float = 0.0,
40
41
42
43
44
45
46
47
48
            scale: Optional[float] = None,
            attention_mask: Optional[torch.Tensor] = None,
            is_causal: bool = False,
            cu_seqlens_q: Optional[torch.Tensor] = None,
            cu_seqlens_kv: Optional[torch.Tensor] = None,
            max_seqlen_q: Optional[int] = None,
            max_seqlen_kv: Optional[int] = None,
            q_indices: Optional[torch.Tensor] = None,
            kv_indices: Optional[torch.Tensor] = None,
49
        ):
50
51
            num_heads = q.size(1)
            return torch_npu.npu_fusion_attention(
52
53
54
                q,
                k,
                v,
55
56
57
                num_heads,
                "BNSD",
                atten_mask=attention_mask.bool(),
58
                scale=scale,
59
60
                keep_prob=1 - dropout_p,
            )[0]
61

62
        return flash_attention