test_flash_attn.py 94.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
import math

Tri Dao's avatar
Tri Dao committed
3
import pytest
Tri Dao's avatar
Tri Dao committed
4
5
6
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
7
8
9
10
11
12
13
from flash_attn import (
    flash_attn_func,
    flash_attn_kvpacked_func,
    flash_attn_qkvpacked_func,
    flash_attn_varlen_func,
    flash_attn_varlen_kvpacked_func,
    flash_attn_varlen_qkvpacked_func,
Tri Dao's avatar
Tri Dao committed
14
    flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
15
)
16
from flash_attn.bert_padding import pad_input, unpad_input
Tri Dao's avatar
Tri Dao committed
17
from flash_attn.flash_attn_interface import _get_block_size_n
18
from flash_attn.layers.rotary import apply_rotary_emb
Tri Dao's avatar
Tri Dao committed
19
20

MAX_HEADDIM_SM8x = 192
Tri Dao's avatar
Tri Dao committed
21

Tri Dao's avatar
Tri Dao committed
22

Tri Dao's avatar
Tri Dao committed
23
24
25
26
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
Tri Dao's avatar
Tri Dao committed
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def attn_bias_from_alibi_slopes(
    slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
):
    batch, nheads = slopes.shape
    device = slopes.device
    slopes = rearrange(slopes, "b h -> b h 1 1")
    if causal:
        return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
    else:
        row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
        col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
        sk = (
            seqlen_k
            if key_padding_mask is None
            else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
        )
        sq = (
            seqlen_q
            if query_padding_mask is None
            else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
        )
        relative_pos = torch.abs(row_idx + sk - sq - col_idx)
        return -slopes * relative_pos.to(dtype=slopes.dtype)


Tri Dao's avatar
Tri Dao committed
54
55
56
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
    assert mode in ["full", "random", "third"]
    if mode == "full":
Tri Dao's avatar
Tri Dao committed
57
        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
Tri Dao's avatar
Tri Dao committed
58
    elif mode == "random":
59
60
61
        lengths = torch.randint(
            max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
        )
Tri Dao's avatar
Tri Dao committed
62
    elif mode == "third":
63
        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
Tri Dao's avatar
Tri Dao committed
64
65
66
    padding_mask = (
        repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
    )
Tri Dao's avatar
Tri Dao committed
67
68
69
    return padding_mask


Tri Dao's avatar
Tri Dao committed
70
71
72
def generate_qkv(
    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
Tri Dao's avatar
Tri Dao committed
73
74
    """
    Arguments:
Tri Dao's avatar
Tri Dao committed
75
76
77
        q: (batch_size, seqlen_q, nheads, d)
        k: (batch_size, seqlen_k, nheads_k, d)
        v: (batch_size, seqlen_k, nheads_k, d)
Tri Dao's avatar
Tri Dao committed
78
79
80
81
        query_padding_mask: (batch_size, seqlen), bool
        key_padding_mask: (batch_size, seqlen), bool
    """
    assert not (kvpacked and qkvpacked)
Tri Dao's avatar
Tri Dao committed
82
83
84
85
    batch_size, seqlen_q, nheads, d = q.shape
    _, seqlen_k, nheads_k, _ = k.shape
    assert k.shape == (batch_size, seqlen_k, nheads_k, d)
    assert v.shape == (batch_size, seqlen_k, nheads_k, d)
Tri Dao's avatar
Tri Dao committed
86
87
88

    if query_padding_mask is not None:
        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
Tri Dao's avatar
Tri Dao committed
89
90
91
        output_pad_fn = lambda output_unpad: pad_input(
            output_unpad, indices_q, batch_size, seqlen_q
        )
Tri Dao's avatar
Tri Dao committed
92
    else:
Tri Dao's avatar
Tri Dao committed
93
94
95
96
        q_unpad = rearrange(q, "b s h d -> (b s) h d")
        cu_seqlens_q = torch.arange(
            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
        )
Tri Dao's avatar
Tri Dao committed
97
        max_seqlen_q = seqlen_q
Tri Dao's avatar
Tri Dao committed
98
99
100
        output_pad_fn = lambda output_unpad: rearrange(
            output_unpad, "(b s) h d -> b s h d", b=batch_size
        )
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105

    if key_padding_mask is not None:
        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
    else:
Tri Dao's avatar
Tri Dao committed
106
107
108
109
110
        k_unpad = rearrange(k, "b s h d -> (b s) h d")
        v_unpad = rearrange(v, "b s h d -> (b s) h d")
        cu_seqlens_k = torch.arange(
            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
        )
Tri Dao's avatar
Tri Dao committed
111
        max_seqlen_k = seqlen_k
Tri Dao's avatar
Tri Dao committed
112
113
114

    if qkvpacked:
        assert (query_padding_mask == key_padding_mask).all()
Tri Dao's avatar
Tri Dao committed
115
        assert nheads == nheads_k
Tri Dao's avatar
Tri Dao committed
116
        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
Tri Dao's avatar
Tri Dao committed
117
        qkv = torch.stack([q, k, v], dim=2)
Tri Dao's avatar
Tri Dao committed
118
        if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
119
            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
Tri Dao's avatar
Tri Dao committed
120
        else:
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
128
129
130
131
            dqkv_pad_fn = lambda dqkv_unpad: rearrange(
                dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
            )
        return (
            qkv_unpad.detach().requires_grad_(),
            cu_seqlens_q,
            max_seqlen_q,
            qkv.detach().requires_grad_(),
            output_pad_fn,
            dqkv_pad_fn,
        )
Tri Dao's avatar
Tri Dao committed
132
133
    elif kvpacked:
        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
Tri Dao's avatar
Tri Dao committed
134
        kv = torch.stack([k, v], dim=2)
Tri Dao's avatar
Tri Dao committed
135
136
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
137
            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
Tri Dao's avatar
Tri Dao committed
138
        else:
Tri Dao's avatar
Tri Dao committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            dkv_pad_fn = lambda dkv_unpad: rearrange(
                dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
            )
        return (
            q_unpad.detach().requires_grad_(),
            kv_unpad.detach().requires_grad_(),
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q.detach().requires_grad_(),
            kv.detach().requires_grad_(),
            output_pad_fn,
            dq_pad_fn,
            dkv_pad_fn,
        )
Tri Dao's avatar
Tri Dao committed
155
156
157
    else:
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
158
            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
Tri Dao's avatar
Tri Dao committed
159
        else:
Tri Dao's avatar
Tri Dao committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
        return (
            q_unpad.detach().requires_grad_(),
            k_unpad.detach().requires_grad_(),
            v_unpad.detach().requires_grad_(),
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q.detach().requires_grad_(),
            k.detach().requires_grad_(),
            v.detach().requires_grad_(),
            output_pad_fn,
            dq_pad_fn,
            dk_pad_fn,
        )
Tri Dao's avatar
Tri Dao committed
176
177


Tri Dao's avatar
Tri Dao committed
178
179
180
181
182
183
184
def construct_local_mask(
    seqlen_q,
    seqlen_k,
    window_size=(-1, -1),  # -1 means infinite window size
    query_padding_mask=None,
    key_padding_mask=None,
    device=None,
185
):
186
187
188
189
190
191
192
193
194
195
196
197
    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
    sk = (
        seqlen_k
        if key_padding_mask is None
        else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
    )
    sq = (
        seqlen_q
        if query_padding_mask is None
        else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
    )
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
205
    if window_size[0] < 0:
        return col_idx > row_idx + sk - sq + window_size[1]
    else:
        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
        return torch.logical_or(
            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
            col_idx < row_idx + sk - sq - window_size[0],
        )
206
207


Tri Dao's avatar
Tri Dao committed
208
209
210
211
212
213
def attention_ref(
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
214
    attn_bias=None,
Tri Dao's avatar
Tri Dao committed
215
216
217
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
218
    window_size=(-1, -1),  # -1 means infinite window size
Nicolas Patry's avatar
Nicolas Patry committed
219
    softcap=0.0,
Tri Dao's avatar
Tri Dao committed
220
221
222
    upcast=True,
    reorder_ops=False,
):
Tri Dao's avatar
Tri Dao committed
223
224
225
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
Tri Dao's avatar
Tri Dao committed
226
227
        k: (batch_size, seqlen_k, nheads_k, head_dim)
        v: (batch_size, seqlen_k, nheads_k, head_dim)
Tri Dao's avatar
Tri Dao committed
228
229
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
230
        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
231
232
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
233
234
        causal: whether to apply causal masking
        window_size: (int, int), left and right window size
Tri Dao's avatar
Tri Dao committed
235
236
        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
            output back to fp16/bf16.
cao lei's avatar
cao lei committed
237
        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
Tri Dao's avatar
Tri Dao committed
238
239
240
241
242
243
            without changing the math. This is to estimate the numerical error from operation
            reordering.
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
    """
Tri Dao's avatar
Tri Dao committed
244
245
    if causal:
        window_size = (window_size[0], 0)
Tri Dao's avatar
Tri Dao committed
246
247
248
249
    dtype_og = q.dtype
    if upcast:
        q, k, v = q.float(), k.float(), v.float()
    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
Tri Dao's avatar
Tri Dao committed
250
251
    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
Tri Dao's avatar
Tri Dao committed
252
253
    d = q.shape[-1]
    if not reorder_ops:
Tri Dao's avatar
Tri Dao committed
254
        scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
Tri Dao's avatar
Tri Dao committed
255
    else:
Tri Dao's avatar
Tri Dao committed
256
        scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
Nicolas Patry's avatar
Nicolas Patry committed
257
258
259
260
    if softcap > 0:
        scores /= softcap
        scores = scores.tanh()
        scores *= softcap
Tri Dao's avatar
Tri Dao committed
261
    if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
262
        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
Tri Dao's avatar
Tri Dao committed
263
264
265
266
267
268
269
270
    if window_size[0] >= 0 or window_size[1] >= 0:
        local_mask = construct_local_mask(
            seqlen_q,
            seqlen_k,
            window_size,
            query_padding_mask,
            key_padding_mask,
            q.device,
Tri Dao's avatar
Tri Dao committed
271
        )
Tri Dao's avatar
Tri Dao committed
272
        scores.masked_fill_(local_mask, float("-inf"))
273
274
275
    if attn_bias is not None:
        scores = scores + attn_bias
    attention = torch.softmax(scores, dim=-1).to(v.dtype)
Tri Dao's avatar
Tri Dao committed
276
277
278
279
280
281
282
    # Some rows might be completely masked out so we fill them with zero instead of NaN
    if window_size[0] >= 0 or window_size[1] >= 0:
        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
    # We want to mask here so that the attention matrix doesn't have any NaNs
    # Otherwise we'll get NaN in dV
    if query_padding_mask is not None:
        attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
283
284
285
286
287
    dropout_scaling = 1.0 / (1 - dropout_p)
    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    if dropout_mask is not None:
        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
Tri Dao's avatar
Tri Dao committed
288
289
    else:
        attention_drop = attention
Tri Dao's avatar
Tri Dao committed
290
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
Tri Dao's avatar
Tri Dao committed
291
    if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
292
        output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
293
294
295
    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
def attention_kvpacked_ref(
    q,
    kv,
    query_padding_mask=None,
    key_padding_mask=None,
301
    attn_bias=None,
Tri Dao's avatar
Tri Dao committed
302
303
304
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
305
    window_size=(-1, -1),  # -1 means infinite window size
Tri Dao's avatar
Tri Dao committed
306
307
308
309
310
311
312
313
314
    upcast=True,
    reorder_ops=False,
):
    return attention_ref(
        q,
        kv[:, :, 0],
        kv[:, :, 1],
        query_padding_mask,
        key_padding_mask,
315
        attn_bias,
Tri Dao's avatar
Tri Dao committed
316
317
318
319
        dropout_p,
        dropout_mask,
        upcast=upcast,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
320
        window_size=window_size,
Tri Dao's avatar
Tri Dao committed
321
322
        reorder_ops=reorder_ops,
    )
Tri Dao's avatar
Tri Dao committed
323
324


Tri Dao's avatar
Tri Dao committed
325
326
327
def attention_qkvpacked_ref(
    qkv,
    key_padding_mask=None,
328
    attn_bias=None,
Tri Dao's avatar
Tri Dao committed
329
330
331
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
Tri Dao's avatar
Tri Dao committed
332
    window_size=(-1, -1),  # -1 means infinite window size
Tri Dao's avatar
Tri Dao committed
333
334
335
336
337
338
339
340
341
    upcast=True,
    reorder_ops=False,
):
    return attention_ref(
        qkv[:, :, 0],
        qkv[:, :, 1],
        qkv[:, :, 2],
        key_padding_mask,
        key_padding_mask,
342
        attn_bias,
Tri Dao's avatar
Tri Dao committed
343
344
345
346
        dropout_p,
        dropout_mask,
        upcast=upcast,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
347
        window_size=window_size,
Tri Dao's avatar
Tri Dao committed
348
349
        reorder_ops=reorder_ops,
    )
Tri Dao's avatar
Tri Dao committed
350
351
352
353
354
355
356
357
358
359
360


def generate_sparsity_mask(seqlen, sparsity=0.3):
    repeats = seqlen // 16 // 2
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    nrow, ncol = seqlen // 16, seqlen // 256
Tri Dao's avatar
Tri Dao committed
361
    mask = torch.rand(nrow, ncol, device="cuda") < sparsity
Tri Dao's avatar
Tri Dao committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    return mask


def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        blockmask: (seqlen / 16, seqlen / 256)
        attn_mask: (batch_size, seqlen)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen, seqlen)
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
        attention: softmax after dropout
    """
    q, k, v = qkv.float().unbind(dim=2)
    d = qkv.shape[-1]
    seqlen = qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
380
381
382
    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
    scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
    blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
Tri Dao's avatar
Tri Dao committed
383
    blockmask = blockmask[:seqlen, :seqlen]
Tri Dao's avatar
Tri Dao committed
384
    scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
Tri Dao's avatar
Tri Dao committed
385
    attention = torch.softmax(scores, dim=-1)
Tri Dao's avatar
Tri Dao committed
386
387
    attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
    attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
Tri Dao's avatar
Tri Dao committed
388
    attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
Tri Dao's avatar
Tri Dao committed
389
390
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
    output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
Tri Dao's avatar
Tri Dao committed
391
392
393
    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)


Tri Dao's avatar
Tri Dao committed
394
def convert_flash_attn_S_to_softmax(
Tri Dao's avatar
Tri Dao committed
395
396
397
398
399
400
401
402
403
    S,
    seqlen_q,
    seqlen_k,
    query_padding_mask,
    key_padding_mask,
    head_dim,
    is_dropout,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite window size
Tri Dao's avatar
Tri Dao committed
404
):
Tri Dao's avatar
Tri Dao committed
405
406
    """FlashAttention stores the S matrix in a different way.
    Arguments:
Tri Dao's avatar
Tri Dao committed
407
        S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
408
409
        query_padding_mask: (batch_size, seqlen_q_rounded)
        key_padding_mask: (batch_size, seqlen_k_rounded)
Tri Dao's avatar
Tri Dao committed
410
    """
Tri Dao's avatar
Tri Dao committed
411
412
    if causal:
        window_size = (window_size[0], 0)
413
    seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
Tri Dao's avatar
Tri Dao committed
414
    S_converted = S
Tri Dao's avatar
Tri Dao committed
415
416
417
418
419
420
421
422
    if window_size[0] >= 0 or window_size[1] >= 0:
        local_mask = construct_local_mask(
            seqlen_q,
            seqlen_k,
            window_size,
            query_padding_mask,
            key_padding_mask,
            S.device,
Tri Dao's avatar
Tri Dao committed
423
        )
Tri Dao's avatar
Tri Dao committed
424
425
        local_mask = F.pad(
            local_mask,
426
427
428
            (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
            value=True,
        )
Tri Dao's avatar
Tri Dao committed
429
        S_converted = S_converted.masked_fill(local_mask, 0.0)
Tri Dao's avatar
Tri Dao committed
430
431
432

    # Need to zero out things not in attention_mask in case S was initialized with random values
    # and some of those values aren't overwritten.
433
434
435
    seqlen_q_og = (
        query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
    )
Tri Dao's avatar
Tri Dao committed
436
    if query_padding_mask is not None:
437
        query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
Tri Dao's avatar
Tri Dao committed
438
        S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
439
440
    seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
    if key_padding_mask is not None:
441
        key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
Tri Dao's avatar
Tri Dao committed
442
        S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
443
444
445
    S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
    S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
    return S_converted[:, :, :seqlen_q, :seqlen_k]
Tri Dao's avatar
Tri Dao committed
446
447


Tri Dao's avatar
Tri Dao committed
448
449
450
451
452
453
454
def normalize_flash_attn_S(
    attn_unnorm,
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
455
    attn_bias=None,
Tri Dao's avatar
Tri Dao committed
456
457
    is_dropout=False,
    causal=False,
Tri Dao's avatar
Tri Dao committed
458
    window_size=(-1, -1),  # -1 means infinite window size
Tri Dao's avatar
Tri Dao committed
459
):
Tri Dao's avatar
Tri Dao committed
460
461
462
463
464
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k, v: (batch_size, seqlen_k, nheads, head_dim)
        key_padding_mask: (batch_size, seqlen_q)
465
        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Tri Dao's avatar
Tri Dao committed
466
467
468
469
    Output:
        softmax_lse: (batch_size, nheads, seqlen_q)
        softmax_max: (batch_size, nheads, seqlen_q)
    """
Tri Dao's avatar
Tri Dao committed
470
471
    if causal:
        window_size = (window_size[0], 0)
Tri Dao's avatar
Tri Dao committed
472
473
474
    q, k, v = q.float(), k.float(), v.float()
    _, seqlen_q, _, head_dim = q.shape
    seqlen_k = k.shape[1]
Tri Dao's avatar
Tri Dao committed
475
    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
Tri Dao's avatar
Tri Dao committed
476
    if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
477
        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
Tri Dao's avatar
Tri Dao committed
478
479
480
481
482
483
484
485
    if window_size[0] >= 0 or window_size[1] >= 0:
        local_mask = construct_local_mask(
            seqlen_q,
            seqlen_k,
            window_size,
            query_padding_mask,
            key_padding_mask,
            q.device,
Tri Dao's avatar
Tri Dao committed
486
        )
Tri Dao's avatar
Tri Dao committed
487
        scores.masked_fill_(local_mask, float("-inf"))
488
489
    if attn_bias is not None:
        scores = scores + attn_bias.to(dtype=scores.dtype)
Tri Dao's avatar
Tri Dao committed
490
    block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
Tri Dao's avatar
Tri Dao committed
491
    scores_block = scores.split(block_size_n, dim=-1)
Tri Dao's avatar
Tri Dao committed
492
    lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
Tri Dao's avatar
Tri Dao committed
493
    lse = torch.logsumexp(lse_block, dim=-1)
494
495
496
    # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
    # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
    lse[lse == float("-inf")] = float("inf")
Tri Dao's avatar
Tri Dao committed
497
498
499
    scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
    cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
    attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
Tri Dao's avatar
Tri Dao committed
500
501
    attn_norm = torch.cat(
        [
502
            a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
Tri Dao's avatar
Tri Dao committed
503
504
505
506
            for a, m in zip(attn_unnorm_block, cummax_block)
        ],
        dim=-1,
    )
Tri Dao's avatar
Tri Dao committed
507
    if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
508
        attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
Tri Dao's avatar
Tri Dao committed
509
510
511
    return attn_norm.to(dtype=attn_unnorm.dtype)


Tri Dao's avatar
Tri Dao committed
512
def get_dropout_fraction(
Tri Dao's avatar
Tri Dao committed
513
514
515
516
517
    dropout_mask,
    query_padding_mask=None,
    key_padding_mask=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite window size
Tri Dao's avatar
Tri Dao committed
518
):
Tri Dao's avatar
Tri Dao committed
519
520
521
522
523
    """
    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
    query_padding_mask: (batch_size, seqlen_q)
    key_padding_mask: (batch_size, seqlen_k)
    """
Tri Dao's avatar
Tri Dao committed
524
525
    if causal:
        window_size = (window_size[0], 0)
Tri Dao's avatar
Tri Dao committed
526
527
    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
    dropped = ~dropout_mask
Tri Dao's avatar
Tri Dao committed
528
    valid = torch.ones_like(dropout_mask)
Tri Dao's avatar
Tri Dao committed
529
    if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
530
        dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
Tri Dao's avatar
Tri Dao committed
531
        valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
Tri Dao's avatar
Tri Dao committed
532
    if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
533
        dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
Tri Dao's avatar
Tri Dao committed
534
535
536
537
538
539
540
541
542
        valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
    if window_size[0] >= 0 or window_size[1] >= 0:
        local_mask = construct_local_mask(
            seqlen_q,
            seqlen_k,
            window_size,
            query_padding_mask,
            key_padding_mask,
            dropout_mask.device,
Tri Dao's avatar
Tri Dao committed
543
        )
Tri Dao's avatar
Tri Dao committed
544
545
        dropped.masked_fill_(local_mask, False)
        valid.masked_fill_(local_mask, False)
Tri Dao's avatar
Tri Dao committed
546
    dropped_total = dropped.sum()
Tri Dao's avatar
Tri Dao committed
547
    return dropped.sum() / valid.sum()
Tri Dao's avatar
Tri Dao committed
548
549


Tri Dao's avatar
Tri Dao committed
550
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
Tri Dao's avatar
Tri Dao committed
551
# @pytest.mark.parametrize("dtype", [torch.float16])
552
@pytest.mark.parametrize("deterministic", [False, True])
553
# @pytest.mark.parametrize("deterministic", [False])
554
@pytest.mark.parametrize("alibi", [False, True])
555
# @pytest.mark.parametrize("alibi", [False])
Tri Dao's avatar
Tri Dao committed
556
@pytest.mark.parametrize("local", [False, True])
557
# @pytest.mark.parametrize("local", [False])
Tri Dao's avatar
Tri Dao committed
558
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
559
# @pytest.mark.parametrize("causal", [False])
Tri Dao's avatar
Tri Dao committed
560
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
561
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
562
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
Tri Dao's avatar
Tri Dao committed
563
# @pytest.mark.parametrize("d", [64])
Tri Dao's avatar
Tri Dao committed
564
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
565
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
566
# @pytest.mark.parametrize("seqlen", [512])
Tri Dao's avatar
Tri Dao committed
567
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
Tri Dao's avatar
Tri Dao committed
568
# @pytest.mark.parametrize("dropout_p", [0.0])
569
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
Tri Dao's avatar
Tri Dao committed
570
    if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
Tri Dao's avatar
Tri Dao committed
571
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
572
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
573
574
    # set seed
    torch.random.manual_seed(0)
575
    batch_size = 4
Tri Dao's avatar
Tri Dao committed
576
    nheads = 9
Tri Dao's avatar
Tri Dao committed
577
    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
Tri Dao's avatar
Tri Dao committed
578
579
580
    qkv = torch.randn(
        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
    )
581
582
583
584
585
    if alibi:
        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
    else:
        alibi_slopes, attn_bias = None, None
Tri Dao's avatar
Tri Dao committed
586
    out, lse, S_dmask = flash_attn_qkvpacked_func(
587
588
589
590
591
        qkv,
        dropout_p,
        causal=causal,
        window_size=window_size,
        alibi_slopes=alibi_slopes,
592
        deterministic=deterministic,
593
        return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
594
    )
Tri Dao's avatar
Tri Dao committed
595
596
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
Tri Dao's avatar
Tri Dao committed
597
598
599
600
601
602
603
604
605
            S_dmask,
            seqlen,
            seqlen,
            None,
            None,
            d,
            dropout_p > 0.0,
            causal=causal,
            window_size=window_size,
606
        )
Tri Dao's avatar
Tri Dao committed
607
608
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
Tri Dao's avatar
Tri Dao committed
609
610
611
612
613
614
615
        attn = normalize_flash_attn_S(
            attn_unnorm,
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            None,
            None,
616
            attn_bias,
Tri Dao's avatar
Tri Dao committed
617
618
            dropout_p > 0.0,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
619
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
620
        )
Tri Dao's avatar
Tri Dao committed
621
622
623
        dropout_fraction = get_dropout_fraction(
            dropout_mask, None, None, causal=causal, window_size=window_size
        ).item()
Tri Dao's avatar
Tri Dao committed
624
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
625
626
627
    else:
        dropout_mask = None

Tri Dao's avatar
Tri Dao committed
628
    out_ref, attn_ref = attention_qkvpacked_ref(
629
        qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
Tri Dao's avatar
Tri Dao committed
630
    )
Tri Dao's avatar
Tri Dao committed
631
    out_pt, attn_pt = attention_qkvpacked_ref(
Tri Dao's avatar
Tri Dao committed
632
633
        qkv,
        None,
634
        attn_bias,
Tri Dao's avatar
Tri Dao committed
635
636
637
638
639
640
        dropout_p,
        dropout_mask,
        causal=causal,
        window_size=window_size,
        upcast=False,
        reorder_ops=True,
Tri Dao's avatar
Tri Dao committed
641
    )
Tri Dao's avatar
Tri Dao committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    # v = qkv[:, :, 2].float()
    # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
    # if causal:
    #     causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
    #     qk.masked_fill_(causal_mask, float('-inf'))
    # m = qk.amax(-1, keepdim=True)
    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
    # p_tmp = torch.softmax(qk / math.sqrt(d), -1)
    # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
    # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
    # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
    # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
    # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
    # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
    # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
    # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
    # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
Tri Dao's avatar
Tri Dao committed
660
661
662
663
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
664
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
665
666
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
667
668
669
670
671

    g = torch.randn_like(out)
    # do_o = (g.float() * out.float()).sum(-1)
    # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
    # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
672
    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
673
674
675
676
677
678
679
680
681
682
683
        (dqkv,) = torch.autograd.grad(out, qkv, g)
        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
        print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
684
685
686

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
687
688
689
690
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
691
692
693
        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
        if not alibi:
            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
Tri Dao's avatar
Tri Dao committed
694

695
    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
696
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
697
698


Tri Dao's avatar
Tri Dao committed
699
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
Tri Dao's avatar
Tri Dao committed
700
# @pytest.mark.parametrize('dtype', [torch.float16])
701
702
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
703
704
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
Tri Dao's avatar
Tri Dao committed
705
706
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
Tri Dao's avatar
Tri Dao committed
707
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
708
# @pytest.mark.parametrize('causal', [False])
709
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
710
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
711
# @pytest.mark.parametrize('d', [64])
712
@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
Tri Dao's avatar
Tri Dao committed
713
# @pytest.mark.parametrize('seqlen', [128])
Tri Dao's avatar
Tri Dao committed
714
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
Tri Dao's avatar
Tri Dao committed
715
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
716
717
718
def test_flash_attn_varlen_qkvpacked(
    seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
):
Tri Dao's avatar
Tri Dao committed
719
    if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
Tri Dao's avatar
Tri Dao committed
720
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
721
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
722
723
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
724
725
    batch_size = 5
    nheads = 6
Tri Dao's avatar
Tri Dao committed
726
    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
Tri Dao's avatar
Tri Dao committed
727
728
729
    qkv = torch.randn(
        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
    )
Tri Dao's avatar
Tri Dao committed
730

Tri Dao's avatar
Tri Dao committed
731
    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
Tri Dao's avatar
Tri Dao committed
732
    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
733
734
735
736
737
738
739
    if alibi:
        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
        attn_bias = attn_bias_from_alibi_slopes(
            alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
        )
    else:
        alibi_slopes, attn_bias = None, None
Tri Dao's avatar
Tri Dao committed
740

Tri Dao's avatar
Tri Dao committed
741
742
    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
Tri Dao's avatar
Tri Dao committed
743
    )
Tri Dao's avatar
Tri Dao committed
744
745

    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
746
747
748
749
750
751
        qkv_unpad,
        cu_seqlens,
        max_seqlen,
        dropout_p,
        causal=causal,
        window_size=window_size,
752
        alibi_slopes=alibi_slopes,
753
        deterministic=deterministic,
Tri Dao's avatar
Tri Dao committed
754
        return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
755
    )
Tri Dao's avatar
Tri Dao committed
756
757
758
    out = output_pad_fn(out_unpad)
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
759
760
761
762
763
764
765
766
            S_dmask,
            seqlen,
            seqlen,
            key_padding_mask,
            key_padding_mask,
            d,
            dropout_p > 0.0,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
767
            window_size=window_size,
768
        )
Tri Dao's avatar
Tri Dao committed
769
770
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
Tri Dao's avatar
Tri Dao committed
771
772
773
774
775
776
777
        attn = normalize_flash_attn_S(
            attn_unnorm,
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            key_padding_mask,
            key_padding_mask,
778
            attn_bias,
Tri Dao's avatar
Tri Dao committed
779
780
            dropout_p > 0.0,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
781
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
782
783
        )
        dropout_fraction = get_dropout_fraction(
Tri Dao's avatar
Tri Dao committed
784
            dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size
Tri Dao's avatar
Tri Dao committed
785
786
        ).item()
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
787
788
789
    else:
        dropout_mask = None

Tri Dao's avatar
Tri Dao committed
790
    out_ref, attn_ref = attention_qkvpacked_ref(
791
792
793
794
795
796
797
        qkv,
        key_padding_mask,
        attn_bias,
        dropout_p,
        dropout_mask,
        causal=causal,
        window_size=window_size,
Tri Dao's avatar
Tri Dao committed
798
799
800
801
    )
    out_pt, attn_pt = attention_qkvpacked_ref(
        qkv,
        key_padding_mask,
802
        attn_bias,
Tri Dao's avatar
Tri Dao committed
803
804
805
        dropout_p,
        dropout_mask,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
806
        window_size=window_size,
Tri Dao's avatar
Tri Dao committed
807
808
809
810
811
812
813
        upcast=False,
        reorder_ops=True,
    )
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
814
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
815
816
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
817
818

    g = torch.randn_like(out)
819
    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
820
        (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
Tri Dao's avatar
Tri Dao committed
821
        dqkv = dqkv_pad_fn(dqkv_unpad)
Tri Dao's avatar
Tri Dao committed
822
823
824
825
826
827
828
829
830
831
        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
        print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
        print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
832
833
834

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
835
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
836

Tri Dao's avatar
Tri Dao committed
837
838
    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
839
840
841
        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
        if not alibi:
            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
Tri Dao's avatar
Tri Dao committed
842

843
    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
844
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
845
846


Tri Dao's avatar
Tri Dao committed
847
@pytest.mark.parametrize("kvpacked", [True, False])
848
# @pytest.mark.parametrize("kvpacked", [False])
Tri Dao's avatar
Tri Dao committed
849
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
850
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
851
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
852
# @pytest.mark.parametrize("mha_type", ["mha"])
853
854
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
855
856
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
Tri Dao's avatar
Tri Dao committed
857
858
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
Tri Dao's avatar
Tri Dao committed
859
@pytest.mark.parametrize("causal", [False, True])
860
# @pytest.mark.parametrize("causal", [True])
861
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
862
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
863
864
865
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
866
# @pytest.mark.parametrize("d", [64])
Tri Dao's avatar
Tri Dao committed
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (512, 256),
        (1024, 1024),
        (1023, 1024),
        (1024, 1023),
        (2048, 2048),
    ],
)
882
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
Tri Dao's avatar
Tri Dao committed
883
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
884
# @pytest.mark.parametrize("dropout_p", [0.17])
Nicolas Patry's avatar
Nicolas Patry committed
885
@pytest.mark.parametrize("softcap", [0.0, 50.0])
Tri Dao's avatar
Tri Dao committed
886
def test_flash_attn_output(
Nicolas Patry's avatar
Nicolas Patry committed
887
    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
Tri Dao's avatar
Tri Dao committed
888
):
Tri Dao's avatar
Tri Dao committed
889
890
891
892
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
Tri Dao's avatar
Tri Dao committed
893
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
894
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
895
896
    # set seed
    torch.random.manual_seed(0)
897
    batch_size = 4
Tri Dao's avatar
Tri Dao committed
898
899
900
    nheads = 9
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
Tri Dao's avatar
Tri Dao committed
901
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
Tri Dao's avatar
Tri Dao committed
902
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
Nicolas Patry's avatar
Nicolas Patry committed
903
904
905
    if softcap > 0:
        # Ensure the values of qk are at least within softcap range.
        q = q * softcap
Tri Dao's avatar
Tri Dao committed
906
    if kvpacked:
Tri Dao's avatar
Tri Dao committed
907
908
909
        kv = torch.randn(
            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
Tri Dao's avatar
Tri Dao committed
910
    else:
Tri Dao's avatar
Tri Dao committed
911
912
913
914
915
916
        k = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
        v = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
917
918
919
920
921
    if alibi:
        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
    else:
        alibi_slopes, attn_bias = None, None
Tri Dao's avatar
Tri Dao committed
922
923
924

    if kvpacked:
        out, lse, S_dmask = flash_attn_kvpacked_func(
925
926
927
928
929
            q,
            kv,
            dropout_p,
            causal=causal,
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
930
            softcap=softcap,
931
            alibi_slopes=alibi_slopes,
932
            deterministic=deterministic,
933
            return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
934
935
936
        )
    else:
        out, lse, S_dmask = flash_attn_func(
937
938
939
940
941
942
            q,
            k,
            v,
            dropout_p,
            causal=causal,
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
943
            softcap=softcap,
944
            alibi_slopes=alibi_slopes,
945
            deterministic=deterministic,
946
            return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
947
948
949
        )
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
Tri Dao's avatar
Tri Dao committed
950
951
952
953
954
955
956
957
958
            S_dmask,
            seqlen_q,
            seqlen_k,
            None,
            None,
            d,
            dropout_p > 0.0,
            causal=causal,
            window_size=window_size,
959
        )
Tri Dao's avatar
Tri Dao committed
960
961
962
963
964
965
966
967
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        if kvpacked:
            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
            k_rep, v_rep = kv_rep.unbind(dim=2)
        else:
            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
Tri Dao's avatar
Tri Dao committed
968
        attn = normalize_flash_attn_S(
Tri Dao's avatar
Tri Dao committed
969
970
971
972
973
974
            attn_unnorm,
            q,
            k_rep,
            v_rep,
            None,
            None,
975
            attn_bias,
Tri Dao's avatar
Tri Dao committed
976
977
978
            dropout_p > 0.0,
            causal=causal,
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
979
        )
Tri Dao's avatar
Tri Dao committed
980
981
982
        dropout_fraction = get_dropout_fraction(
            dropout_mask, None, None, causal=causal, window_size=window_size
        ).item()
Tri Dao's avatar
Tri Dao committed
983
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
984
985
    else:
        dropout_mask = None
Tri Dao's avatar
Tri Dao committed
986

Tri Dao's avatar
Tri Dao committed
987
    if kvpacked:
Tri Dao's avatar
Tri Dao committed
988
        out_ref, attn_ref = attention_kvpacked_ref(
Tri Dao's avatar
Tri Dao committed
989
990
991
992
            q,
            kv,
            None,
            None,
993
            attn_bias,
Tri Dao's avatar
Tri Dao committed
994
995
996
997
            dropout_p,
            dropout_mask,
            causal=causal,
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
998
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
999
1000
1001
1002
1003
1004
        )
        out_pt, attn_pt = attention_kvpacked_ref(
            q,
            kv,
            None,
            None,
1005
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1006
1007
1008
            dropout_p,
            dropout_mask,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1009
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1010
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1011
1012
1013
            upcast=False,
            reorder_ops=True,
        )
Tri Dao's avatar
Tri Dao committed
1014
    else:
Tri Dao's avatar
Tri Dao committed
1015
        out_ref, attn_ref = attention_ref(
Tri Dao's avatar
Tri Dao committed
1016
1017
1018
1019
1020
            q,
            k,
            v,
            None,
            None,
1021
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1022
1023
1024
1025
            dropout_p,
            dropout_mask,
            causal=causal,
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1026
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1027
1028
1029
1030
1031
1032
1033
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
1034
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1035
1036
1037
            dropout_p,
            dropout_mask,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1038
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1039
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1040
1041
1042
1043
1044
1045
1046
1047
            upcast=False,
            reorder_ops=True,
        )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
1048
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
1049
1050
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
1051
1052
1053

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
1054
    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
Tri Dao's avatar
Tri Dao committed
1055
        if kvpacked:
Tri Dao's avatar
Tri Dao committed
1056
1057
1058
1059
            (
                dq,
                dkv,
            ) = torch.autograd.grad(out, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1060
            dk, dv = dkv.unbind(2)
Tri Dao's avatar
Tri Dao committed
1061
1062
1063
1064
            (
                dq_ref,
                dkv_ref,
            ) = torch.autograd.grad(out_ref, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1065
            dk_ref, dv_ref = dkv_ref.unbind(2)
Tri Dao's avatar
Tri Dao committed
1066
1067
1068
1069
            (
                dq_pt,
                dkv_pt,
            ) = torch.autograd.grad(out_pt, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1070
1071
            dk_pt, dv_pt = dkv_pt.unbind(2)
        else:
Tri Dao's avatar
Tri Dao committed
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
            (
                dq,
                dk,
                dv,
            ) = torch.autograd.grad(out, (q, k, v), g)
            (
                dq_ref,
                dk_ref,
                dv_ref,
            ) = torch.autograd.grad(out_ref, (q, k, v), g)
            (
                dq_pt,
                dk_pt,
                dv_pt,
            ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
1099
1100
1101

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
1102
1103
1104
1105
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
1106
1107
1108
        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
        if not alibi:
            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
Tri Dao's avatar
Tri Dao committed
1109

1110
    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
Tri Dao's avatar
Tri Dao committed
1111
1112
1113
1114
1115
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()


Tri Dao's avatar
Tri Dao committed
1116
@pytest.mark.parametrize("kvpacked", [True, False])
Tri Dao's avatar
Tri Dao committed
1117
# @pytest.mark.parametrize('kvpacked', [False])
Tri Dao's avatar
Tri Dao committed
1118
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
1119
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
1120
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
Tri Dao's avatar
Tri Dao committed
1121
# @pytest.mark.parametrize('mha_type', ["mqa"])
1122
1123
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
1124
1125
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
Tri Dao's avatar
Tri Dao committed
1126
1127
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
Tri Dao's avatar
Tri Dao committed
1128
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
1129
# @pytest.mark.parametrize('causal', [True])
1130
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
1131
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
1132
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
1133
1134
1135
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
1136
        (1, 147),
Tri Dao's avatar
Tri Dao committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (512, 256),
        (1024, 1024),
        (1023, 1024),
        (1024, 1023),
        (2048, 2048),
    ],
)
Tri Dao's avatar
Tri Dao committed
1149
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
Tri Dao's avatar
Tri Dao committed
1150
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
Nicolas Patry's avatar
Nicolas Patry committed
1151
@pytest.mark.parametrize("softcap", [0.0, 50.0])
1152
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
1153
def test_flash_attn_varlen_output(
Nicolas Patry's avatar
Nicolas Patry committed
1154
    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
Tri Dao's avatar
Tri Dao committed
1155
1156
1157
1158
1159
):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
1160
        pytest.skip()  # Reference implementation OOM
Tri Dao's avatar
Tri Dao committed
1161
    device = "cuda"
1162
1163
    # set seed
    torch.random.manual_seed(0)
1164
    batch_size = 4
Tri Dao's avatar
Tri Dao committed
1165
1166
1167
    nheads = 9
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
Tri Dao's avatar
Tri Dao committed
1168
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
Tri Dao's avatar
Tri Dao committed
1169
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
Nicolas Patry's avatar
Nicolas Patry committed
1170
1171
1172
    if softcap > 0:
        # Ensure the values of qk are at least within softcap range.
        q = q * softcap
1173

Tri Dao's avatar
Tri Dao committed
1174
    if kvpacked:
Tri Dao's avatar
Tri Dao committed
1175
1176
1177
        kv = torch.randn(
            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
1178
    else:
Tri Dao's avatar
Tri Dao committed
1179
1180
1181
1182
1183
1184
        k = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
        v = torch.randn(
            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
        )
Tri Dao's avatar
Tri Dao committed
1185

Tri Dao's avatar
Tri Dao committed
1186
1187
    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
Tri Dao's avatar
Tri Dao committed
1188
    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
1189
1190
1191
1192
1193
1194
1195
    if alibi:
        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
        attn_bias = attn_bias_from_alibi_slopes(
            alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
        )
    else:
        alibi_slopes, attn_bias = None, None
Tri Dao's avatar
Tri Dao committed
1196
1197

    if kvpacked:
Tri Dao's avatar
Tri Dao committed
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
        (
            q_unpad,
            kv_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q,
            kv,
            output_pad_fn,
            dq_pad_fn,
            dkv_pad_fn,
        ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
Tri Dao's avatar
Tri Dao committed
1211
        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
1212
1213
1214
1215
1216
1217
1218
1219
            q_unpad,
            kv_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1220
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1221
            softcap=softcap,
1222
            alibi_slopes=alibi_slopes,
1223
            deterministic=deterministic,
1224
            return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
1225
1226
        )
    else:
Tri Dao's avatar
Tri Dao committed
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        (
            q_unpad,
            k_unpad,
            v_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            q,
            k,
            v,
            output_pad_fn,
            dq_pad_fn,
            dk_pad_fn,
        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
Tri Dao's avatar
Tri Dao committed
1242
        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
Tri Dao's avatar
Tri Dao committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
            q_unpad,
            k_unpad,
            v_unpad,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            dropout_p,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1252
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1253
            softcap=softcap,
1254
            alibi_slopes=alibi_slopes,
1255
            deterministic=deterministic,
1256
            return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
1257
        )
Tri Dao's avatar
Tri Dao committed
1258
1259
    out = output_pad_fn(out_unpad)
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
1260
        S_dmask_converted = convert_flash_attn_S_to_softmax(
1261
1262
1263
1264
1265
1266
1267
1268
            S_dmask,
            seqlen_q,
            seqlen_k,
            query_padding_mask,
            key_padding_mask,
            d,
            dropout_p > 0.0,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1269
            window_size=window_size,
1270
        )
Tri Dao's avatar
Tri Dao committed
1271
1272
1273
1274
1275
1276
1277
1278
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        if kvpacked:
            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
            k_rep, v_rep = kv_rep.unbind(dim=2)
        else:
            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
Tri Dao's avatar
Tri Dao committed
1279
1280
1281
1282
1283
1284
1285
        attn = normalize_flash_attn_S(
            attn_unnorm,
            q,
            k_rep,
            v_rep,
            query_padding_mask,
            key_padding_mask,
1286
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1287
1288
            dropout_p > 0.0,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1289
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
1290
1291
        )
        dropout_fraction = get_dropout_fraction(
Tri Dao's avatar
Tri Dao committed
1292
1293
1294
1295
1296
            dropout_mask,
            query_padding_mask,
            key_padding_mask,
            causal=causal,
            window_size=window_size,
Tri Dao's avatar
Tri Dao committed
1297
1298
        ).item()
        print(f"Actual dropout fraction: {dropout_fraction}")
Tri Dao's avatar
Tri Dao committed
1299
1300
1301
1302
    else:
        dropout_mask = None

    if kvpacked:
Tri Dao's avatar
Tri Dao committed
1303
        out_ref, attn_ref = attention_kvpacked_ref(
Tri Dao's avatar
Tri Dao committed
1304
1305
1306
1307
            q,
            kv,
            query_padding_mask,
            key_padding_mask,
1308
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1309
1310
1311
1312
            dropout_p,
            dropout_mask,
            causal=causal,
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1313
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1314
1315
1316
1317
1318
1319
        )
        out_pt, attn_pt = attention_kvpacked_ref(
            q,
            kv,
            query_padding_mask,
            key_padding_mask,
1320
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1321
1322
1323
            dropout_p,
            dropout_mask,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1324
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1325
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1326
1327
1328
            upcast=False,
            reorder_ops=True,
        )
Tri Dao's avatar
Tri Dao committed
1329
    else:
Tri Dao's avatar
Tri Dao committed
1330
        out_ref, attn_ref = attention_ref(
Tri Dao's avatar
Tri Dao committed
1331
1332
1333
1334
1335
            q,
            k,
            v,
            query_padding_mask,
            key_padding_mask,
1336
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1337
1338
1339
1340
            dropout_p,
            dropout_mask,
            causal=causal,
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1341
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1342
1343
1344
1345
1346
1347
1348
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            query_padding_mask,
            key_padding_mask,
1349
            attn_bias,
Tri Dao's avatar
Tri Dao committed
1350
1351
1352
            dropout_p,
            dropout_mask,
            causal=causal,
Tri Dao's avatar
Tri Dao committed
1353
            window_size=window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1354
            softcap=softcap,
Tri Dao's avatar
Tri Dao committed
1355
1356
1357
1358
1359
1360
1361
1362
            upcast=False,
            reorder_ops=True,
        )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
1363
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
1364
1365
        print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
        print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
1366
1367

    g = torch.randn_like(out)
1368
    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
Tri Dao's avatar
Tri Dao committed
1369
        if kvpacked:
Tri Dao's avatar
Tri Dao committed
1370
1371
1372
1373
            (
                dq_unpad,
                dkv_unpad,
            ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
Tri Dao's avatar
Tri Dao committed
1374
            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
Tri Dao's avatar
Tri Dao committed
1375
1376
1377
1378
            (
                dq_ref,
                dkv_ref,
            ) = torch.autograd.grad(out_ref, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1379
            dk_ref, dv_ref = dkv_ref.unbind(2)
Tri Dao's avatar
Tri Dao committed
1380
1381
1382
1383
            (
                dq_pt,
                dkv_pt,
            ) = torch.autograd.grad(out_pt, (q, kv), g)
Tri Dao's avatar
Tri Dao committed
1384
1385
            dk_pt, dv_pt = dkv_pt.unbind(2)
        else:
Tri Dao's avatar
Tri Dao committed
1386
1387
1388
1389
1390
            (
                dq_unpad,
                dk_unpad,
                dv_unpad,
            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
Tri Dao's avatar
Tri Dao committed
1391
1392
            dk = dk_pad_fn(dk_unpad)
            dv = dk_pad_fn(dv_unpad)
Tri Dao's avatar
Tri Dao committed
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
            (
                dq_ref,
                dk_ref,
                dv_ref,
            ) = torch.autograd.grad(out_ref, (q, k, v), g)
            (
                dq_pt,
                dk_pt,
                dv_pt,
            ) = torch.autograd.grad(out_pt, (q, k, v), g)
Tri Dao's avatar
Tri Dao committed
1403
        dq = dq_pad_fn(dq_unpad)
Tri Dao's avatar
Tri Dao committed
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
1416
1417
1418

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
1419
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
1420

Tri Dao's avatar
Tri Dao committed
1421
1422
    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
1423
1424
1425
        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
        if not alibi:
            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
Tri Dao's avatar
Tri Dao committed
1426

1427
    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
1428
1429
1430
        assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
1431

1432

Tri Dao's avatar
Tri Dao committed
1433
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
1434
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
1435
1436
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
1437
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
1438
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (3, 799),
        (127, 512),
        (127, 513),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (1023, 1024),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
Tri Dao's avatar
Tri Dao committed
1461
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    causal = True
    # set seed
    torch.random.manual_seed(0)
1473
    batch_size = 8
1474
    nheads = 9
Tri Dao's avatar
Tri Dao committed
1475
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
1476
1477
1478
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
Tri Dao's avatar
Tri Dao committed
1479
1480
    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
    out_ref, attn_ref = attention_ref(
1481
        q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
Tri Dao's avatar
Tri Dao committed
1482
    )
1483
1484
1485
1486
1487
1488
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        None,
        None,
1489
        None,
1490
1491
1492
        0.0,
        None,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
1493
        window_size=window_size,
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
        upcast=False,
        reorder_ops=True,
    )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
1505
    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
        (
            dq,
            dk,
            dv,
        ) = torch.autograd.grad(out, (q, k, v), g)
        (
            dq_ref,
            dk_ref,
            dv_ref,
        ) = torch.autograd.grad(out_ref, (q, k, v), g)
        (
            dq_pt,
            dk_pt,
            dv_pt,
        ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

1538
    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
1539
1540
1541
1542
1543
1544
1545
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5


@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
1546
1547
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
1548
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
1549
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
1550
1551
1552
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
Tri Dao's avatar
Tri Dao committed
1553
# @pytest.mark.parametrize("d", [64])
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (3, 799),
        (127, 512),
        (127, 513),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (1023, 1024),
    ],
)
1571
1572
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
1573
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
1574
1575
1576
def test_flash_attn_varlen_causal(
    seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
):
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    causal = True
    # set seed
    torch.random.manual_seed(0)
1588
    batch_size = 8
1589
    nheads = 9
Tri Dao's avatar
Tri Dao committed
1590
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
1591
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604

    if paged_kv_block_size is None:
        k = torch.randn(
            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
        )
        v = torch.randn(
            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
        )
        block_table = None
    else:
        k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
            seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
        )
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
    (
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        q,
        k,
        v,
        output_pad_fn,
        dq_pad_fn,
        dk_pad_fn,
    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
    out_unpad = flash_attn_varlen_func(
        q_unpad,
1624
1625
        k_unpad if paged_kv_block_size is None else k_cache_paged,
        v_unpad if paged_kv_block_size is None else v_cache_paged,
1626
1627
1628
1629
1630
1631
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        0.0,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
1632
        window_size=window_size,
1633
        block_table=block_table,
1634
1635
1636
    )
    out = output_pad_fn(out_unpad)
    out_ref, attn_ref = attention_ref(
Tri Dao's avatar
Tri Dao committed
1637
1638
1639
1640
1641
        q,
        k,
        v,
        query_padding_mask,
        key_padding_mask,
1642
        None,
Tri Dao's avatar
Tri Dao committed
1643
1644
1645
1646
        0.0,
        None,
        causal=causal,
        window_size=window_size,
1647
1648
1649
1650
1651
1652
1653
    )
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        query_padding_mask,
        key_padding_mask,
1654
        None,
1655
1656
1657
        0.0,
        None,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
1658
        window_size=window_size,
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
        upcast=False,
        reorder_ops=True,
    )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
1670
1671
    test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
    if test_backward:
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
        (
            dq_unpad,
            dk_unpad,
            dv_unpad,
        ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
        dq = dq_pad_fn(dq_unpad)
        dk = dk_pad_fn(dk_unpad)
        dv = dk_pad_fn(dv_unpad)
        (
            dq_ref,
            dk_ref,
            dv_ref,
        ) = torch.autograd.grad(out_ref, (q, k, v), g)
        (
            dq_pt,
            dk_pt,
            dv_pt,
        ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

1707
    if test_backward:
1708
1709
1710
1711
1712
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5


Tri Dao's avatar
Tri Dao committed
1713
1714
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
1715
1716
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
1717
1718
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
Tri Dao's avatar
Tri Dao committed
1719
@pytest.mark.parametrize("local", [False, True])
1720
# @pytest.mark.parametrize("local", [False])
Tri Dao's avatar
Tri Dao committed
1721
1722
1723
1724
1725
1726
1727
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
1728
# @pytest.mark.parametrize("d", [64])
Tri Dao's avatar
Tri Dao committed
1729
1730
1731
1732
1733
1734
1735
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (3, 1024),
        (1, 339),
1736
        (64, 800),
Tri Dao's avatar
Tri Dao committed
1737
1738
1739
1740
1741
1742
1743
1744
1745
        (3, 799),
        (64, 2048),
        (16, 20000),
        (16, 100000),
        (128, 128),
        (256, 256),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
Tri Dao's avatar
Tri Dao committed
1746
1747
1748
def test_flash_attn_splitkv(
    seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype
):
Tri Dao's avatar
Tri Dao committed
1749
1750
1751
1752
1753
1754
1755
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    batch_size = 1
    nheads = 12
Tri Dao's avatar
Tri Dao committed
1756
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
Tri Dao's avatar
Tri Dao committed
1757
1758
1759
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
1760
1761
1762
1763
1764
    if alibi:
        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
    else:
        alibi_slopes, attn_bias = None, None
Tri Dao's avatar
Tri Dao committed
1765
    out, lse, _ = flash_attn_func(
1766
1767
1768
1769
1770
1771
1772
        q,
        k,
        v,
        0.0,
        causal=causal,
        window_size=window_size,
        alibi_slopes=alibi_slopes,
1773
        deterministic=deterministic,
1774
        return_attn_probs=True,
Tri Dao's avatar
Tri Dao committed
1775
1776
    )
    out_ref, attn_ref = attention_ref(
1777
        q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size
Tri Dao's avatar
Tri Dao committed
1778
    )
Tri Dao's avatar
Tri Dao committed
1779
1780
1781
1782
1783
1784
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        None,
        None,
1785
        attn_bias,
Tri Dao's avatar
Tri Dao committed
1786
1787
1788
        0.0,
        None,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
1789
        window_size=window_size,
Tri Dao's avatar
Tri Dao committed
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
        upcast=False,
        reorder_ops=True,
    )

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
1801
    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
        (
            dq,
            dk,
            dv,
        ) = torch.autograd.grad(out, (q, k, v), g)
        (
            dq_ref,
            dk_ref,
            dv_ref,
        ) = torch.autograd.grad(out_ref, (q, k, v), g)
        (
            dq_pt,
            dk_pt,
            dv_pt,
        ) = torch.autograd.grad(out_pt, (q, k, v), g)
        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

1834
    mult = 2 if not alibi else 8
1835
    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
1836
1837
1838
        assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
        assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
        assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
Tri Dao's avatar
Tri Dao committed
1839

1840

1841
1842
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
Tri Dao's avatar
Tri Dao committed
1843
@pytest.mark.parametrize("num_splits", [1, 0])
1844
# @pytest.mark.parametrize("num_splits", [1])
Tri Dao's avatar
Tri Dao committed
1845
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
1846
# @pytest.mark.parametrize("mha_type", ["mha"])
Tri Dao's avatar
Tri Dao committed
1847
@pytest.mark.parametrize("new_kv", [False, True])
1848
1849
# @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [False, True])
Tri Dao's avatar
Tri Dao committed
1850
# @pytest.mark.parametrize("alibi", [False])
Tri Dao's avatar
Tri Dao committed
1851
@pytest.mark.parametrize("local", [False, True])
1852
# @pytest.mark.parametrize("local", [False])
Tri Dao's avatar
Tri Dao committed
1853
@pytest.mark.parametrize("causal", [False, True])
1854
# @pytest.mark.parametrize("causal", [False])
1855
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
1856
1857
1858
1859
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
1860
# @pytest.mark.parametrize("rotary_fraction", [0.0])
1861
1862
1863
@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256])
1864
@pytest.mark.parametrize("has_batch_idx", [False, True])
1865
# @pytest.mark.parametrize("has_batch_idx", [False])
Tri Dao's avatar
Tri Dao committed
1866
1867
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
1868
1869
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
1870
# @pytest.mark.parametrize("d", [128])
Tri Dao's avatar
Tri Dao committed
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 128),
        (1, 339),
        (3, 1024),
        (64, 800),
        (64, 256),
        (3, 799),
        (64, 2048),
        (16, 20000),
        (1, 128 * 1024),
        (16, 128 * 1024),
        (128, 128),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
1888
def test_flash_attn_kvcache(
1889
1890
1891
    seqlen_q,
    seqlen_k,
    d,
1892
    has_batch_idx,
Tri Dao's avatar
Tri Dao committed
1893
    paged_kv_block_size,
1894
1895
1896
1897
    rotary_fraction,
    rotary_interleaved,
    seqlen_new_eq_seqlen_q,
    causal,
Tri Dao's avatar
Tri Dao committed
1898
    local,
1899
    alibi,
1900
1901
1902
1903
    new_kv,
    mha_type,
    num_splits,
    dtype,
1904
):
Tri Dao's avatar
Tri Dao committed
1905
1906
    if seqlen_q > seqlen_k and new_kv:
        pytest.skip()
1907
1908
    if not new_kv and rotary_fraction > 0.0:
        pytest.skip()
Tri Dao's avatar
Tri Dao committed
1909
1910
    if has_batch_idx and paged_kv_block_size is not None:
        pytest.skip()
Tri Dao's avatar
Tri Dao committed
1911
1912
1913
1914
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    batch_size = 2
1915
    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
Tri Dao's avatar
Tri Dao committed
1916
    nheads = 6
1917
1918
    # rotary_dim must be a multiple of 16, and must be <= d
    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
Tri Dao's avatar
Tri Dao committed
1919
1920
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
Tri Dao's avatar
Tri Dao committed
1921
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
Tri Dao's avatar
Tri Dao committed
1922
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
1923
    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
Tri Dao's avatar
Tri Dao committed
1924
    if new_kv:
1925
1926
        k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
        v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
1927
1928
    else:
        k, v = None, None
Tri Dao's avatar
Tri Dao committed
1929
1930
1931
1932
1933
    if paged_kv_block_size is None:
        k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
        v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
        block_table = None
    else:
1934
1935
1936
1937
1938
1939
1940
1941
1942
        (
            k_cache,
            v_cache,
            block_table,
            k_cache_paged,
            v_cache_paged,
            num_blocks,
        ) = _generate_block_kvcache(
            seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
Tri Dao's avatar
Tri Dao committed
1943
        )
1944
    cache_seqlens = torch.randint(
Tri Dao's avatar
Tri Dao committed
1945
        0 if new_kv else 1,
1946
        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
1947
1948
1949
1950
1951
        (
            (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
            if new_kv
            else (seqlen_k + 1)
        ),
1952
1953
1954
1955
        (batch_size,),
        dtype=torch.int32,
        device=device,
    )
1956
1957
1958
    arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
    cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
1959
    if has_batch_idx:
1960
1961
1962
        cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
            :batch_size
        ]
1963
1964
    else:
        cache_batch_idx = None
1965
1966
1967
1968
1969
1970
1971
    if alibi:
        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
        attn_bias = attn_bias_from_alibi_slopes(
            alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
        )
    else:
        alibi_slopes, attn_bias = None, None
Tri Dao's avatar
Tri Dao committed
1972
    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
1973
    if rotary_dim > 0:
Tri Dao's avatar
Tri Dao committed
1974
1975
1976
1977
1978
1979
1980
1981
1982
        angle = (
            torch.rand(
                seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
                rotary_dim // 2,
                device=device,
            )
            * 2
            * math.pi
        )
1983
1984
        cos = torch.cos(angle).to(dtype=dtype)
        sin = torch.sin(angle).to(dtype=dtype)
Tri Dao's avatar
Tri Dao committed
1985
        if causal or local:
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
            q_ro = apply_rotary_emb(
                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
            )
        else:
            q_ro = rearrange(
                apply_rotary_emb(
                    rearrange(q, "b s h d -> b 1 (s h) d"),
                    cos,
                    sin,
                    seqlen_offsets=cache_seqlens,
                    interleaved=rotary_interleaved,
                ),
                "b 1 (s h) d -> b s h d",
                s=seqlen_q,
            )
        # q_ro = q
        k_ro = apply_rotary_emb(
            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
        )
    else:
        cos, sin = None, None
        q_ro, k_ro = q, k
Tri Dao's avatar
Tri Dao committed
2008
    # k_cache[:, 64:] = -1
2009
2010
2011
2012
2013
2014
    k_cache_ref = (
        k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
    ).clone()
    v_cache_ref = (
        v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
    ).clone()
Tri Dao's avatar
Tri Dao committed
2015
    if new_kv:
2016
2017
2018
        update_mask = torch.logical_and(
            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
        )
2019
        k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
Tri Dao's avatar
Tri Dao committed
2020
2021
2022
        v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
    k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
    v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
2023
    out = flash_attn_with_kvcache(
2024
        q,
Tri Dao's avatar
Tri Dao committed
2025
2026
        k_cache if paged_kv_block_size is None else k_cache_paged,
        v_cache if paged_kv_block_size is None else v_cache_paged,
2027
2028
        k,
        v,
Tri Dao's avatar
Tri Dao committed
2029
2030
2031
2032
2033
        rotary_cos=cos,
        rotary_sin=sin,
        cache_seqlens=cache_seqlens,
        cache_batch_idx=cache_batch_idx,
        block_table=block_table,
2034
        causal=causal,
Tri Dao's avatar
Tri Dao committed
2035
        window_size=window_size,
2036
        rotary_interleaved=rotary_interleaved,
2037
        alibi_slopes=alibi_slopes,
2038
        num_splits=num_splits,
2039
    )
Tri Dao's avatar
Tri Dao committed
2040
2041
2042
2043
    # out = flash_attn_with_kvcache(
    #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
    # )
    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
Tri Dao's avatar
Tri Dao committed
2044
2045
2046
2047
2048
2049
    # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
    # m = qk.amax(-1, keepdim=True)
    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
    # probs = torch.softmax(qk, dim=-1)
2050
    out_ref, _ = attention_ref(
Tri Dao's avatar
Tri Dao committed
2051
2052
2053
2054
2055
        q_ro,
        k_cache_rep,
        v_cache_rep,
        None,
        key_padding_mask,
2056
        attn_bias,
Tri Dao's avatar
Tri Dao committed
2057
2058
2059
2060
        0.0,
        None,
        causal=causal,
        window_size=window_size,
2061
2062
    )
    out_pt, _ = attention_ref(
2063
        q_ro,
2064
2065
2066
2067
        k_cache_rep,
        v_cache_rep,
        None,
        key_padding_mask,
2068
        attn_bias,
2069
2070
2071
        0.0,
        None,
        causal=causal,
Tri Dao's avatar
Tri Dao committed
2072
        window_size=window_size,
2073
2074
2075
        upcast=False,
        reorder_ops=True,
    )
Tri Dao's avatar
Tri Dao committed
2076
2077
2078
2079
2080
2081
2082
2083
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    if new_kv:
Tri Dao's avatar
Tri Dao committed
2084
        if paged_kv_block_size is None:
2085
2086
2087
2088
2089
2090
            k_cache_select = (
                k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
            )
            v_cache_select = (
                v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
            )
Tri Dao's avatar
Tri Dao committed
2091
2092
        else:
            k_cache_select = rearrange(
2093
                k_cache_paged[block_table.to(dtype=torch.long).flatten()],
Tri Dao's avatar
Tri Dao committed
2094
2095
2096
2097
                "(b nblocks) block_size ... -> b (nblocks block_size) ...",
                b=batch_size,
            )[:, :seqlen_k]
            v_cache_select = rearrange(
2098
                v_cache_paged[block_table.to(dtype=torch.long).flatten()],
Tri Dao's avatar
Tri Dao committed
2099
2100
2101
                "(b nblocks) block_size ... -> b (nblocks block_size) ...",
                b=batch_size,
            )[:, :seqlen_k]
2102
2103
        assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
        assert torch.equal(v_cache_select, v_cache_ref)
2104
2105
    mult = 3 if not alibi else 5
    assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
Tri Dao's avatar
Tri Dao committed
2106

Tri Dao's avatar
Tri Dao committed
2107

2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
    num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
    k_cache_paged = torch.randn(
        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
    )
    v_cache_paged = torch.randn(
        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
    )
    block_table = rearrange(
        torch.randperm(num_blocks, dtype=torch.int32, device=device),
        "(b nblocks) -> b nblocks",
        b=batch_size,
    )
    k_cache = rearrange(
        # pytorch 1.12 doesn't have indexing with int32
        k_cache_paged[block_table.to(dtype=torch.long).flatten()],
        "(b nblocks) block_size ... -> b (nblocks block_size) ...",
        b=batch_size,
    )[:, :seqlen_k]
    v_cache = rearrange(
        v_cache_paged[block_table.to(dtype=torch.long).flatten()],
        "(b nblocks) block_size ... -> b (nblocks block_size) ...",
        b=batch_size,
    )[:, :seqlen_k]
    return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks


2135
2136
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
Tri Dao's avatar
Tri Dao committed
2137
@pytest.mark.parametrize("causal", [False, True])
2138
2139
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
2140
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
2141
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
Tri Dao's avatar
Tri Dao committed
2142
# @pytest.mark.parametrize('d', [128])
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (239, 1),
        (3, 799),
        (799, 3),
        (1024, 128),
        (97, 97),
        (128, 128),
        (200, 200),
        (256, 256),
        (257, 257),
        (384, 384),
        (512, 512),
        (768, 768),
        (1024, 1024),
    ],
)
2162
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
2163
2164
# @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
Tri Dao's avatar
Tri Dao committed
2165
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
2166
2167
    # set seed
    torch.random.manual_seed(0)
2168
    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger
Tri Dao's avatar
Tri Dao committed
2169
    nheads = 4
2170
2171
2172
2173
2174
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    torch.random.manual_seed(42)
    out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
Tri Dao's avatar
Tri Dao committed
2175
    g = torch.randn_like(out0)
2176
    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
2177
2178
2179
2180
2181
        (
            dq0,
            dk0,
            dv0,
        ) = torch.autograd.grad(out0, (q, k, v), g)
2182
        # Numerical error if we just do any arithmetic on dq
2183
        dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
Tri Dao's avatar
Tri Dao committed
2184

2185
2186
2187
    for i in range(250):
        torch.random.manual_seed(42)
        out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
Tri Dao's avatar
Tri Dao committed
2188
2189
        assert torch.equal(out, out0)
        assert torch.equal(lse, lse0)
Tri Dao's avatar
Tri Dao committed
2190

2191
        if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
2192
2193
2194
2195
2196
2197
            (
                dq,
                dk,
                dv,
            ) = torch.autograd.grad(out, (q, k, v), g)
            dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
2198
            if not dq_equal:
2199
2200
2201
                print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
            assert torch.equal(dv, dv0)
            assert torch.equal(dk, dk0)
2202
            assert dq_equal
2203
2204


Tri Dao's avatar
Tri Dao committed
2205
2206
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
2207
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
2208
@pytest.mark.parametrize("d", [16, 32, 64])
2209
# @pytest.mark.parametrize('d', [16])
Tri Dao's avatar
Tri Dao committed
2210
@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
2211
2212
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
Tri Dao's avatar
Tri Dao committed
2213
    """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
2214
2215
    in the case where seqlen % 128 != 0.
    """
Tri Dao's avatar
Tri Dao committed
2216
    device = "cuda"
2217
2218
2219
2220
2221
    # set seed
    torch.random.manual_seed(0)
    batch_size = 2
    nheads = 5
    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
Tri Dao's avatar
Tri Dao committed
2222
2223
2224
2225
    k, v = [
        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
        for _ in range(2)
    ]
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
    q.requires_grad_(True)
    k.requires_grad_(True)
    v.requires_grad_(True)
    out = flash_attn_func(q, k, v, causal=causal)
    g = torch.randn_like(out)
    out.backward(g)
    q_pt = q.detach().clone().requires_grad_(True)
    k_pt = k.detach().clone().requires_grad_(True)
    v_pt = v.detach().clone().requires_grad_(True)
    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
    out_pt.backward(g)
    q_ref = q.detach().clone().requires_grad_(True)
    k_ref = k.detach().clone().requires_grad_(True)
    v_ref = v.detach().clone().requires_grad_(True)
    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
    out_ref.backward(g)
Tri Dao's avatar
Tri Dao committed
2242
2243
2244
2245
2246
2247
    print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
    print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
    print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
    print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
    print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
    print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
2248
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
    assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
        q_pt.grad - q_ref.grad
    ).abs().max().item() + 1e-3
    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
        k_pt.grad - k_ref.grad
    ).abs().max().item() + 1e-3
    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
        v_pt.grad - v_ref.grad
    ).abs().max().item() + 1e-3


@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
2261
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
2262
@pytest.mark.parametrize("causal", [False, True])
2263
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
2264
@pytest.mark.parametrize("d", [64, 128])
2265
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
2266
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
2267
2268
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
Tri Dao's avatar
Tri Dao committed
2269
    """We previously had a bug where we were using the wrong strides of dout, which shows up
2270
2271
    when dout is not contiguous.
    """
Tri Dao's avatar
Tri Dao committed
2272
    device = "cuda"
2273
2274
2275
2276
    # set seed
    torch.random.manual_seed(0)
    batch_size = 5
    nheads = 2
Tri Dao's avatar
Tri Dao committed
2277
2278
2279
2280
    q, k, v = [
        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
        for _ in range(3)
    ]
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
    out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
    # So g is not contiguous
    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
    out.backward(g)
    q_pt = q.detach().clone().requires_grad_(True)
    k_pt = k.detach().clone().requires_grad_(True)
    v_pt = v.detach().clone().requires_grad_(True)
    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
    out_pt = rearrange(out_pt, "b s ... -> s b ...")
    out_pt.backward(g)
    q_ref = q.detach().clone().requires_grad_(True)
    k_ref = k.detach().clone().requires_grad_(True)
    v_ref = v.detach().clone().requires_grad_(True)
    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
    out_ref = rearrange(out_ref, "b s ... -> s b ...")
    out_ref.backward(g)
Tri Dao's avatar
Tri Dao committed
2297
2298
2299
2300
2301
2302
    print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
    print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
    print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
    print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
    print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
    print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
2303
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
        q_pt.grad - q_ref.grad
    ).abs().max().item()
    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
        k_pt.grad - k_ref.grad
    ).abs().max().item()
    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
        v_pt.grad - v_ref.grad
    ).abs().max().item()


@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
2317
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
2318
@pytest.mark.parametrize("d", [16, 32, 64])
2319
2320
# @pytest.mark.parametrize('d', [16])
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
Tri Dao's avatar
Tri Dao committed
2321
    """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
2322
2323
    in the case where seqlen % 128 != 0 or varlen.
    """
Tri Dao's avatar
Tri Dao committed
2324
    device = "cuda"
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
    # set seed
    torch.random.manual_seed(0)
    nheads = 5
    q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
    k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
    Mq = 256
    Mk = 3

    q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
    k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
    q.requires_grad_(True)
    k.requires_grad_(True)
    v.requires_grad_(True)

    out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
    g = torch.randn_like(out)
    out.backward(g)

    assert not q.grad.isnan().any()
    assert not k.grad.isnan().any()
    assert not v.grad.isnan().any()
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397


@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (3, 799),
        (127, 512),
        (127, 513),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (1023, 1024),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    batch_size = 4
    nheads = 9
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)

    g = torch.randn_like(out)
2398
    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
        dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
        for _ in range(50):
            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
            assert torch.equal(dv, dv0)
            assert torch.equal(dk, dk0)
            assert torch.equal(dq, dq0)


@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 239),
        (3, 799),
        (127, 512),
        (127, 513),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
        (1023, 1024),
    ],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    if swap_sq_sk:
        seqlen_q, seqlen_k = seqlen_k, seqlen_q
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    batch_size = 2
    nheads = 9
    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
    (
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        q,
        k,
        v,
        output_pad_fn,
        dq_pad_fn,
        dk_pad_fn,
    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
    out = flash_attn_varlen_func(
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        0.0,
        causal=causal,
        window_size=window_size,
        deterministic=True,
    )

    g = torch.randn_like(out)
2486
    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
2487
        dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
2488
2489
        for _ in range(50):
            dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
2490
2491
2492
            assert torch.equal(dv, dv0)
            assert torch.equal(dk, dk0)
            assert torch.equal(dq, dq0)