"examples/community/img2img_inpainting.py" did not exist on "98c42134a5615e1c26f2cca70ff9a4c142850f65"
test_flash_attn.py 48.5 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
import math

import torch
import torch.nn.functional as F

import pytest

from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
10
11
12
13
from flash_attn import flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
from flash_attn import flash_attn_varlen_func
from flash_attn.flash_attn_interface import _get_block_size
Tri Dao's avatar
Tri Dao committed
14
15
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis

Tri Dao's avatar
Tri Dao committed
16
17

MAX_HEADDIM_SM8x = 192
Tri Dao's avatar
Tri Dao committed
18

Tri Dao's avatar
Tri Dao committed
19
20

is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
Tri Dao's avatar
Tri Dao committed
21
is_sm8x = torch.cuda.get_device_capability('cuda')[0] == 8
22
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
Tri Dao's avatar
Tri Dao committed
23
is_sm90 = torch.cuda.get_device_capability('cuda') == (9, 0)
Tri Dao's avatar
Tri Dao committed
24
25
26


def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
Tri Dao's avatar
Tri Dao committed
27
    assert mode in ['full', 'random', 'third']
Tri Dao's avatar
Tri Dao committed
28
29
30
    if mode == 'full':
        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
    elif mode == 'random':
Tri Dao's avatar
Tri Dao committed
31
        lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
Tri Dao's avatar
Tri Dao committed
32
    elif mode == 'third':
Tri Dao's avatar
Tri Dao committed
33
        lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
Tri Dao's avatar
Tri Dao committed
34
35
36
37
    padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths
    return padding_mask


Tri Dao's avatar
Tri Dao committed
38
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
Tri Dao's avatar
Tri Dao committed
39
40
41
                 kvpacked=False, qkvpacked=False):
    """
    Arguments:
Tri Dao's avatar
Tri Dao committed
42
43
44
        q: (batch_size, seqlen_q, nheads, d)
        k: (batch_size, seqlen_k, nheads_k, d)
        v: (batch_size, seqlen_k, nheads_k, d)
Tri Dao's avatar
Tri Dao committed
45
46
47
48
        query_padding_mask: (batch_size, seqlen), bool
        key_padding_mask: (batch_size, seqlen), bool
    """
    assert not (kvpacked and qkvpacked)
Tri Dao's avatar
Tri Dao committed
49
50
51
52
    batch_size, seqlen_q, nheads, d = q.shape
    _, seqlen_k, nheads_k, _ = k.shape
    assert k.shape == (batch_size, seqlen_k, nheads_k, d)
    assert v.shape == (batch_size, seqlen_k, nheads_k, d)
Tri Dao's avatar
Tri Dao committed
53
54
55

    if query_padding_mask is not None:
        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
Tri Dao's avatar
Tri Dao committed
56
        output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
Tri Dao's avatar
Tri Dao committed
57
    else:
Tri Dao's avatar
Tri Dao committed
58
59
        q_unpad = rearrange(q, 'b s h d -> (b s) h d')
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
Tri Dao's avatar
Tri Dao committed
60
                                    device=q_unpad.device)
Tri Dao's avatar
Tri Dao committed
61
        max_seqlen_q = seqlen_q
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
67
        output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size)

    if key_padding_mask is not None:
        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
    else:
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
        k_unpad = rearrange(k, 'b s h d -> (b s) h d')
        v_unpad = rearrange(v, 'b s h d -> (b s) h d')
        cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
                                    device=k_unpad.device)
        max_seqlen_k = seqlen_k
Tri Dao's avatar
Tri Dao committed
73
74
75

    if qkvpacked:
        assert (query_padding_mask == key_padding_mask).all()
Tri Dao's avatar
Tri Dao committed
76
        assert nheads == nheads_k
Tri Dao's avatar
Tri Dao committed
77
        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
Tri Dao's avatar
Tri Dao committed
78
        qkv = torch.stack([q, k, v], dim=2)
Tri Dao's avatar
Tri Dao committed
79
        if query_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
80
            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
        else:
            dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
        return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q,
                qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn)
    elif kvpacked:
        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
Tri Dao's avatar
Tri Dao committed
87
        kv = torch.stack([k, v], dim=2)
Tri Dao's avatar
Tri Dao committed
88
89
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
90
            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
97
98
99
        else:
            dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
        return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(),
                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                q.detach().requires_grad_(), kv.detach().requires_grad_(),
                output_pad_fn, dq_pad_fn, dkv_pad_fn)
    else:
        dq_pad_fn = output_pad_fn
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
100
            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
        else:
            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size)
        return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(),
                v_unpad.detach().requires_grad_(),
                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
Tri Dao's avatar
Tri Dao committed
106
107
                q.detach().requires_grad_(), k.detach().requires_grad_(),
                v.detach().requires_grad_(),
Tri Dao's avatar
Tri Dao committed
108
109
110
111
                output_pad_fn, dq_pad_fn, dk_pad_fn)


def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
Tri Dao's avatar
Tri Dao committed
112
                  dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
Tri Dao's avatar
Tri Dao committed
113
114
115
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
Tri Dao's avatar
Tri Dao committed
116
117
        k: (batch_size, seqlen_k, nheads_k, head_dim)
        v: (batch_size, seqlen_k, nheads_k, head_dim)
Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
            output back to fp16/bf16.
        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
            without changing the math. This is to estimate the numerical error from operation
            reordering.
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
    """
    dtype_og = q.dtype
    if upcast:
        q, k, v = q.float(), k.float(), v.float()
    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
Tri Dao's avatar
Tri Dao committed
135
136
    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
Tri Dao's avatar
Tri Dao committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    d = q.shape[-1]
    if not reorder_ops:
        scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
    else:
        scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
    if key_padding_mask is not None:
        scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
        scores.masked_fill_(causal_mask, float('-inf'))
    attention = torch.softmax(scores, dim=-1)
    dropout_scaling = 1.0 / (1 - dropout_p)
    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    if dropout_mask is not None:
        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
Tri Dao's avatar
Tri Dao committed
153
154
    else:
        attention_drop = attention
Tri Dao's avatar
Tri Dao committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling)
    if query_padding_mask is not None:
        output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0)
        attention = attention.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


def attention_kvpacked_ref(q, kv, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
                           dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
    return attention_ref(q, kv[:, :, 0], kv[:, :, 1], query_padding_mask,
                         key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal,
                         reorder_ops=reorder_ops)


def attention_qkvpacked_ref(qkv, key_padding_mask=None, dropout_p=0.0,
                            dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
    return attention_ref(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask,
                         key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal,
                         reorder_ops=reorder_ops)


def generate_sparsity_mask(seqlen, sparsity=0.3):
    repeats = seqlen // 16 // 2
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
    nrow, ncol = seqlen // 16, seqlen // 256
    mask = torch.rand(nrow, ncol, device='cuda') < sparsity
    return mask


def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        blockmask: (seqlen / 16, seqlen / 256)
        attn_mask: (batch_size, seqlen)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen, seqlen)
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
        attention: softmax after dropout
    """
    q, k, v = qkv.float().unbind(dim=2)
    d = qkv.shape[-1]
    seqlen = qkv.shape[1]
    scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
    scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
    blockmask = repeat(blockmask, 's_16 s_256 -> (s_16 16) (s_256 256)')
    blockmask = blockmask[:seqlen, :seqlen]
    scores.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), float('-inf'))
    attention = torch.softmax(scores, dim=-1)
    attention = attention.masked_fill(rearrange(~attn_mask, 'b s -> b 1 s 1'), 0.0)
    attention = attention.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), 0.0)
    attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1 1'), 0)
    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)


def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, head_dim, is_dropout,
                                    causal=False):
    """FlashAttention stores the S matrix in a different way.
    Arguments:
Tri Dao's avatar
Tri Dao committed
222
        S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
Tri Dao's avatar
Tri Dao committed
223
224
225
226
227
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
    """
    seqlen_q, seqlen_k = S.shape[-2:]
    warps_n = 4
Tri Dao's avatar
Tri Dao committed
228
229
230
231
232
233
234
235
    blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
    nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
    nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
    mmas_n = (blocksize_n + 16 - 1) // 16
    S_flat = rearrange(S, 'b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)',
                       blocksize_m=blocksize_m, blocksize_n=blocksize_n)
    S_converted = rearrange(S_flat, 'b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)',
                            mmas_n=mmas_n, warps_n=warps_n, eight=8, c0=2, c1=2, c2=2, four=4)
Tri Dao's avatar
Tri Dao committed
236
237
238
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
        S_converted.masked_fill_(causal_mask, 0.0)
Tri Dao's avatar
Tri Dao committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    # Need to zero out things not in attention_mask in case S was initialized with random values
    # and some of those values aren't overwritten.
    seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
    if query_padding_mask is not None:
        if seqlen_q_og < seqlen_q:
            query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
        else:
            query_padding_mask = query_padding_mask[:, :seqlen_q]
        S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
    seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
    if key_padding_mask is not None:
        if seqlen_k_og < seqlen_k:
            key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
        else:
            key_padding_mask = key_padding_mask[:, :seqlen_k]
        S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0)
Tri Dao's avatar
Tri Dao committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    if seqlen_q_og < seqlen_q:
        S_converted = S_converted[:, :, :seqlen_q_og, :]
    else:
        S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
    if seqlen_k_og < seqlen_k:
        S_converted = S_converted[:, :, :, :seqlen_k_og]
    else:
        S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
    return S_converted


def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None,
                           is_dropout=False, causal=False):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k, v: (batch_size, seqlen_k, nheads, head_dim)
        key_padding_mask: (batch_size, seqlen_q)
    Output:
        softmax_lse: (batch_size, nheads, seqlen_q)
        softmax_max: (batch_size, nheads, seqlen_q)
    """
    q, k, v = q.float(), k.float(), v.float()
    _, seqlen_q, _, head_dim = q.shape
    seqlen_k = k.shape[1]
    scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(head_dim), k)
    if key_padding_mask is not None:
        scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
        scores.masked_fill_(causal_mask, float('-inf'))
Tri Dao's avatar
Tri Dao committed
287
288
    _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
    scores_block = scores.split(block_size_n, dim=-1)
Tri Dao's avatar
Tri Dao committed
289
    lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
Tri Dao's avatar
Tri Dao committed
290
291
292
293
294
295
    lse = torch.logsumexp(lse_block, dim=-1)
    scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
    cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
    attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
    attn_norm = torch.cat([a / rearrange(torch.exp(lse - m), 'b h s -> b h s 1')
                           for a, m in zip(attn_unnorm_block, cummax_block)], dim=-1)
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    if query_padding_mask is not None:
        attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
    return attn_norm.to(dtype=attn_unnorm.dtype)


def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False):
    """
    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
    query_padding_mask: (batch_size, seqlen_q)
    key_padding_mask: (batch_size, seqlen_k)
    """
    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
    dropped = ~dropout_mask
    if query_padding_mask is not None:
        dropped.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), False)
    if key_padding_mask is not None:
        dropped.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), False)
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool,
                                            device=dropout_mask.device), 1)
        dropped.masked_fill_(causal_mask, False)
    dropped_total = dropped.sum()
    query_lengths = (query_padding_mask.sum(dim=-1) if query_padding_mask is not None
                     else torch.full((batch_size,), seqlen_q, device=dropout_mask.device))
    key_lengths = (key_padding_mask.sum(dim=-1) if key_padding_mask is not None
                   else torch.full((batch_size,), seqlen_k, device=dropout_mask.device))
    if not causal:
        numel_per_batch = query_lengths * key_lengths
    else:
        numel_per_batch = torch.where(
            query_lengths <= key_lengths,
            query_lengths * (query_lengths + 1) / 2,
            query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2)
        )
    return dropped_total / (numel_per_batch.sum() * nheads)


@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
Tri Dao's avatar
Tri Dao committed
336
337
338
339
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
Tri Dao's avatar
Tri Dao committed
340
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
341
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
Tri Dao's avatar
Tri Dao committed
342
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
Tri Dao's avatar
Tri Dao committed
343
# @pytest.mark.parametrize('seqlen', [97])
Tri Dao's avatar
Tri Dao committed
344
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
Tri Dao's avatar
Tri Dao committed
345
346
# @pytest.mark.parametrize('dropout_p', [0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
Tri Dao's avatar
Tri Dao committed
347
348
349
350
351
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
352
353
354
355
356
357
    batch_size = 16
    nheads = 9
    qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
                      requires_grad=True)
    out, lse, S_dmask = flash_attn_qkvpacked_func(
        qkv, dropout_p, return_attn_probs=True, causal=causal
Tri Dao's avatar
Tri Dao committed
358
    )
Tri Dao's avatar
Tri Dao committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
            S_dmask, None, None, d, dropout_p > 0.0, causal=causal
        )[:, :, :seqlen, :seqlen]
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
                                      None, None, dropout_p > 0.0, causal=causal)
        dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
        print(f'Actual dropout fraction: {dropout_fraction}')
    else:
        dropout_mask = None

    out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal)
    out_pt, attn_pt = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal,
                                              upcast=False, reorder_ops=True)
    # v = qkv[:, :, 2].float()
    # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
    # if causal:
    #     causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
    #     qk.masked_fill_(causal_mask, float('-inf'))
    # m = qk.amax(-1, keepdim=True)
    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
    # p_tmp = torch.softmax(qk / math.sqrt(d), -1)
    # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
    # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
    # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
    # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
    # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
    # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
    # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
    # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
    # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
    if dropout_p > 0.0:
        print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
        print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

    g = torch.randn_like(out)
    # do_o = (g.float() * out.float()).sum(-1)
    # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
    # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        dqkv, = torch.autograd.grad(out, qkv, g)
        dqkv_ref, = torch.autograd.grad(out_ref, qkv, g)
        dqkv_pt, = torch.autograd.grad(out_pt, qkv, g)
Tri Dao's avatar
Tri Dao committed
409
410
411
412
413
414
415
416
417
418
419
        print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
        print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
420
421
422
423
424
425
426
427
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
428
429
430
431
432
433



@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
Tri Dao's avatar
Tri Dao committed
434
435
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
Tri Dao's avatar
Tri Dao committed
436
437
438
439
440
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
441
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
Tri Dao's avatar
Tri Dao committed
442
443
444
445
446
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
447
448
449
450
    batch_size = 5
    nheads = 6
    qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
                      requires_grad=True)
Tri Dao's avatar
Tri Dao committed
451
452

    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
Tri Dao's avatar
Tri Dao committed
453
    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
Tri Dao's avatar
Tri Dao committed
454

Tri Dao's avatar
Tri Dao committed
455
456
    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
Tri Dao's avatar
Tri Dao committed
457
    )
Tri Dao's avatar
Tri Dao committed
458
459
460

    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
        qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
Tri Dao's avatar
Tri Dao committed
461
    )
Tri Dao's avatar
Tri Dao committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    out = output_pad_fn(out_unpad)
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
            S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
        )[:, :, :seqlen, :seqlen]
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
                                      key_padding_mask, key_padding_mask, dropout_p > 0.0,
                                      causal=causal)
        dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
                                                causal=causal).item()
        print(f'Actual dropout fraction: {dropout_fraction}')
    else:
        dropout_mask = None

    out_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                                causal=causal)
    out_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
                                              causal=causal, upcast=False, reorder_ops=True)
    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
    if dropout_p > 0.0:
        print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
        print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

    g = torch.randn_like(out)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        dqkv_unpad, = torch.autograd.grad(out, qkv_unpad, g)
        dqkv = dqkv_pad_fn(dqkv_unpad)
        dqkv_ref, = torch.autograd.grad(out_ref, qkv, g)
        dqkv_pt, = torch.autograd.grad(out_pt, qkv, g)
        print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
        print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
        print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
Tri Dao's avatar
Tri Dao committed
504
505
506

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
507
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
508

Tri Dao's avatar
Tri Dao committed
509
510
511
512
513
514
    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01

    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
515
516


Tri Dao's avatar
Tri Dao committed
517
518
@pytest.mark.parametrize('kvpacked', [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
Tri Dao's avatar
Tri Dao committed
519
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
Tri Dao's avatar
Tri Dao committed
520
521
522
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mha"])
Tri Dao's avatar
Tri Dao committed
523
@pytest.mark.parametrize('causal', [False, True])
Tri Dao's avatar
Tri Dao committed
524
525
526
527
528
529
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
Tri Dao's avatar
Tri Dao committed
530
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
531
532
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
Tri Dao's avatar
Tri Dao committed
533
534
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
535
536
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
    if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
Tri Dao's avatar
Tri Dao committed
537
538
539
540
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    batch_size = 16
    nheads = 9
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    if kvpacked:
        kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype,
                         requires_grad=True)
    else:
        k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
                        requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
                        requires_grad=True)

    if kvpacked:
        out, lse, S_dmask = flash_attn_kvpacked_func(
            q, kv, dropout_p, return_attn_probs=True, causal=causal
        )
    else:
        out, lse, S_dmask = flash_attn_func(
            q, k, v, dropout_p, return_attn_probs=True, causal=causal
        )
    if dropout_p > 0.0:
        S_dmask_converted = convert_flash_attn_S_to_softmax(
            S_dmask, None, None, d, dropout_p > 0.0, causal=causal
        )[:, :, :seqlen_q, :seqlen_k]
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        if kvpacked:
            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
            k_rep, v_rep = kv_rep.unbind(dim=2)
        else:
            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
        attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep,
                                      None, None, dropout_p > 0.0, causal=causal)
        dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
        print(f'Actual dropout fraction: {dropout_fraction}')
    else:
        dropout_mask = None
Tri Dao's avatar
Tri Dao committed
581

Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
    if kvpacked:
        out_ref, attn_ref = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask,
                                                causal=causal)
        out_pt, attn_pt = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask,
                                                causal=causal, upcast=False, reorder_ops=True)
    else:
        out_ref, attn_ref = attention_ref(q, k, v, None, None, dropout_p, dropout_mask,
                                          causal=causal)
        out_pt, attn_pt = attention_ref(q, k, v, None, None, dropout_p, dropout_mask,
                                        causal=causal, upcast=False, reorder_ops=True)

    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
    if dropout_p > 0.0:
        print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
        print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

    g = torch.randn_like(out)
    do_o = (g.float() * out.float()).sum(-1)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        if kvpacked:
            dq, dkv, = torch.autograd.grad(out, (q, kv), g)
            dk, dv = dkv.unbind(2)
            dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g)
            dk_ref, dv_ref = dkv_ref.unbind(2)
            dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g)
            dk_pt, dv_pt = dkv_pt.unbind(2)
        else:
            dq, dk, dv, = torch.autograd.grad(out, (q, k, v), g)
            dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g)
Tri Dao's avatar
Tri Dao committed
615
616
617
        print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
        print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
        print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
Tri Dao's avatar
Tri Dao committed
618
619
620
        print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
        print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
        print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
Tri Dao's avatar
Tri Dao committed
621
622
623
        print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
Tri Dao's avatar
Tri Dao committed
624
625
626
        print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
        print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
        print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
Tri Dao's avatar
Tri Dao committed
627
628
629

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
630
631
632
633
634
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01
Tri Dao's avatar
Tri Dao committed
635

Tri Dao's avatar
Tri Dao committed
636
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
Tri Dao's avatar
Tri Dao committed
637
638
639
640
641
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()


Tri Dao's avatar
Tri Dao committed
642
643
@pytest.mark.parametrize('kvpacked', [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
644
645
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
646
647
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
648
@pytest.mark.parametrize('causal', [False, True])
Tri Dao's avatar
Tri Dao committed
649
650
651
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
652
# @pytest.mark.parametrize('d', [64])
Tri Dao's avatar
Tri Dao committed
653
654
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
655
656
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
Tri Dao's avatar
Tri Dao committed
657
658
659
def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype,
                                  kvpacked):
    if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
660
661
662
663
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
Tri Dao's avatar
Tri Dao committed
664
665
666
667
668
669
670
671
    batch_size = 16
    nheads = 9
    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
    assert nheads % nheads_k == 0
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    if kvpacked:
        kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype,
                         requires_grad=True)
672
    else:
Tri Dao's avatar
Tri Dao committed
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
                        requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
                        requires_grad=True)

    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode='random')
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='random')
    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')

    if kvpacked:
        (q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv,
        output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv(
            q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True
        )
        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
            q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, return_attn_probs=True, causal=causal
        )
    else:
        (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
        output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
            q, k, v, query_padding_mask, key_padding_mask, kvpacked=False
        )
        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
Tri Dao's avatar
Tri Dao committed
697
698
699
            q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
            dropout_p, return_attn_probs=True, causal=causal
        )
Tri Dao's avatar
Tri Dao committed
700
701
    out = output_pad_fn(out_unpad)
    if dropout_p > 0.0:
Tri Dao's avatar
Tri Dao committed
702
703
        S_dmask_converted = convert_flash_attn_S_to_softmax(
            S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
Tri Dao's avatar
Tri Dao committed
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        )[:, :, :seqlen_q, :seqlen_k]
        dropout_mask = S_dmask_converted >= 0
        attn_unnorm = S_dmask_converted.abs()
        if kvpacked:
            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
            k_rep, v_rep = kv_rep.unbind(dim=2)
        else:
            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
        attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep,
                                      query_padding_mask, key_padding_mask,
                                      dropout_p > 0.0, causal=causal)
        dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask,
                                                key_padding_mask, causal=causal).item()
        print(f'Actual dropout fraction: {dropout_fraction}')
    else:
        dropout_mask = None

    if kvpacked:
        out_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
                                                dropout_p, dropout_mask, causal=causal)
        out_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
                                                dropout_p, dropout_mask,
                                                causal=causal, upcast=False, reorder_ops=True)
    else:
        out_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
                                          dropout_p, dropout_mask, causal=causal)
        out_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
                                        dropout_p, dropout_mask,
                                        causal=causal, upcast=False, reorder_ops=True)

    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
    print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
    if dropout_p > 0.0:
        print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
        print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')

    g = torch.randn_like(out)
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        if kvpacked:
            dq_unpad, dkv_unpad, = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
            dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g)
            dk_ref, dv_ref = dkv_ref.unbind(2)
            dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g)
            dk_pt, dv_pt = dkv_pt.unbind(2)
        else:
            dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
            dk = dk_pad_fn(dk_unpad)
            dv = dk_pad_fn(dv_unpad)
            dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g)
        dq = dq_pad_fn(dq_unpad)
        print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
        print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
        print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
        print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
        print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
        print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
        print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
        print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
        print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
        print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
        print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
        print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
771
772
773

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
Tri Dao's avatar
Tri Dao committed
774
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
775

Tri Dao's avatar
Tri Dao committed
776
777
778
    if dropout_p > 0.0:
        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
        assert abs(dropout_fraction - dropout_p) <= 0.01
Tri Dao's avatar
Tri Dao committed
779

Tri Dao's avatar
Tri Dao committed
780
781
782
783
    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
784

785

Tri Dao's avatar
Tri Dao committed
786
787
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
788
@pytest.mark.parametrize('causal', [False, True])
789
# @pytest.mark.parametrize('causal', [True])
Tri Dao's avatar
Tri Dao committed
790
791
792
793
794
795
796
797
798
799
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
# @pytest.mark.parametrize('seqlen', [193])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
    if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
Tri Dao's avatar
Tri Dao committed
800
801
802
803
        pytest.skip()  # Reference implementation OOM
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
804
    batch_size = 32
Tri Dao's avatar
Tri Dao committed
805
    nheads = 4
Tri Dao's avatar
Tri Dao committed
806
807
808
809
810
811
    qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
    out0, lse0, _ = flash_attn_qkvpacked_func(
        qkv, dropout_p, return_attn_probs=True, causal=causal
    )
    g = torch.randn_like(out0)
    dqkv0, = torch.autograd.grad(out0, qkv, g)
Tri Dao's avatar
Tri Dao committed
812

Tri Dao's avatar
Tri Dao committed
813
814
815
816
817
818
819
820
821
    for _ in range(200):
        torch.random.manual_seed(0)
        out, lse, S_dmask = flash_attn_qkvpacked_func(
            qkv, dropout_p, return_attn_probs=True, causal=causal
        )
        assert torch.equal(out, out0)
        assert torch.equal(lse, lse0)
        # sm_lse has some parts that are uninitialized from torch.empty
        # assert torch.equal(sm_lse, sm_lse_0)
Tri Dao's avatar
Tri Dao committed
822

Tri Dao's avatar
Tri Dao committed
823
824
825
826
827
        if not (is_sm75 and d == 128):
            dqkv, = torch.autograd.grad(out, qkv, g)
            assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0])
            assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
            assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921


@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
@pytest.mark.parametrize('seqlen', [1, 2, 5, 17, 128])
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
    """ We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
    in the case where seqlen % 128 != 0.
    """
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
    batch_size = 2
    nheads = 5
    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
    k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2)]
    q.requires_grad_(True)
    k.requires_grad_(True)
    v.requires_grad_(True)
    out = flash_attn_func(q, k, v, causal=causal)
    g = torch.randn_like(out)
    out.backward(g)
    q_pt = q.detach().clone().requires_grad_(True)
    k_pt = k.detach().clone().requires_grad_(True)
    v_pt = v.detach().clone().requires_grad_(True)
    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
    out_pt.backward(g)
    q_ref = q.detach().clone().requires_grad_(True)
    k_ref = k.detach().clone().requires_grad_(True)
    v_ref = v.detach().clone().requires_grad_(True)
    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
    out_ref.backward(g)
    print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
    print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
    print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
    print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
    print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
    print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
    assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (q_pt.grad - q_ref.grad).abs().max().item() + 1e-3
    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (k_pt.grad - k_ref.grad).abs().max().item() + 1e-3
    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (v_pt.grad - v_ref.grad).abs().max().item() + 1e-3


@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256])
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
    """ We previously had a bug where we were using the wrong strides of dout, which shows up
    when dout is not contiguous.
    """
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
    batch_size = 5
    nheads = 2
    q, k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda",
                           requires_grad=True)
               for _ in range(3)]
    out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
    # So g is not contiguous
    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
    out.backward(g)
    q_pt = q.detach().clone().requires_grad_(True)
    k_pt = k.detach().clone().requires_grad_(True)
    v_pt = v.detach().clone().requires_grad_(True)
    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
    out_pt = rearrange(out_pt, "b s ... -> s b ...")
    out_pt.backward(g)
    q_ref = q.detach().clone().requires_grad_(True)
    k_ref = k.detach().clone().requires_grad_(True)
    v_ref = v.detach().clone().requires_grad_(True)
    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
    out_ref = rearrange(out_ref, "b s ... -> s b ...")
    out_ref.backward(g)
    print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
    print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
    print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
    print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
    print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
    print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (q_pt.grad - q_ref.grad).abs().max().item()
    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (k_pt.grad - k_ref.grad).abs().max().item()
    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (v_pt.grad - v_ref.grad).abs().max().item()