test_flash_attn.py 71.5 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
import math

Tri Dao's avatar
Tri Dao committed
3
import pytest
Tri Dao's avatar
Tri Dao committed
4
5
6
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
7
8
9
10
11
12
13
from flash_attn import (
    flash_attn_func,
    flash_attn_kvpacked_func,
    flash_attn_qkvpacked_func,
    flash_attn_varlen_func,
    flash_attn_varlen_kvpacked_func,
    flash_attn_varlen_qkvpacked_func,
Tri Dao's avatar
Tri Dao committed
14
    flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
15
)
16
from flash_attn.bert_padding import pad_input, unpad_input
Tri Dao's avatar
Tri Dao committed
17
18
19
from flash_attn.flash_attn_interface import _get_block_size

MAX_HEADDIM_SM8x = 192
Tri Dao's avatar
Tri Dao committed
20

Tri Dao's avatar
Tri Dao committed
21

Tri Dao's avatar
Tri Dao committed
22
23
24
25
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
Tri Dao's avatar
Tri Dao committed
26
27


Tri Dao's avatar
Tri Dao committed
28
29
30
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
    assert mode in ["full", "random", "third"]
    if mode == "full":
Tri Dao's avatar
Tri Dao committed
31
        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
Tri Dao's avatar
Tri Dao committed
32
    elif mode == "random":
33
34
35
        lengths = torch.randint(
            max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
        )
Tri Dao's avatar
Tri Dao committed
36
    elif mode == "third":
37
        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
Tri Dao's avatar
Tri Dao committed
38
39
40
    padding_mask = (
        repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
    )
Tri Dao's avatar
Tri Dao committed
41
42
43
    return padding_mask


Tri Dao's avatar
Tri Dao committed
44
45
46
def generate_qkv(
    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
Tri Dao's avatar
Tri Dao committed
47
48
    """
    Arguments:
Tri Dao's avatar
Tri Dao committed
49
50
51
        q: (batch_size, seqlen_q, nheads, d)
        k: (batch_size, seqlen_k, nheads_k, d)
        v: (batch_size, seqlen_k, nheads_k, d)
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        query_padding_mask: (batch_size, seqlen), bool
        key_padding_mask: (batch_size, seqlen), bool
    """
    assert not (kvpacked and qkvpacked)
Tri Dao's avatar
Tri Dao committed
56
57
58
59
    batch_size, seqlen_q, nheads, d = q.shape
    _, seqlen_k, nheads_k, _ = k.shape
    assert k.shape == (batch_size, seqlen_k, nheads_k, d)
    assert v.shape == (batch_size, seqlen_k, nheads_k, d)
Tri Dao's avatar
Tri Dao committed
60
61
62

    if query_padding_mask is not None:
        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
Tri Dao's avatar
Tri Dao committed
63
64
65
        output_pad_fn = lambda output_unpad: pad_input(
            output_unpad, indices_q, batch_size, seqlen_q
        )
Tri Dao's avatar
Tri Dao committed
66
    else:
Tri Dao's avatar
Tri Dao committed
67
68
69
70
        q_unpad = rearrange(q, "b s h d -> (b s) h d")
        cu_seqlens_q = torch.arange(
            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
        )
Tri Dao's avatar
Tri Dao committed
71
        max_seqlen_q = seqlen_q
Tri Dao's avatar
Tri Dao committed
72
73
74
        output_pad_fn = lambda output_unpad: rearrange(
            output_unpad, "(b s) h d -> b s h d", b=batch_size
        )
Tri Dao's avatar
Tri Dao committed
75
76
77
78
79

    if key_padding_mask is not None:
        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
    else:
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
        k_unpad = rearrange(k, "b s h d -> (b s) h d")
        v_unpad = rearrange(v, "b s h d -> (b s) h d")
        cu_seqlens_k = torch.arange(
            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
        )
Tri Dao's avatar
Tri Dao committed
85
        max_seqlen_k = seqlen_k
Tri Dao's avatar
Tri Dao committed
86
87
88

    if qkvpacked:
        assert (query_padding_mask == key_padding_mask).all()
Tri Dao's avatar
Tri Dao committed
89
        assert nheads == nheads_k
Tri Dao's avatar
Tri Dao committed
90
        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
Tri Dao's avatar
Tri Dao committed
91
        qkv = torch.stack([q, k, v], dim=2)
Tri Dao's avatar
Tri Dao committed
92
        if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
93
            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
Tri Dao's avatar
Tri Dao committed
94
        else:
Tri Dao's avatar
Tri Dao committed
95
96
97
98
99
100
101
102
103
104
105
            dqkv_pad_fn = lambda dqkv_unpad: rearrange(
                dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
            )
        return (
            qkv_unpad.detach().requires_grad_(),
            cu_seqlens_q,
            max_seqlen_q,
            qkv.detach().requires_grad_(),
            output_pad_fn,
            dqkv_pad_fn,
        )
Tri Dao's avatar
Tri Dao committed
106
107
    elif kvpacked:
        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
Tri Dao's avatar
Tri Dao committed
108
        kv = torch.stack([k, v], dim=2)
Tri Dao's avatar
Tri Dao committed
109
110
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
111
            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
Tri Dao's avatar
Tri Dao committed
112
        else:
Tri Dao's avatar
Tri Dao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            dkv_pad_fn = lambda dkv_unpad: rearrange(
                dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
            )
        return (
            q_unpad.detach().requires_grad_(),
            kv_unpad.detach().requires_grad_(),
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q.detach().requires_grad_(),
            kv.detach().requires_grad_(),
            output_pad_fn,
            dq_pad_fn,
            dkv_pad_fn,
        )
Tri Dao's avatar
Tri Dao committed
129
130
131
    else:
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
132
            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
Tri Dao's avatar
Tri Dao committed
133
        else:
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
        return (
            q_unpad.detach().requires_grad_(),
            k_unpad.detach().requires_grad_(),
            v_unpad.detach().requires_grad_(),
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q.detach().requires_grad_(),
            k.detach().requires_grad_(),
            v.detach().requires_grad_(),
            output_pad_fn,
            dq_pad_fn,
            dk_pad_fn,
        )
Tri Dao's avatar
Tri Dao committed
150
151


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None,
                          device=None):
    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
    sk = (
        seqlen_k
        if key_padding_mask is None
        else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
    )
    sq = (
        seqlen_q
        if query_padding_mask is None
        else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
    )
    return col_idx > row_idx + sk - sq


Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
174
175
176
177
178
179
180
def attention_ref(
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
    upcast=True,
    reorder_ops=False,
):
Tri Dao's avatar
Tri Dao committed
181
182
183
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
Tri Dao's avatar
Tri Dao committed
184
185
        k: (batch_size, seqlen_k, nheads_k, head_dim)
        v: (batch_size, seqlen_k, nheads_k, head_dim)
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
            output back to fp16/bf16.
        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
            without changing the math. This is to estimate the numerical error from operation
            reordering.
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
    """
    dtype_og = q.dtype
    if upcast:
        q, k, v = q.float(), k.float(), v.float()
    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
Tri Dao's avatar
Tri Dao committed
203
204
    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
Tri Dao's avatar
Tri Dao committed
205
206
    d = q.shape[-1]
    if not reorder_ops:
Tri Dao's avatar
Tri Dao committed
207
        scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
Tri Dao's avatar
Tri Dao committed
208
    else:
Tri Dao's avatar
Tri Dao committed
209
        scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
Tri Dao's avatar
Tri Dao committed
210
    if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
211
        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
Tri Dao's avatar
Tri Dao committed
212
    if causal:
213
214
215
216
217
        # causal_mask = torch.triu(
        #     torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
        # )
        causal_mask = construct_causal_mask(
            seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device
Tri Dao's avatar
Tri Dao committed
218
219
        )
        scores.masked_fill_(causal_mask, float("-inf"))
Tri Dao's avatar
Tri Dao committed
220
    attention = torch.softmax(scores, dim=-1)
221
222
    if causal:  # Some rows are completely masked out so we fill them with zero instead of NaN
        attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
Tri Dao's avatar
Tri Dao committed
223
224
225
226
227
    dropout_scaling = 1.0 / (1 - dropout_p)
    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    if dropout_mask is not None:
        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
Tri Dao's avatar
Tri Dao committed
228
229
    else:
        attention_drop = attention
Tri Dao's avatar
Tri Dao committed
230
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
Tri Dao's avatar
Tri Dao committed
231
    if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
232
233
        output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
        attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
234
235
236
    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


Tri Dao's avatar
Tri Dao committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def attention_kvpacked_ref(
    q,
    kv,
    query_padding_mask=None,
    key_padding_mask=None,
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
    upcast=True,
    reorder_ops=False,
):
    return attention_ref(
        q,
        kv[:, :, 0],
        kv[:, :, 1],
        query_padding_mask,
        key_padding_mask,
        dropout_p,
        dropout_mask,
        upcast=upcast,
        causal=causal,
        reorder_ops=reorder_ops,
    )
Tri Dao's avatar
Tri Dao committed
260
261


Tri Dao's avatar
Tri Dao committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def attention_qkvpacked_ref(
    qkv,
    key_padding_mask=None,
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
    upcast=True,
    reorder_ops=False,
):
    return attention_ref(
        qkv[:, :, 0],
        qkv[:, :, 1],
        qkv[:, :, 2],
        key_padding_mask,
        key_padding_mask,
        dropout_p,
        dropout_mask,
        upcast=upcast,
        causal=causal,
        reorder_ops=reorder_ops,
    )
Tri Dao's avatar
Tri Dao committed
283
284
285
286
287
288
289
290
291
292
293


def generate_sparsity_mask(seqlen, sparsity=0.3):
    repeats = seqlen // 16 // 2
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    nrow, ncol = seqlen // 16, seqlen // 256
Tri Dao's avatar
Tri Dao committed
294
    mask = torch.rand(nrow, ncol, device="cuda") < sparsity
Tri Dao's avatar
Tri Dao committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    return mask


def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        blockmask: (seqlen / 16, seqlen / 256)
        attn_mask: (batch_size, seqlen)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen, seqlen)
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
        attention: softmax after dropout
    """
    q, k, v = qkv.float().unbind(dim=2)
    d = qkv.shape[-1]
    seqlen = qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
313
314
315
    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
    scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
    blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
Tri Dao's avatar
Tri Dao committed
316
    blockmask = blockmask[:seqlen, :seqlen]
Tri Dao's avatar
Tri Dao committed
317
    scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
Tri Dao's avatar
Tri Dao committed
318
    attention = torch.softmax(scores, dim=-1)
Tri Dao's avatar
Tri Dao committed
319
320
    attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
    attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
Tri Dao's avatar
Tri Dao committed
321
    attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
Tri Dao's avatar
Tri Dao committed
322
323
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
    output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
Tri Dao's avatar
Tri Dao committed
324
325
326
    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)


Tri Dao's avatar
Tri Dao committed
327
def convert_flash_attn_S_to_softmax(
328
    S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False
Tri Dao's avatar
Tri Dao committed
329
):
Tri Dao's avatar
Tri Dao committed
330
331
    """FlashAttention stores the S matrix in a different way.
    Arguments:
Tri Dao's avatar
Tri Dao committed
332
        S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
333
334
        query_padding_mask: (batch_size, seqlen_q_rounded)
        key_padding_mask: (batch_size, seqlen_k_rounded)
Tri Dao's avatar
Tri Dao committed
335
    """
336
    seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
Tri Dao's avatar
Tri Dao committed
337
    warps_n = 4
Tri Dao's avatar
Tri Dao committed
338
    blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
339
340
    nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
    nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
Tri Dao's avatar
Tri Dao committed
341
    mmas_n = (blocksize_n + 16 - 1) // 16
Tri Dao's avatar
Tri Dao committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    S_flat = rearrange(
        S,
        "b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
        blocksize_m=blocksize_m,
        blocksize_n=blocksize_n,
    )
    S_converted = rearrange(
        S_flat,
        "b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
        mmas_n=mmas_n,
        warps_n=warps_n,
        eight=8,
        c0=2,
        c1=2,
        c2=2,
        four=4,
    )
359

Tri Dao's avatar
Tri Dao committed
360
    if causal:
361
362
363
364
365
        # causal_mask = torch.triu(
        #     torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1
        # )
        causal_mask = construct_causal_mask(
            seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device
Tri Dao's avatar
Tri Dao committed
366
        )
367
        causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True)
Tri Dao's avatar
Tri Dao committed
368
        S_converted.masked_fill_(causal_mask, 0.0)
Tri Dao's avatar
Tri Dao committed
369
370
371

    # Need to zero out things not in attention_mask in case S was initialized with random values
    # and some of those values aren't overwritten.
372
    seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
Tri Dao's avatar
Tri Dao committed
373
    if query_padding_mask is not None:
374
        query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
Tri Dao's avatar
Tri Dao committed
375
        S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
376
377
    seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
    if key_padding_mask is not None:
378
        key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
Tri Dao's avatar
Tri Dao committed
379
        S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
380
381
382
    S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
    S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
    return S_converted[:, :, :seqlen_q, :seqlen_k]
Tri Dao's avatar
Tri Dao committed
383
384


Tri Dao's avatar
Tri Dao committed
385
386
387
388
389
390
391
392
393
394
def normalize_flash_attn_S(
    attn_unnorm,
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
    is_dropout=False,
    causal=False,
):
Tri Dao's avatar
Tri Dao committed
395
396
397
398
399
400
401
402
403
404
405
406
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k, v: (batch_size, seqlen_k, nheads, head_dim)
        key_padding_mask: (batch_size, seqlen_q)
    Output:
        softmax_lse: (batch_size, nheads, seqlen_q)
        softmax_max: (batch_size, nheads, seqlen_q)
    """
    q, k, v = q.float(), k.float(), v.float()
    _, seqlen_q, _, head_dim = q.shape
    seqlen_k = k.shape[1]
Tri Dao's avatar
Tri Dao committed
407
    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
Tri Dao's avatar
Tri Dao committed
408
    if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
409
        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
Tri Dao's avatar
Tri Dao committed
410
    if causal:
411
412
413
414
415
        # causal_mask = torch.triu(
        #     torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
        # )
        causal_mask = construct_causal_mask(
            seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device
Tri Dao's avatar
Tri Dao committed
416
417
        )
        scores.masked_fill_(causal_mask, float("-inf"))
Tri Dao's avatar
Tri Dao committed
418
419
    _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
    scores_block = scores.split(block_size_n, dim=-1)
Tri Dao's avatar
Tri Dao committed
420
    lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
Tri Dao's avatar
Tri Dao committed
421
    lse = torch.logsumexp(lse_block, dim=-1)
422
423
424
    # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
    # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
    lse[lse == float("-inf")] = float("inf")
Tri Dao's avatar
Tri Dao committed
425
426
427
    scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
    cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
    attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
Tri Dao's avatar
Tri Dao committed
428
429
    attn_norm = torch.cat(
        [
430
            a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
Tri Dao's avatar
Tri Dao committed
431
432
433
434
            for a, m in zip(attn_unnorm_block, cummax_block)
        ],
        dim=-1,
    )
Tri Dao's avatar
Tri Dao committed
435
    if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
436
        attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
437
438
439
    return attn_norm.to(dtype=attn_unnorm.dtype)


Tri Dao's avatar
Tri Dao committed
440
441
442
def get_dropout_fraction(
    dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False
):
Tri Dao's avatar
Tri Dao committed
443
444
445
446
447
448
449
450
    """
    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
    query_padding_mask: (batch_size, seqlen_q)
    key_padding_mask: (batch_size, seqlen_k)
    """
    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
    dropped = ~dropout_mask
    if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
451
        dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
Tri Dao's avatar
Tri Dao committed
452
    if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
453
        dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
Tri Dao's avatar
Tri Dao committed
454
    if causal:
455
456
457
458
459
        # causal_mask = torch.triu(
        #     torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
        # )
        causal_mask = construct_causal_mask(
            seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, dropout_mask.device
Tri Dao's avatar
Tri Dao committed
460
        )
Tri Dao's avatar
Tri Dao committed
461
462
        dropped.masked_fill_(causal_mask, False)
    dropped_total = dropped.sum()
Tri Dao's avatar
Tri Dao committed
463
464
465
466
467
468
469
470
471
472
    query_lengths = (
        query_padding_mask.sum(dim=-1)
        if query_padding_mask is not None
        else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)
    )
    key_lengths = (
        key_padding_mask.sum(dim=-1)
        if key_padding_mask is not None
        else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)
    )
Tri Dao's avatar
Tri Dao committed
473
474
475
476
    if not causal:
        numel_per_batch = query_lengths * key_lengths
    else:
        numel_per_batch = torch.where(
477
478
479
            key_lengths <= query_lengths,
            key_lengths * (key_lengths + 1) / 2,
            query_lengths * key_lengths - (query_lengths * (query_lengths - 1) / 2),
Tri Dao's avatar
Tri Dao committed
480
481
482
483
        )
    return dropped_total / (numel_per_batch.sum() * nheads)


Tri Dao's avatar
Tri Dao committed
484
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
Tri Dao's avatar
Tri Dao committed
485
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
486
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
487
# @pytest.mark.parametrize('causal', [True])
Tri Dao's avatar
Tri Dao committed
488
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
489
490
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
Tri Dao's avatar
Tri Dao committed
491
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
492
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
Tri Dao's avatar
Tri Dao committed
493
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
Tri Dao's avatar
Tri Dao committed
494
# @pytest.mark.parametrize('seqlen', [97])
Tri Dao's avatar
Tri Dao committed
495
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
Tri Dao's avatar
Tri Dao committed
496
497
# @pytest.mark.parametrize('dropout_p', [0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
Tri Dao's avatar
Tri Dao committed
498
    if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
Tri Dao's avatar
Tri Dao committed
499
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
500
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
501
502
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
503
504
    batch_size = 16
    nheads = 9
Tri Dao's avatar
Tri Dao committed
505
506
507
    qkv = torch.randn(
        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
    )
Tri Dao's avatar
Tri Dao committed
508
509
    out, lse, S_dmask = flash_attn_qkvpacked_func(
        qkv, dropout_p, return_attn_probs=True, causal=causal
Tri Dao's avatar
Tri Dao committed
510
    )
Tri Dao's avatar
Tri Dao committed
511
512
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
513
514
            S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal
        )
Tri Dao's avatar
Tri Dao committed
515
516
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
Tri Dao's avatar
Tri Dao committed
517
518
519
520
521
522
523
524
525
526
        attn = normalize_flash_attn_S(
            attn_unnorm,
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            None,
            None,
            dropout_p > 0.0,
            causal=causal,
        )
Tri Dao's avatar
Tri Dao committed
527
        dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
Tri Dao's avatar
Tri Dao committed
528
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
529
530
531
532
    else:
        dropout_mask = None

    out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal)
Tri Dao's avatar
Tri Dao committed
533
534
535
    out_pt, attn_pt = attention_qkvpacked_ref(
        qkv, None, dropout_p, dropout_mask, causal=causal, upcast=False, reorder_ops=True
    )
Tri Dao's avatar
Tri Dao committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    # v = qkv[:, :, 2].float()
    # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
    # if causal:
    #     causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
    #     qk.masked_fill_(causal_mask, float('-inf'))
    # m = qk.amax(-1, keepdim=True)
    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
    # p_tmp = torch.softmax(qk / math.sqrt(d), -1)
    # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
    # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
    # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
    # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
    # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
    # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
    # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
    # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
    # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
Tri Dao's avatar
Tri Dao committed
554
555
556
557
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
558
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
559
560
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
561
562
563
564
565
566

    g = torch.randn_like(out)
    # do_o = (g.float() * out.float()).sum(-1)
    # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
    # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
567
568
569
570
571
572
573
574
575
576
577
        (dqkv,) = torch.autograd.grad(out, qkv, g)
        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
        print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
578
579
580

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
581
582
583
584
585
586
587
588
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
589
590


Tri Dao's avatar
Tri Dao committed
591
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
Tri Dao's avatar
Tri Dao committed
592
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
593
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
594
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
595
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
596
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
597
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
Tri Dao's avatar
Tri Dao committed
598
# @pytest.mark.parametrize('seqlen', [128])
Tri Dao's avatar
Tri Dao committed
599
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
Tri Dao's avatar
Tri Dao committed
600
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
601
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
Tri Dao's avatar
Tri Dao committed
602
    if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
Tri Dao's avatar
Tri Dao committed
603
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
604
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
605
606
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
607
608
    batch_size = 5
    nheads = 6
Tri Dao's avatar
Tri Dao committed
609
610
611
    qkv = torch.randn(
        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
    )
Tri Dao's avatar
Tri Dao committed
612

Tri Dao's avatar
Tri Dao committed
613
    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
Tri Dao's avatar
Tri Dao committed
614
    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
Tri Dao's avatar
Tri Dao committed
615

Tri Dao's avatar
Tri Dao committed
616
617
    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
Tri Dao's avatar
Tri Dao committed
618
    )
Tri Dao's avatar
Tri Dao committed
619
620
621

    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
        qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
Tri Dao's avatar
Tri Dao committed
622
    )
Tri Dao's avatar
Tri Dao committed
623
624
625
    out = output_pad_fn(out_unpad)
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
626
627
            S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
        )
Tri Dao's avatar
Tri Dao committed
628
629
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
Tri Dao's avatar
Tri Dao committed
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        attn = normalize_flash_attn_S(
            attn_unnorm,
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            key_padding_mask,
            key_padding_mask,
            dropout_p > 0.0,
            causal=causal,
        )
        dropout_fraction = get_dropout_fraction(
            dropout_mask, key_padding_mask, key_padding_mask, causal=causal
        ).item()
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
644
645
646
    else:
        dropout_mask = None

Tri Dao's avatar
Tri Dao committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    out_ref, attn_ref = attention_qkvpacked_ref(
        qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal
    )
    out_pt, attn_pt = attention_qkvpacked_ref(
        qkv,
        key_padding_mask,
        dropout_p,
        dropout_mask,
        causal=causal,
        upcast=False,
        reorder_ops=True,
    )
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
663
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
664
665
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
666
667
668

    g = torch.randn_like(out)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
669
        (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
Tri Dao's avatar
Tri Dao committed
670
        dqkv = dqkv_pad_fn(dqkv_unpad)
Tri Dao's avatar
Tri Dao committed
671
672
673
674
675
676
677
678
679
680
        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
        print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
681
682
683

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
684
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
685

Tri Dao's avatar
Tri Dao committed
686
687
688
689
690
691
    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
692
693


Tri Dao's avatar
Tri Dao committed
694
@pytest.mark.parametrize("kvpacked", [True, False])
695
# @pytest.mark.parametrize("kvpacked", [False])
Tri Dao's avatar
Tri Dao committed
696
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
697
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
698
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
699
# @pytest.mark.parametrize("mha_type", ["mha"])
Tri Dao's avatar
Tri Dao committed
700
@pytest.mark.parametrize("causal", [False, True])
701
# @pytest.mark.parametrize("causal", [True])
Tri Dao's avatar
Tri Dao committed
702
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
703
704
705
706
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
707
# @pytest.mark.parametrize("d", [64])
Tri Dao's avatar
Tri Dao committed
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (512, 256),
        (1024, 1024),
        (1023, 1024),
        (1024, 1023),
        (2048, 2048),
    ],
)
723
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
Tri Dao's avatar
Tri Dao committed
724
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
725
# @pytest.mark.parametrize("dropout_p", [0.17])
Tri Dao's avatar
Tri Dao committed
726
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
Tri Dao's avatar
Tri Dao committed
727
728
729
730
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
Tri Dao's avatar
Tri Dao committed
731
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
732
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
733
734
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
735
736
737
738
739
740
    batch_size = 16
    nheads = 9
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    if kvpacked:
Tri Dao's avatar
Tri Dao committed
741
742
743
        kv = torch.randn(
            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
Tri Dao's avatar
Tri Dao committed
744
    else:
Tri Dao's avatar
Tri Dao committed
745
746
747
748
749
750
        k = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
        v = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
Tri Dao's avatar
Tri Dao committed
751
752
753
754
755
756
757
758
759
760
761

    if kvpacked:
        out, lse, S_dmask = flash_attn_kvpacked_func(
            q, kv, dropout_p, return_attn_probs=True, causal=causal
        )
    else:
        out, lse, S_dmask = flash_attn_func(
            q, k, v, dropout_p, return_attn_probs=True, causal=causal
        )
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
762
763
            S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal
        )
Tri Dao's avatar
Tri Dao committed
764
765
766
767
768
769
770
771
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        if kvpacked:
            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
            k_rep, v_rep = kv_rep.unbind(dim=2)
        else:
            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
Tri Dao's avatar
Tri Dao committed
772
773
774
        attn = normalize_flash_attn_S(
            attn_unnorm, q, k_rep, v_rep, None, None, dropout_p > 0.0, causal=causal
        )
Tri Dao's avatar
Tri Dao committed
775
        dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
Tri Dao's avatar
Tri Dao committed
776
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
777
778
    else:
        dropout_mask = None
Tri Dao's avatar
Tri Dao committed
779

Tri Dao's avatar
Tri Dao committed
780
    if kvpacked:
Tri Dao's avatar
Tri Dao committed
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        out_ref, attn_ref = attention_kvpacked_ref(
            q, kv, None, None, dropout_p, dropout_mask, causal=causal
        )
        out_pt, attn_pt = attention_kvpacked_ref(
            q,
            kv,
            None,
            None,
            dropout_p,
            dropout_mask,
            causal=causal,
            upcast=False,
            reorder_ops=True,
        )
Tri Dao's avatar
Tri Dao committed
795
    else:
Tri Dao's avatar
Tri Dao committed
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
        out_ref, attn_ref = attention_ref(
            q, k, v, None, None, dropout_p, dropout_mask, causal=causal
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
            dropout_p,
            dropout_mask,
            causal=causal,
            upcast=False,
            reorder_ops=True,
        )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
816
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
817
818
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
819
820
821
822
823

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        if kvpacked:
Tri Dao's avatar
Tri Dao committed
824
825
826
827
            (
                dq,
                dkv,
            ) = torch.autograd.grad(out, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
828
            dk, dv = dkv.unbind(2)
Tri Dao's avatar
Tri Dao committed
829
830
831
832
            (
                dq_ref,
                dkv_ref,
            ) = torch.autograd.grad(out_ref, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
833
            dk_ref, dv_ref = dkv_ref.unbind(2)
Tri Dao's avatar
Tri Dao committed
834
835
836
837
            (
                dq_pt,
                dkv_pt,
            ) = torch.autograd.grad(out_pt, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
838
839
            dk_pt, dv_pt = dkv_pt.unbind(2)
        else:
Tri Dao's avatar
Tri Dao committed
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
            (
                dq,
                dk,
                dv,
            ) = torch.autograd.grad(out, (q, k, v), g)
            (
                dq_ref,
                dk_ref,
                dv_ref,
            ) = torch.autograd.grad(out_ref, (q, k, v), g)
            (
                dq_pt,
                dk_pt,
                dv_pt,
            ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
867
868
869

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
870
871
872
873
874
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01
Tri Dao's avatar
Tri Dao committed
875

Tri Dao's avatar
Tri Dao committed
876
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
877
878
879
880
881
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()


Tri Dao's avatar
Tri Dao committed
882
@pytest.mark.parametrize("kvpacked", [True, False])
Tri Dao's avatar
Tri Dao committed
883
# @pytest.mark.parametrize('kvpacked', [False])
Tri Dao's avatar
Tri Dao committed
884
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
885
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
886
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
Tri Dao's avatar
Tri Dao committed
887
# @pytest.mark.parametrize('mha_type', ["mqa"])
Tri Dao's avatar
Tri Dao committed
888
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
889
# @pytest.mark.parametrize('causal', [True])
Tri Dao's avatar
Tri Dao committed
890
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
891
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
892
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (512, 256),
        (1024, 1024),
        (1023, 1024),
        (1024, 1023),
        (2048, 2048),
    ],
)
Tri Dao's avatar
Tri Dao committed
908
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
Tri Dao's avatar
Tri Dao committed
909
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
910
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
911
912
913
914
915
916
917
def test_flash_attn_varlen_output(
    seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked
):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
918
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
919
    device = "cuda"
920
921
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
922
923
924
925
926
927
    batch_size = 16
    nheads = 9
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    if kvpacked:
Tri Dao's avatar
Tri Dao committed
928
929
930
        kv = torch.randn(
            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
931
    else:
Tri Dao's avatar
Tri Dao committed
932
933
934
935
936
937
        k = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
        v = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
Tri Dao's avatar
Tri Dao committed
938

Tri Dao's avatar
Tri Dao committed
939
940
    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
Tri Dao's avatar
Tri Dao committed
941
942
943
    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')

    if kvpacked:
Tri Dao's avatar
Tri Dao committed
944
945
946
947
948
949
950
951
952
953
954
955
956
        (
            q_unpad,
            kv_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q,
            kv,
            output_pad_fn,
            dq_pad_fn,
            dkv_pad_fn,
        ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
Tri Dao's avatar
Tri Dao committed
957
        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
958
959
960
961
962
963
964
965
966
            q_unpad,
            kv_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            return_attn_probs=True,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
967
968
        )
    else:
Tri Dao's avatar
Tri Dao committed
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
        (
            q_unpad,
            k_unpad,
            v_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q,
            k,
            v,
            output_pad_fn,
            dq_pad_fn,
            dk_pad_fn,
        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
Tri Dao's avatar
Tri Dao committed
984
        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
Tri Dao's avatar
Tri Dao committed
985
986
987
988
989
990
991
992
993
994
            q_unpad,
            k_unpad,
            v_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            return_attn_probs=True,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
995
        )
Tri Dao's avatar
Tri Dao committed
996
997
    out = output_pad_fn(out_unpad)
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
998
        S_dmask_converted = convert_flash_attn_S_to_softmax(
999
1000
            S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
        )
Tri Dao's avatar
Tri Dao committed
1001
1002
1003
1004
1005
1006
1007
1008
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        if kvpacked:
            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
            k_rep, v_rep = kv_rep.unbind(dim=2)
        else:
            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
Tri Dao's avatar
Tri Dao committed
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        attn = normalize_flash_attn_S(
            attn_unnorm,
            q,
            k_rep,
            v_rep,
            query_padding_mask,
            key_padding_mask,
            dropout_p > 0.0,
            causal=causal,
        )
        dropout_fraction = get_dropout_fraction(
            dropout_mask, query_padding_mask, key_padding_mask, causal=causal
        ).item()
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
1023
1024
1025
1026
    else:
        dropout_mask = None

    if kvpacked:
Tri Dao's avatar
Tri Dao committed
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        out_ref, attn_ref = attention_kvpacked_ref(
            q, kv, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
        )
        out_pt, attn_pt = attention_kvpacked_ref(
            q,
            kv,
            query_padding_mask,
            key_padding_mask,
            dropout_p,
            dropout_mask,
            causal=causal,
            upcast=False,
            reorder_ops=True,
        )
Tri Dao's avatar
Tri Dao committed
1041
    else:
Tri Dao's avatar
Tri Dao committed
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        out_ref, attn_ref = attention_ref(
            q, k, v, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            query_padding_mask,
            key_padding_mask,
            dropout_p,
            dropout_mask,
            causal=causal,
            upcast=False,
            reorder_ops=True,
        )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
1062
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
1063
1064
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
1065
1066
1067
1068

    g = torch.randn_like(out)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        if kvpacked:
Tri Dao's avatar
Tri Dao committed
1069
1070
1071
1072
            (
                dq_unpad,
                dkv_unpad,
            ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
Tri Dao's avatar
Tri Dao committed
1073
            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
Tri Dao's avatar
Tri Dao committed
1074
1075
1076
1077
            (
                dq_ref,
                dkv_ref,
            ) = torch.autograd.grad(out_ref, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1078
            dk_ref, dv_ref = dkv_ref.unbind(2)
Tri Dao's avatar
Tri Dao committed
1079
1080
1081
1082
            (
                dq_pt,
                dkv_pt,
            ) = torch.autograd.grad(out_pt, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1083
1084
            dk_pt, dv_pt = dkv_pt.unbind(2)
        else:
Tri Dao's avatar
Tri Dao committed
1085
1086
1087
1088
1089
            (
                dq_unpad,
                dk_unpad,
                dv_unpad,
            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
Tri Dao's avatar
Tri Dao committed
1090
1091
            dk = dk_pad_fn(dk_unpad)
            dv = dk_pad_fn(dv_unpad)
Tri Dao's avatar
Tri Dao committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
            (
                dq_ref,
                dk_ref,
                dv_ref,
            ) = torch.autograd.grad(out_ref, (q, k, v), g)
            (
                dq_pt,
                dk_pt,
                dv_pt,
            ) = torch.autograd.grad(out_pt, (q, k, v), g)
Tri Dao's avatar
Tri Dao committed
1102
        dq = dq_pad_fn(dq_unpad)
Tri Dao's avatar
Tri Dao committed
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
1115
1116
1117

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
1118
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
1119

Tri Dao's avatar
Tri Dao committed
1120
1121
1122
    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01
Tri Dao's avatar
Tri Dao committed
1123

Tri Dao's avatar
Tri Dao committed
1124
1125
1126
1127
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
1128

1129

Tri Dao's avatar
Tri Dao committed
1130
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (3, 799),
        (127, 512),
        (127, 513),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (1023, 1024),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    causal = True
    # set seed
    torch.random.manual_seed(0)
    batch_size = 16
    nheads = 9
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    out = flash_attn_func(q, k, v, 0.0, causal=causal)
    out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        None,
        None,
        0.0,
        None,
        causal=causal,
        upcast=False,
        reorder_ops=True,
    )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        (
            dq,
            dk,
            dv,
        ) = torch.autograd.grad(out, (q, k, v), g)
        (
            dq_ref,
            dk_ref,
            dv_ref,
        ) = torch.autograd.grad(out_ref, (q, k, v), g)
        (
            dq_pt,
            dk_pt,
            dv_pt,
        ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5


@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (3, 799),
        (127, 512),
        (127, 513),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (1023, 1024),
    ],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    causal = True
    # set seed
    torch.random.manual_seed(0)
    batch_size = 16
    nheads = 9
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
    (
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        q,
        k,
        v,
        output_pad_fn,
        dq_pad_fn,
        dk_pad_fn,
    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
    out_unpad = flash_attn_varlen_func(
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        0.0,
        causal=causal,
    )
    out = output_pad_fn(out_unpad)
    out_ref, attn_ref = attention_ref(
        q, k, v, query_padding_mask, key_padding_mask, 0.0, None, causal=causal
    )
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        query_padding_mask,
        key_padding_mask,
        0.0,
        None,
        causal=causal,
        upcast=False,
        reorder_ops=True,
    )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        (
            dq_unpad,
            dk_unpad,
            dv_unpad,
        ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
        dq = dq_pad_fn(dq_unpad)
        dk = dk_pad_fn(dk_unpad)
        dv = dk_pad_fn(dv_unpad)
        (
            dq_ref,
            dk_ref,
            dv_ref,
        ) = torch.autograd.grad(out_ref, (q, k, v), g)
        (
            dq_pt,
            dk_pt,
            dv_pt,
        ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5


Tri Dao's avatar
Tri Dao committed
1371
1372
1373
1374
1375
1376
1377
1378
1379
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
1380
# @pytest.mark.parametrize("d", [64])
Tri Dao's avatar
Tri Dao committed
1381
1382
1383
1384
1385
1386
1387
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (3, 1024),
        (1, 339),
1388
        (64, 800),
Tri Dao's avatar
Tri Dao committed
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
        (3, 799),
        (64, 2048),
        (16, 20000),
        (16, 100000),
        (128, 128),
        (256, 256),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    batch_size = 1
    nheads = 12
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True)
    out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        None,
        None,
        0.0,
        None,
        causal=causal,
        upcast=False,
        reorder_ops=True,
    )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        (
            dq,
            dk,
            dv,
        ) = torch.autograd.grad(out, (q, k, v), g)
        (
            dq_ref,
            dk_ref,
            dv_ref,
        ) = torch.autograd.grad(out_ref, (q, k, v), g)
        (
            dq_pt,
            dk_pt,
            dv_pt,
        ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4

Tri Dao's avatar
Tri Dao committed
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 128),
        (1, 339),
        (3, 1024),
        (64, 800),
        (64, 256),
        (3, 799),
        (64, 2048),
        (16, 20000),
        (1, 128 * 1024),
        (16, 128 * 1024),
        (128, 128),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype):
    if seqlen_q > seqlen_k and new_kv:
        pytest.skip()
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    batch_size = 2
    nheads = 6
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
    if new_kv:
        k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
        v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
    else:
        k, v = None, None
    k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
    v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
    cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device)
    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
    # k_cache[:, 64:] = -1
    k_cache_ref = k_cache.clone()
    v_cache_ref = v_cache.clone()
    arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
    cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
    if new_kv:
        update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q)
        k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
        v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
    k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
    v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
    out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits)
    # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
    # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
    # m = qk.amax(-1, keepdim=True)
    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
    # probs = torch.softmax(qk, dim=-1)
    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0)
    out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal)
    out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal,
                              upcast=False, reorder_ops=True)
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
    if new_kv:
        assert torch.equal(k_cache, k_cache_ref)
        assert torch.equal(v_cache, v_cache_ref)

Tri Dao's avatar
Tri Dao committed
1558

1559
1560
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
Tri Dao's avatar
Tri Dao committed
1561
@pytest.mark.parametrize("causal", [False, True])
1562
1563
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
1564
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
1565
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
Tri Dao's avatar
Tri Dao committed
1566
# @pytest.mark.parametrize('d', [128])
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (239, 1),
        (3, 799),
        (799, 3),
        (1024, 128),
        (97, 97),
        (128, 128),
        (200, 200),
        (256, 256),
        (257, 257),
        (384, 384),
        (512, 512),
        (768, 768),
        (1024, 1024),
    ],
)
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
Tri Dao's avatar
Tri Dao committed
1589
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
1590
1591
    # set seed
    torch.random.manual_seed(0)
1592
    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger
Tri Dao's avatar
Tri Dao committed
1593
    nheads = 4
1594
1595
1596
1597
1598
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    torch.random.manual_seed(42)
    out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
Tri Dao's avatar
Tri Dao committed
1599
    g = torch.randn_like(out0)
1600
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
1601
1602
1603
1604
1605
        (
            dq0,
            dk0,
            dv0,
        ) = torch.autograd.grad(out0, (q, k, v), g)
1606
        # Numerical error if we just do any arithmetic on dq
1607
        dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
Tri Dao's avatar
Tri Dao committed
1608

1609
1610
1611
    for i in range(250):
        torch.random.manual_seed(42)
        out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
Tri Dao's avatar
Tri Dao committed
1612
1613
        assert torch.equal(out, out0)
        assert torch.equal(lse, lse0)
Tri Dao's avatar
Tri Dao committed
1614

1615
        if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
1616
1617
1618
1619
1620
1621
            (
                dq,
                dk,
                dv,
            ) = torch.autograd.grad(out, (q, k, v), g)
            dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
1622
            if not dq_equal:
1623
1624
1625
                print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
            assert torch.equal(dv, dv0)
            assert torch.equal(dk, dk0)
1626
            assert dq_equal
1627
1628


Tri Dao's avatar
Tri Dao committed
1629
1630
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
1631
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
1632
@pytest.mark.parametrize("d", [16, 32, 64])
1633
# @pytest.mark.parametrize('d', [16])
Tri Dao's avatar
Tri Dao committed
1634
@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
1635
1636
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
Tri Dao's avatar
Tri Dao committed
1637
    """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
1638
1639
    in the case where seqlen % 128 != 0.
    """
Tri Dao's avatar
Tri Dao committed
1640
    device = "cuda"
1641
1642
1643
1644
1645
    # set seed
    torch.random.manual_seed(0)
    batch_size = 2
    nheads = 5
    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
Tri Dao's avatar
Tri Dao committed
1646
1647
1648
1649
    k, v = [
        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
        for _ in range(2)
    ]
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
    q.requires_grad_(True)
    k.requires_grad_(True)
    v.requires_grad_(True)
    out = flash_attn_func(q, k, v, causal=causal)
    g = torch.randn_like(out)
    out.backward(g)
    q_pt = q.detach().clone().requires_grad_(True)
    k_pt = k.detach().clone().requires_grad_(True)
    v_pt = v.detach().clone().requires_grad_(True)
    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
    out_pt.backward(g)
    q_ref = q.detach().clone().requires_grad_(True)
    k_ref = k.detach().clone().requires_grad_(True)
    v_ref = v.detach().clone().requires_grad_(True)
    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
    out_ref.backward(g)
Tri Dao's avatar
Tri Dao committed
1666
1667
1668
1669
1670
1671
    print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
    print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
    print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
    print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
    print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
    print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
1672
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
    assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
        q_pt.grad - q_ref.grad
    ).abs().max().item() + 1e-3
    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
        k_pt.grad - k_ref.grad
    ).abs().max().item() + 1e-3
    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
        v_pt.grad - v_ref.grad
    ).abs().max().item() + 1e-3


@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
1685
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
1686
@pytest.mark.parametrize("causal", [False, True])
1687
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
1688
@pytest.mark.parametrize("d", [64, 128])
1689
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
1690
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
1691
1692
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
Tri Dao's avatar
Tri Dao committed
1693
    """We previously had a bug where we were using the wrong strides of dout, which shows up
1694
1695
    when dout is not contiguous.
    """
Tri Dao's avatar
Tri Dao committed
1696
    device = "cuda"
1697
1698
1699
1700
    # set seed
    torch.random.manual_seed(0)
    batch_size = 5
    nheads = 2
Tri Dao's avatar
Tri Dao committed
1701
1702
1703
1704
    q, k, v = [
        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
        for _ in range(3)
    ]
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
    out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
    # So g is not contiguous
    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
    out.backward(g)
    q_pt = q.detach().clone().requires_grad_(True)
    k_pt = k.detach().clone().requires_grad_(True)
    v_pt = v.detach().clone().requires_grad_(True)
    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
    out_pt = rearrange(out_pt, "b s ... -> s b ...")
    out_pt.backward(g)
    q_ref = q.detach().clone().requires_grad_(True)
    k_ref = k.detach().clone().requires_grad_(True)
    v_ref = v.detach().clone().requires_grad_(True)
    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
    out_ref = rearrange(out_ref, "b s ... -> s b ...")
    out_ref.backward(g)
Tri Dao's avatar
Tri Dao committed
1721
1722
1723
1724
1725
1726
    print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
    print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
    print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
    print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
    print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
    print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
1727
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
        q_pt.grad - q_ref.grad
    ).abs().max().item()
    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
        k_pt.grad - k_ref.grad
    ).abs().max().item()
    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
        v_pt.grad - v_ref.grad
    ).abs().max().item()


@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
1741
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
1742
@pytest.mark.parametrize("d", [16, 32, 64])
1743
1744
# @pytest.mark.parametrize('d', [16])
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
Tri Dao's avatar
Tri Dao committed
1745
    """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
1746
1747
    in the case where seqlen % 128 != 0 or varlen.
    """
Tri Dao's avatar
Tri Dao committed
1748
    device = "cuda"
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
    # set seed
    torch.random.manual_seed(0)
    nheads = 5
    q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
    k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
    Mq = 256
    Mk = 3

    q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
    k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
    q.requires_grad_(True)
    k.requires_grad_(True)
    v.requires_grad_(True)

    out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
    g = torch.randn_like(out)
    out.backward(g)

    assert not q.grad.isnan().any()
    assert not k.grad.isnan().any()
    assert not v.grad.isnan().any()