flash_attn.py 8.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from functools import lru_cache
from typing import Optional, Union

import torch
import torch.nn as nn

# try:
#     from sgl_kernel import flash_ops
# except:
#     raise ImportError("Can not import sgl_kernel. Please check your installation.")

try:
    from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
except ImportError:
    flash_attn_varlen_func_v4 = None


@lru_cache(maxsize=1)
def is_fa3_supported(device=None) -> bool:
    #  There some fa3 FYI
    #  FA3 can fail without a enough shared memory for a some shapes, such as higher
    #  hidden_dim or some special cases.
    #  Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
    #  Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
    #  https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
    #  And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
    #  That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
    return (torch.version.cuda >= "12.3") and (
        torch.cuda.get_device_capability(device)[0] == 9
        or torch.cuda.get_device_capability(device)[0] == 8
    )


def maybe_contiguous(x):
    return x.contiguous() if x is not None and x.stride(-1) != 1 else x



def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
    qv=None,
    rotary_cos=None,
    rotary_sin=None,
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
    cache_batch_idx: Optional[torch.Tensor] = None,
    cache_leftpad: Optional[torch.Tensor] = None,
    page_table: Optional[torch.Tensor] = None,
    cu_seqlens_q: Optional[torch.Tensor] = None,
    cu_seqlens_k_new: Optional[torch.Tensor] = None,
    max_seqlen_q: Optional[int] = None,
    rotary_seqlens: Optional[torch.Tensor] = None,
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    softcap=0.0,  # 0.0 means deactivated
    rotary_interleaved=True,
    scheduler_metadata=None,
    num_splits=0,  # Can be tuned for speed
    pack_gqa=None,  # Can be tuned for speed
    sm_margin=0,  # Can be tuned if some SMs are used for communication
    return_softmax_lse=False,
    sinks=None,
    ver=3,
):
    if ver == 4:
        raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")

    # HIP环境检测和回退
76
77
78
    if hasattr(torch.version, 'hip') and torch.version.hip is not None:
        # 简单PyTorch回退,处理实际的张量形状
        # q: [1, 4, 256], k_cache: [411528, 1, 1, 256], v_cache: [411528, 1, 1, 256]
79
        
80
81
        if softmax_scale is None:
            softmax_scale = (q.shape[-1]) ** (-0.5)
82
        
83
84
85
86
        # 重塑以匹配attention计算
        q_reshaped = q.unsqueeze(1)  # [1, 1, 4, 256]
        k_reshaped = k_cache.squeeze(1).squeeze(1)  # [411528, 256]  
        v_reshaped = v_cache.squeeze(1).squeeze(1)  # [411528, 256]
87
        
88
89
90
91
        # 简单的点积attention
        scores = torch.matmul(q, k_reshaped.T) * softmax_scale  # [1, 4, 411528]
        attn_weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn_weights, v_reshaped)  # [1, 4, 256]
92
        
93
94
95
96
        if return_softmax_lse:
            softmax_lse = torch.zeros(1, 4, 1, device=q.device)
            return out, softmax_lse
        return out
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

    # 原始sgl_kernel实现
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    if softmax_scale is None:
        softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
        cache_seqlens = maybe_contiguous(cache_seqlens)

    q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
    v_cache = (
        v_cache.contiguous()
        if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
        else v_cache
    )
    cu_seqlens_q, cu_seqlens_k_new = [
        maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
    ]
    page_table, cache_batch_idx, cache_leftpad = [
        maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
    ]
    rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
    rotary_seqlens = maybe_contiguous(rotary_seqlens)

    out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
        q,
        k_cache,
        v_cache,
        k,
        v,
        qv,
        None,  # out
        cu_seqlens_q,
        None,  # cu_seqlens_k
        cu_seqlens_k_new,
        None,  # seqused_q
        cache_seqlens,
        max_seqlen_q,
        None,  # max_seqlen_k
        page_table,
        cache_batch_idx,
        cache_leftpad,
        rotary_cos,
        rotary_sin,
        rotary_seqlens,
        q_descale,
        k_descale,
        v_descale,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        softcap,
        rotary_interleaved,
        scheduler_metadata,
        num_splits,
        pack_gqa,
        sm_margin,
        sinks,
    )
    return (out, softmax_lse) if return_softmax_lse else out


def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    seqused_q=None,
    seqused_k=None,
    softmax_scale=None,
    causal=False,
    qv=None,
    q_descale=None,
    k_descale=None,
    v_descale=None,
    window_size=(-1, -1),
    softcap=0.0,
    num_splits=1,
    pack_gqa=None,
    sm_margin=0,
    return_softmax_lse=False,
    sinks=None,
    ver=3,
):
    if ver == 4:
        assert (
            flash_attn_varlen_func_v4 is not None
        ), "FA4 is not available, please check your installation."
        # Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
        if window_size == (-1, -1):
            window_size = (None, None)
        return flash_attn_varlen_func_v4(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            # max_seqlen_q,
            # max_seqlen_k,
            seqused_q=seqused_q,
            seqused_k=seqused_k,
            softmax_scale=softmax_scale,
            causal=causal,
            # qv=qv,
            # q_descale=q_descale,
            # k_descale=k_descale,
            # v_descale=v_descale,
            window_size=window_size,
            softcap=softcap,
            # num_splits=num_splits,
            pack_gqa=pack_gqa,
            # sm_margin=sm_margin,
            return_softmax_lse=return_softmax_lse,
            learnable_sink=sinks,
        )

    if not is_fa3_supported():
        raise NotImplementedError(
            "flash_attn at sgl-kernel is only supported on sm90 and above"
        )

    if softmax_scale is None:
        softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
            -0.5
        )

    out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
        q,
        k,
        v,
        None,  # k_new
        None,  # v_new
        qv,  # qv
        None,  # out
        cu_seqlens_q,
        cu_seqlens_k,
        None,  # cu_seqlens_k_new
        seqused_q,
        seqused_k,
        max_seqlen_q,
        max_seqlen_k,
        None,  # page_table,
        None,  # kv_batch_idx
        None,  # leftpad_k
        None,  # rotary cos
        None,  # rotary sin
        None,  # seqlens_rotary
        q_descale,
        k_descale,
        v_descale,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        softcap,
        is_rotary_interleaved=False,
        scheduler_metadata=None,
        num_splits=num_splits,
        pack_gqa=pack_gqa,
        sm_margin=sm_margin,
        sinks=sinks,
    )

    return (out, softmax_lse, *rest) if return_softmax_lse else out