flash_attn.py 5.04 KB
Newer Older
fuheaven's avatar
fuheaven committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from loguru import logger

from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER

# Try to import Flash Attention (ROCm version 2.6.1)
try:
    from flash_attn import flash_attn_varlen_func

    FLASH_ATTN_AVAILABLE = True
    logger.info(f"Flash Attention (ROCm) is available")
except ImportError:
    logger.warning("Flash Attention not found. Will use PyTorch SDPA as fallback.")
    flash_attn_varlen_func = None
    FLASH_ATTN_AVAILABLE = False


@PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_dcu")
class FlashAttnDcu(AttnWeightTemplate):
    """
    DCU Flash Attention implementation.
fuheaven's avatar
fuheaven committed
23

fuheaven's avatar
fuheaven committed
24
25
    Uses AMD ROCm version of Flash Attention 2.6.1 when available.
    Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.
fuheaven's avatar
fuheaven committed
26

fuheaven's avatar
fuheaven committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    Tested Environment:
    - PyTorch: 2.7.1
    - Python: 3.10
    - Flash Attention: 2.6.1 (ROCm)
    Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
    """

    def __init__(self, weight_name="flash_attn_dcu"):
        super().__init__(weight_name)
        self.use_flash_attn = FLASH_ATTN_AVAILABLE

        if self.use_flash_attn:
            logger.info("Flash Attention 2.6.1 (ROCm) is available and will be used.")
        else:
            logger.warning("Flash Attention not available. Using PyTorch SDPA fallback.")

    def apply(
        self,
        q,
        k,
        v,
        q_lens=None,
        k_lens=None,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=(-1, -1),
        deterministic=False,
    ):
        """
        Execute Flash Attention computation.
        Args:
            q: [B, Lq, Nq, C1] Query tensor
            k: [B, Lk, Nk, C1] Key tensor
            v: [B, Lk, Nk, C2] Value tensor
            q_lens: [B] Optional sequence lengths for queries
            k_lens: [B] Optional sequence lengths for keys
            dropout_p: Dropout probability
            softmax_scale: Scaling factor for QK^T before softmax
            causal: Whether to apply causal mask
            window_size: Sliding window size tuple (left, right)
            deterministic: Whether to use deterministic algorithm
        Returns:
            Output tensor: [B, Lq, Nq, C2]
        """
        if not self.use_flash_attn:
            # Fallback to PyTorch SDPA
            return self._sdpa_fallback(q, k, v, causal, dropout_p)

        # Ensure data types are half precision
        half_dtypes = (torch.float16, torch.bfloat16)
        dtype = q.dtype if q.dtype in half_dtypes else torch.bfloat16
        out_dtype = q.dtype

        b, lq, lk = q.size(0), q.size(1), k.size(1)

        def half(x):
            return x if x.dtype in half_dtypes else x.to(dtype)

        # Preprocess query
        if q_lens is None:
            q_flat = half(q.flatten(0, 1))
            q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
        else:
            q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))

        # Preprocess key/value
        if k_lens is None:
            k_flat = half(k.flatten(0, 1))
            v_flat = half(v.flatten(0, 1))
            k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device)
        else:
            k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
            v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))

        # Compute cumulative sequence lengths
        cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
        cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)

        # Use Flash Attention 2.6.1 (ROCm version)
        output = flash_attn_varlen_func(
            q=q_flat,
            k=k_flat,
            v=v_flat,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=lq,
            max_seqlen_k=lk,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size=window_size,
            deterministic=deterministic,
        )

        # Reshape back to batch dimension
        output = output.unflatten(0, (b, lq))
        return output.to(out_dtype)

    def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
        """
        Fallback to PyTorch Scaled Dot Product Attention.
        Args:
            q: [B, Lq, Nq, C] Query tensor
            k: [B, Lk, Nk, C] Key tensor
            v: [B, Lk, Nk, C] Value tensor
            causal: Whether to apply causal mask
            dropout_p: Dropout probability
        Returns:
            Output tensor: [B, Lq, Nq, C]
        """
        # Transpose to [B, Nq, Lq, C] for SDPA
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

fuheaven's avatar
fuheaven committed
143
        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p)
fuheaven's avatar
fuheaven committed
144
145
146

        # Transpose back to [B, Lq, Nq, C]
        return out.transpose(1, 2).contiguous()