flash_attn.py 5.04 KB
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from loguru import logger

from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER

# Try to import Flash Attention (ROCm version 2.6.1)
try:
    from flash_attn import flash_attn_varlen_func

    FLASH_ATTN_AVAILABLE = True
    logger.info(f"Flash Attention (ROCm) is available")
except ImportError:
    logger.warning("Flash Attention not found. Will use PyTorch SDPA as fallback.")
    flash_attn_varlen_func = None
    FLASH_ATTN_AVAILABLE = False


@PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_dcu")
class FlashAttnDcu(AttnWeightTemplate):
    """
    DCU Flash Attention implementation.

    Uses AMD ROCm version of Flash Attention 2.6.1 when available.
    Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.

    Tested Environment:
    - PyTorch: 2.7.1
    - Python: 3.10
    - Flash Attention: 2.6.1 (ROCm)
    Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
    """

    def __init__(self, weight_name="flash_attn_dcu"):
        super().__init__(weight_name)
        self.use_flash_attn = FLASH_ATTN_AVAILABLE

        if self.use_flash_attn:
            logger.info("Flash Attention 2.6.1 (ROCm) is available and will be used.")
        else:
            logger.warning("Flash Attention not available. Using PyTorch SDPA fallback.")

    def apply(
        self,
        q,
        k,
        v,
        q_lens=None,
        k_lens=None,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=(-1, -1),
        deterministic=False,
    ):
        """
        Execute Flash Attention computation.
        Args:
            q: [B, Lq, Nq, C1] Query tensor
            k: [B, Lk, Nk, C1] Key tensor
            v: [B, Lk, Nk, C2] Value tensor
            q_lens: [B] Optional sequence lengths for queries
            k_lens: [B] Optional sequence lengths for keys
            dropout_p: Dropout probability
            softmax_scale: Scaling factor for QK^T before softmax
            causal: Whether to apply causal mask
            window_size: Sliding window size tuple (left, right)
            deterministic: Whether to use deterministic algorithm
        Returns:
            Output tensor: [B, Lq, Nq, C2]
        """
        if not self.use_flash_attn:
            # Fallback to PyTorch SDPA
            return self._sdpa_fallback(q, k, v, causal, dropout_p)

        # Ensure data types are half precision
        half_dtypes = (torch.float16, torch.bfloat16)
        dtype = q.dtype if q.dtype in half_dtypes else torch.bfloat16
        out_dtype = q.dtype

        b, lq, lk = q.size(0), q.size(1), k.size(1)

        def half(x):
            return x if x.dtype in half_dtypes else x.to(dtype)

        # Preprocess query
        if q_lens is None:
            q_flat = half(q.flatten(0, 1))
            q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
        else:
            q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))

        # Preprocess key/value
        if k_lens is None:
            k_flat = half(k.flatten(0, 1))
            v_flat = half(v.flatten(0, 1))
            k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device)
        else:
            k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
            v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))

        # Compute cumulative sequence lengths
        cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
        cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)

        # Use Flash Attention 2.6.1 (ROCm version)
        output = flash_attn_varlen_func(
            q=q_flat,
            k=k_flat,
            v=v_flat,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=lq,
            max_seqlen_k=lk,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size=window_size,
            deterministic=deterministic,
        )

        # Reshape back to batch dimension
        output = output.unflatten(0, (b, lq))
        return output.to(out_dtype)

    def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
        """
        Fallback to PyTorch Scaled Dot Product Attention.
        Args:
            q: [B, Lq, Nq, C] Query tensor
            k: [B, Lk, Nk, C] Key tensor
            v: [B, Lk, Nk, C] Value tensor
            causal: Whether to apply causal mask
            dropout_p: Dropout probability
        Returns:
            Output tensor: [B, Lq, Nq, C]
        """
        # Transpose to [B, Nq, Lq, C] for SDPA
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p)

        # Transpose back to [B, Lq, Nq, C]
        return out.transpose(1, 2).contiguous()