test_flash_attn.py 45.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
import math

import torch
import torch.nn.functional as F

import pytest

from einops import rearrange, repeat

from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func
11
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
Tri Dao's avatar
Tri Dao committed
12
13
14
15
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis


is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
16
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
Tri Dao's avatar
Tri Dao committed
17
18
19


def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
20
    assert mode in ['full', 'random', 'third', 'split']
Tri Dao's avatar
Tri Dao committed
21
22
23
    if mode == 'full':
        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
    elif mode == 'random':
24
        lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
Tri Dao's avatar
Tri Dao committed
25
    elif mode == 'third':
26
27
28
29
30
31
32
        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
    elif mode == 'split':
        lengths0 = torch.randint(min(128, max_seqlen), max_seqlen + 1,
                                 (batch_size // 4 * 3, 1), device=device)
        lengths1 = torch.randint(min(max(1, max_seqlen - 20), 128), min(max_seqlen, 128) + 1,
                                 (batch_size - batch_size // 4 * 3, 1), device=device)
        lengths = torch.cat([lengths0, lengths1], dim=0)
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths
    return padding_mask


def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None,
                 kvpacked=False, qkvpacked=False):
    """
    Arguments:
        x: (batch_size, seqlen, nheads * d)
        Wqkv: nn.Linear(nheads * d, 3 * nheads * d)
        query_padding_mask: (batch_size, seqlen), bool
        key_padding_mask: (batch_size, seqlen), bool
    """
    assert not (kvpacked and qkvpacked)
    batch_size, seqlen, dim = x.shape
    q, k, v = Wqkv(x).chunk(3, dim=-1)

    if query_padding_mask is not None:
        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
        q_unpad = rearrange(q_unpad, 'nnz (h d) -> nnz h d', h=nheads)
        output_pad_fn = lambda output_unpad: rearrange(
            pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen),
            'b s (h d) -> b s h d', h=nheads
        )
    else:
        q_unpad = rearrange(q, 'b s (h d) -> (b s) h d', h=nheads)
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                    device=q_unpad.device)
        max_seqlen_q = seqlen
        output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size)

    if key_padding_mask is not None:
        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
        k_unpad = rearrange(k_unpad, 'nnz (h d) -> nnz h d', h=nheads)
        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
        v_unpad = rearrange(v_unpad, 'nnz (h d) -> nnz h d', h=nheads)
    else:
        k_unpad = rearrange(k, 'b s (h d) -> (b s) h d', h=nheads)
        v_unpad = rearrange(v, 'b s (h d) -> (b s) h d', h=nheads)
        cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                    device=q_unpad.device)
        max_seqlen_k = seqlen

    if qkvpacked:
        assert (query_padding_mask == key_padding_mask).all()
        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
        qkv = rearrange(torch.stack([q, k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads)
        if query_padding_mask is not None:
            dqkv_pad_fn = lambda dqkv_unpad: rearrange(
                pad_input(rearrange(dqkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_q, batch_size, seqlen),
                'b s (t h d) -> b s t h d', t=3, h=nheads
            )
        else:
            dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
        return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q,
                qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn)
    elif kvpacked:
        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
        q = rearrange(q, 'b s (h d) -> b s h d', h=nheads)
        kv = rearrange(torch.stack([k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads)
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
            dkv_pad_fn = lambda dkv_unpad: rearrange(
                pad_input(rearrange(dkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_k, batch_size, seqlen),
                'b s (t h d) -> b s t h d', t=2, h=nheads
            )
        else:
            dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
        return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(),
                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                q.detach().requires_grad_(), kv.detach().requires_grad_(),
                output_pad_fn, dq_pad_fn, dkv_pad_fn)
    else:
        q, k, v = [rearrange(z, 'b s (h d) -> b s h d', h=nheads).detach().requires_grad_()
                   for z in [q, k, v]]
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
            dk_pad_fn = lambda dk_unpad: rearrange(
                pad_input(rearrange(dk_unpad, 'nnz h d -> nnz (h d)'), indices_k, batch_size, seqlen),
                'b s (h d) -> b s h d', h=nheads
            )
        else:
            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size)
        return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(),
                v_unpad.detach().requires_grad_(),
                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                q, k, v,
                output_pad_fn, dq_pad_fn, dk_pad_fn)


def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
                  dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k: (batch_size, seqlen_k, nheads, head_dim)
        v: (batch_size, seqlen_k, nheads, head_dim)
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
            output back to fp16/bf16.
        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
            without changing the math. This is to estimate the numerical error from operation
            reordering.
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
    """
    dtype_og = q.dtype
    if upcast:
        q, k, v = q.float(), k.float(), v.float()
    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
    d = q.shape[-1]
    if not reorder_ops:
        scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
    else:
        scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
    if key_padding_mask is not None:
        scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
        scores.masked_fill_(causal_mask, float('-inf'))
    attention = torch.softmax(scores, dim=-1)
    dropout_scaling = 1.0 / (1 - dropout_p)
    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    if dropout_mask is not None:
        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
    output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling)
    if query_padding_mask is not None:
        output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0)
        attention = attention.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


def attention_kvpacked_ref(q, kv, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
                           dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
    return attention_ref(q, kv[:, :, 0], kv[:, :, 1], query_padding_mask,
                         key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal,
                         reorder_ops=reorder_ops)


def attention_qkvpacked_ref(qkv, key_padding_mask=None, dropout_p=0.0,
                            dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
    return attention_ref(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask,
                         key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal,
                         reorder_ops=reorder_ops)


def generate_sparsity_mask(seqlen, sparsity=0.3):
    repeats = seqlen // 16 // 2
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    nrow, ncol = seqlen // 16, seqlen // 256
    mask = torch.rand(nrow, ncol, device='cuda') < sparsity
    return mask


def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        blockmask: (seqlen / 16, seqlen / 256)
        attn_mask: (batch_size, seqlen)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen, seqlen)
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
        attention: softmax after dropout
    """
    q, k, v = qkv.float().unbind(dim=2)
    d = qkv.shape[-1]
    seqlen = qkv.shape[1]
    scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
    scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
    blockmask = repeat(blockmask, 's_16 s_256 -> (s_16 16) (s_256 256)')
    blockmask = blockmask[:seqlen, :seqlen]
    scores.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), float('-inf'))
    attention = torch.softmax(scores, dim=-1)
    attention = attention.masked_fill(rearrange(~attn_mask, 'b s -> b 1 s 1'), 0.0)
    attention = attention.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), 0.0)
    attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1 1'), 0)
    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)


def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, head_dim, is_dropout,
                                    causal=False):
    """FlashAttention stores the S matrix in a different way.
    Arguments:
        S: (batch_size, nheads, seqlen_q, seqlen_k)
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
    """
    S_flat = rearrange(S, 'b h t s -> b h (t s)')
    seqlen_q, seqlen_k = S.shape[-2:]
    block_size = _get_block_size(S.device, head_dim, is_dropout)
    loop_steps = (seqlen_k + block_size - 1) // block_size
    warps_n = 4
    mmas_n = (seqlen_k // warps_n // 16) if seqlen_k <= block_size else (block_size // warps_n // 16)
    S_converted = rearrange(S_flat, 'b h (loop nsteps mmas_n warps_n eight t r c0 c1) -> b h (nsteps r eight) (loop mmas_n warps_n c0 t c1)',
                            loop=loop_steps, nsteps=seqlen_q // 16, mmas_n=mmas_n, warps_n=warps_n, eight=8, t=4,
                            r=2, c0=2, c1=2)

    # Need to zero out things not in attention_mask in case S was initialized with random values
    # and some of those values aren't overwritten.
    seqlen_q_og = query_padding_mask.shape[-1]
    if seqlen_q_og < seqlen_q:
        query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
    else:
        query_padding_mask = query_padding_mask[:, :seqlen_q]
    S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
    seqlen_k_og = key_padding_mask.shape[-1]
    if seqlen_k_og < seqlen_k:
        key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
    else:
        key_padding_mask = key_padding_mask[:, :seqlen_k]
    S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0)
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
        S_converted.masked_fill_(causal_mask, 0.0)
    if seqlen_q_og < seqlen_q:
        S_converted = S_converted[:, :, :seqlen_q_og, :]
    else:
        S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
    if seqlen_k_og < seqlen_k:
        S_converted = S_converted[:, :, :, :seqlen_k_og]
    else:
        S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
    return S_converted


def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None,
                           is_dropout=False, causal=False):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k, v: (batch_size, seqlen_k, nheads, head_dim)
        key_padding_mask: (batch_size, seqlen_q)
    Output:
        softmax_lse: (batch_size, nheads, seqlen_q)
        softmax_max: (batch_size, nheads, seqlen_q)
    """
    q, k, v = q.float(), k.float(), v.float()
    _, seqlen_q, _, head_dim = q.shape
    seqlen_k = k.shape[1]
    scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(head_dim), k)
    if key_padding_mask is not None:
        scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
        scores.masked_fill_(causal_mask, float('-inf'))
    block_size = _get_block_size(scores.device, head_dim, is_dropout)
    scores_block = scores.split(block_size, dim=-1)
    lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
    lcse_block = torch.logcumsumexp(lse_block, dim=-1).unbind(dim=-1)
    scores_max_block = ([torch.amax(scores_block[0], dim=-1)]
                        + [torch.maximum(torch.amax(s, dim=-1), lcse)
                           for s, lcse in zip(scores_block[1:], lcse_block[:-1])])
    attn_unnorm_block = attn_unnorm.split(block_size, dim=-1)
    attn_norm = torch.cat([a / rearrange(torch.exp(lcse_block[-1] - m), 'b h s -> b h s 1')
                           for a, m in zip(attn_unnorm_block, scores_max_block)], dim=-1)
    if query_padding_mask is not None:
        attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
    return attn_norm.to(dtype=attn_unnorm.dtype)


def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False):
    """
    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
    query_padding_mask: (batch_size, seqlen_q)
    key_padding_mask: (batch_size, seqlen_k)
    """
    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
    dropped = ~dropout_mask
    if query_padding_mask is not None:
        dropped.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), False)
    if key_padding_mask is not None:
        dropped.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), False)
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool,
                                            device=dropout_mask.device), 1)
        dropped.masked_fill_(causal_mask, False)
    dropped_total = dropped.sum()
    query_lengths = (query_padding_mask.sum(dim=-1) if query_padding_mask is not None
                     else torch.full((batch_size,), seqlen_q, device=dropout_mask.device))
    key_lengths = (key_padding_mask.sum(dim=-1) if key_padding_mask is not None
                   else torch.full((batch_size,), seqlen_k, device=dropout_mask.device))
    if not causal:
        numel_per_batch = query_lengths * key_lengths
    else:
        numel_per_batch = torch.where(
            query_lengths <= key_lengths,
            query_lengths * (query_lengths + 1) / 2,
            query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2)
        )
    return dropped_total / (numel_per_batch.sum() * nheads)


@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
342
# @pytest.mark.parametrize('causal', [False])
Tri Dao's avatar
Tri Dao committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # if dtype == torch.float16:
    #     rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
    # else:  # torch.bfloat16
    #     rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    nheads = 4
    x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
    Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')

    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
        x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
    )

    output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func(
        qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
    )
    output = output_pad_fn(output_unpad)
    S_dmask_converted = convert_flash_attn_S_to_softmax(
        S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
    )
    dropout_mask = S_dmask_converted >= 0
    attn_unnorm = S_dmask_converted.abs()
    attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
                                  key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
    dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
                                            causal=causal).item()

    output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                   causal=causal)
    output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                 causal=causal, upcast=False, reorder_ops=True)
    print(f'Actual dropout fraction: {dropout_fraction}')
    print(f'Output max diff: {(output - output_ref).abs().max().item()}')
    print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
    print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
    print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

397
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        g = torch.randn_like(output)
        dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
        dqkv = dqkv_pad_fn(dqkv_unpad)
        dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
        dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
        print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
        print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
    # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
    assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
    # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
    if dropout_p == 0.0:
        assert dropout_mask.all()
    else:
        assert 0.99 <= dropout_fraction / dropout_p <= 1.01

423
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
        # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # if dtype == torch.float16:
    #     rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
    # else:  # torch.bfloat16
    #     rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    nheads = 4
    x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
    Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

    query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')

    (q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv,
     output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv(
         x, Wqkv, nheads, query_padding_mask, key_padding_mask, kvpacked=True
     )

    output_unpad, sm_lse, S_dmask = flash_attn_unpadded_kvpacked_func(
        q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, return_attn_probs=True, causal=causal
    )
    output = output_pad_fn(output_unpad)
    S_dmask_converted = convert_flash_attn_S_to_softmax(
        S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
    )
    dropout_mask = S_dmask_converted >= 0
    attn_unnorm = S_dmask_converted.abs()
    attn = normalize_flash_attn_S(attn_unnorm, q, kv[:, :, 0], kv[:, :, 1],
                                  query_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
    dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask,
                                            causal=causal)

    output_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
                                                  dropout_p, dropout_mask, causal=causal)
    output_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
                                                dropout_p, dropout_mask, causal=causal,
                                                upcast=False, reorder_ops=True)
    print(f'Actual dropout fraction: {dropout_fraction}')
    print(f'Output max diff: {(output - output_ref).abs().max().item()}')
    print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
    print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
    print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

488
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        g = torch.randn_like(output)
        dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g)
        dq = dq_pad_fn(dq_unpad)
        dkv = dkv_pad_fn(dkv_unpad)
        dq_ref, dkv_ref, = torch.autograd.grad(output_ref, (q, kv), g)
        dq_pt, dkv_pt = torch.autograd.grad(output_pt, (q, kv), g)
        print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
        print(f'dK max diff: {(dkv[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dV max diff: {(dkv[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dkv_pt[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dkv_pt[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}')

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
    # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
    assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
    # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
    if dropout_p == 0.0:
        assert dropout_mask.all()
    else:
        assert 0.99 <= dropout_fraction / dropout_p <= 1.01

513
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item()
        # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
        # assert torch.allclose(dkv, dkv_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # if dtype == torch.float16:
    #     rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
    # else:  # torch.bfloat16
    #     rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    nheads = 4
    x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
    Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

    query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')

    (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
     output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
         x, Wqkv, nheads, query_padding_mask, key_padding_mask
     )

    output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func(
        q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, return_attn_probs=True, causal=causal
    )
    output = output_pad_fn(output_unpad)
    S_dmask_converted = convert_flash_attn_S_to_softmax(
        S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
    )
    dropout_mask = S_dmask_converted >= 0
    attn_unnorm = S_dmask_converted.abs()
    attn = normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask, key_padding_mask,
                                  dropout_p > 0.0, causal=causal)
    dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask,
                                            causal=causal)

    output_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
                                         dropout_p, dropout_mask, causal=causal)
    output_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
                                       dropout_p, dropout_mask, causal=causal,
                                       upcast=False, reorder_ops=True)
    print(f'Actual dropout fraction: {dropout_fraction}')
    print(f'Output max diff: {(output - output_ref).abs().max().item()}')
    print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
    print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
    print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

580
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        g = torch.randn_like(output)
        dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g)
        dq = dq_pad_fn(dq_unpad)
        dk = dk_pad_fn(dk_unpad)
        dv = dk_pad_fn(dv_unpad)
        dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
        dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
        print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
        print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
        print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
        print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
    # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
    assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
    # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
    if dropout_p == 0.0:
        assert dropout_mask.all()
    else:
        assert 0.99 <= dropout_fraction / dropout_p <= 1.01

606
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
607
608
609
610
611
612
613
614
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
        # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
        # assert torch.allclose(dk, dk_ref, rtol=rtol, atol=atol)
        # assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)


615
616
617
618
619
620
621
622
623
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
624
def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # if dtype == torch.float16:
    #     rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
    # else:  # torch.bfloat16
    #     rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    nheads = 4
    x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
    Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='split')
    batch_size0 = batch_size // 4 * 3  # this must match what's in generate_random_padding_mask
    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')

    qkv_unpad, cu_seqlens, max_seqlen0, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
        x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
    )
    max_seqlen1 = 128

    output_unpad, sm_lse, S_dmask0, S_dmask1 = flash_attn_unpadded_qkvpacked_split_func(
        qkv_unpad, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
        return_attn_probs=True, causal=causal
    )
    output = output_pad_fn(output_unpad)
    S_dmask0_converted = convert_flash_attn_S_to_softmax(
        S_dmask0, key_padding_mask[:batch_size0], key_padding_mask[:batch_size0], d, dropout_p > 0.0, causal=causal
    )
    S_dmask1_converted = convert_flash_attn_S_to_softmax(
        S_dmask1, key_padding_mask[batch_size0:, :max_seqlen1], key_padding_mask[batch_size0:, :max_seqlen1], d, dropout_p > 0.0, causal=causal
    )
    padding = (S_dmask0_converted.shape[-1] - S_dmask1_converted.shape[-1],
               S_dmask0_converted.shape[-2] - S_dmask1_converted.shape[-2])
    S_dmask_converted = torch.cat([S_dmask0_converted,
                                   F.pad(S_dmask1_converted, (0, padding[0], 0, padding[1]))], dim=0)
    dropout_mask = S_dmask_converted >= 0
    attn_unnorm = S_dmask_converted.abs()
    attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
                                  key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
    dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
                                            causal=causal).item()

    output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                   causal=causal)
    output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                 causal=causal, upcast=False, reorder_ops=True)
    print(f'Actual dropout fraction: {dropout_fraction}')
    print(f'Output max diff: {(output - output_ref).abs().max().item()}')
    print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
    print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
    print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
        g = torch.randn_like(output)
        dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
        dqkv = dqkv_pad_fn(dqkv_unpad)
        dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
        dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
        print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
        print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
    # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
    assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
    # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
    if dropout_p == 0.0:
        assert dropout_mask.all()
    else:
        assert 0.99 <= dropout_fraction / dropout_p <= 1.01

    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
        # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)


Tri Dao's avatar
Tri Dao committed
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    nheads = 4
    x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
    Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

    query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')

    (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
     output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
         x, Wqkv, nheads, query_padding_mask, key_padding_mask
     )

    torch.random.manual_seed(0)
    output_unpad_0, sm_lse_0, S_dmask_0 = flash_attn_unpadded_func(
        q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, return_attn_probs=True, causal=causal
    )
    S_dmask_converted_0 = convert_flash_attn_S_to_softmax(
        S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
    )

750
    if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        g = torch.randn_like(output_unpad_0)
        dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
                                                                  (q_unpad, k_unpad, v_unpad), g)

    for _ in range(10):
        torch.random.manual_seed(0)
        output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func(
            q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, return_attn_probs=True, causal=causal
        )
        S_dmask_converted = convert_flash_attn_S_to_softmax(
            S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
        )
        assert torch.equal(output_unpad, output_unpad_0)
        # sm_lse has some parts that are uninitialized from torch.empty
        # assert torch.equal(sm_lse, sm_lse_0)
        assert torch.equal(S_dmask_converted, S_dmask_converted_0)

769
        if is_sm80 or d < 128:  # Only run backward for d=128 on A100
Tri Dao's avatar
Tri Dao committed
770
771
772
773
774
            dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
                                                                (q_unpad, k_unpad, v_unpad), g)
            assert torch.equal(dq_unpad, dq_unpad_0)
            assert torch.equal(dk_unpad, dk_unpad_0)
            assert torch.equal(dv_unpad, dv_unpad_0)
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='requires multiple GPUs')
def test_flash_attn_multigpu():
    seqlen = 256
    d = 64
    dropout_p = 0.0
    causal = False
    dtype = torch.float16
    device = 'cuda:1'
    torch.random.manual_seed(0)
    batch_size = 32
    nheads = 4
    x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
    Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')

    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
        x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
    )

    output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func(
        qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
    )
    output = output_pad_fn(output_unpad)
    S_dmask_converted = convert_flash_attn_S_to_softmax(
        S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
    )
    dropout_mask = S_dmask_converted >= 0
    attn_unnorm = S_dmask_converted.abs()
    attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
                                  key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
    dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
                                            causal=causal).item()

    output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                   causal=causal)
    output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                 causal=causal, upcast=False, reorder_ops=True)
    print(f'Actual dropout fraction: {dropout_fraction}')
    print(f'Output max diff: {(output - output_ref).abs().max().item()}')
    print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
    print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
    print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

    g = torch.randn_like(output)
    dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
    dqkv = dqkv_pad_fn(dqkv_unpad)
    dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
    dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
    print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
    print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
    print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
    print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
    print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
    print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
    print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
    print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
    assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
    # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
    assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
    # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
    if dropout_p == 0.0:
        assert dropout_mask.all()
    else:
        assert 0.99 <= dropout_fraction / dropout_p <= 1.01

    assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()