test_flash_attn.py 11.6 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
import math

import pytest
import torch
import torch.nn.functional as F

from einops import rearrange, repeat
8
9
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

ABS_TOL = 5e-3
REL_TOL = 1e-1

def print_diffs(out, out_ref):
    out_1d = out.flatten()
    out_ref_1d = out_ref.flatten()
    for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)):
        diff = e_o - e_o_ref
        abs_diff = abs(diff)
        abs_ref = abs(e_o_ref + 1e-5)
        relative_diff = abs_diff / abs_ref
        if abs_diff > ABS_TOL or relative_diff > REL_TOL:
            print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")


Tri Dao's avatar
Tri Dao committed
26
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
27
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
Tri Dao's avatar
Tri Dao committed
28
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
29
# @pytest.mark.parametrize("mha_type", ["mha"])
Tri Dao's avatar
Tri Dao committed
30
@pytest.mark.parametrize("causal", [False, True])
Tri Dao's avatar
Tri Dao committed
31
# @pytest.mark.parametrize("causal", [True])
Tri Dao's avatar
Tri Dao committed
32
33
34
35
36
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
@pytest.mark.parametrize("d", [64, 128, 256])
37
# @pytest.mark.parametrize("d", [128])
Tri Dao's avatar
Tri Dao committed
38
39
40
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
41
        (1, 1),
42
        (257, 1),
Tri Dao's avatar
Tri Dao committed
43
44
45
46
47
48
49
50
        (64, 128),
        (128, 128),
        (256, 256),
        (113, 203),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
Tri Dao's avatar
Tri Dao committed
51
52
        (384, 256),
        (640, 128),
Tri Dao's avatar
Tri Dao committed
53
54
55
56
        (512, 256),
        (1024, 1024),
        (1023, 1024),
        (1024, 1023),
57
        (4096, 4096),
Tri Dao's avatar
Tri Dao committed
58
59
60
61
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
def test_flash_attn_output(
62
    seqlen_q, seqlen_k, d, causal, mha_type, dtype,    
Tri Dao's avatar
Tri Dao committed
63
64
):
    device = "cuda"
65
66
67
68
69
    if(dtype == torch.float8_e4m3fn):
        dtype_init = torch.float16
    else:
        dtype_init = dtype    
    print(dtype)
Tri Dao's avatar
Tri Dao committed
70
71
72
73
    # set seed
    torch.random.manual_seed(0)
    # batch_size = 40
    # nheads = 16
74
    batch_size = 4
Tri Dao's avatar
Tri Dao committed
75
76
    nheads = 6
    nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
77
78
79
    # nheads_kv = 2
    # batch_size = 9
    # nheads = 6
80
81
82
83
84
85
86
87
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)

    q = q.to(dtype)
    k = k.to(dtype)
    v = v.to(dtype)

Tri Dao's avatar
Tri Dao committed
88
    out, lse = flash_attn_func(q, k, v, causal=causal)
89
90
91
92
93

    q = q.to(dtype_init)
    k = k.to(dtype_init)
    v = v.to(dtype_init)
    
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    out_ref, attn_ref = attention_ref(
        q,
        k,
        v,
        None,
        None,
        causal=causal,
    )
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        None,
        None,
        causal=causal,
        upcast=False,
        reorder_ops=True,
    )

    # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
    # m = qk.amax(-1, keepdim=True)
    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
    # exp_sum = s_tmp.sum(-1)
Tri Dao's avatar
Tri Dao committed
117
118
    # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
    # lse_ref = torch.logsumexp(qk, dim=-1)
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
124
    
Tri Dao's avatar
Tri Dao committed
125
    # if not causal:
126
    #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")                
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    # breakpoint()

    # if d <= 128:
    #     g = torch.randn_like(out)
    #     do_o = (g.float() * out.float()).sum(-1)
    #     dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
    #     dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
    #     dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
    #     print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
    #     print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
    #     print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
    #     print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
    #     print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
    #     print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
    #     print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
    #     print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
    #     print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
    #     print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
    #     print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
    #     print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
    # P = torch.softmax(qk, -1)
    # dP = P * (dS - do_o.unsqueeze(1))
    # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
    # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
    # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
    # breakpoint()

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
158
    # breakpoint()
159
160
161
162
163
    if(dtype != torch.float8_e4m3fn):
        assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
    else:       
        # just test correctness of fp8 kernel w/o further quantization techniques
        assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
164
165
166
167
168

    # if d <= 128:
    #     assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
    #     assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
    #     assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193


@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [128])
@pytest.mark.parametrize("d", [64, 128, 256])
@pytest.mark.parametrize(
    "seqlen_q,seqlen_k",
    [
        (1, 1),
        (1, 3),
        (2, 1),
        (511, 1),
        (3, 513),
        (64, 128),
        (113, 203),
        (128, 128),
        (128, 217),
        (113, 211),
        (108, 256),
        (256, 512),
Tri Dao's avatar
Tri Dao committed
194
        (384, 256),
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        (512, 256),
        (640, 128),
        (1024, 1024),
        (1023, 1024),
        (1024, 1023),
        (2048, 2048),
    ],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
def test_flash_attn_varlen_output(
    seqlen_q, seqlen_k, d, causal, mha_type, dtype
):
    if (
        max(seqlen_q, seqlen_k) >= 2048
        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
    ):
        pytest.skip()  # Reference implementation OOM
    device = "cuda"
    # set seed
    torch.random.manual_seed(0)
    # batch_size = 1
    # nheads = 1
    batch_size = 9
    nheads = 6
    nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
 
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(
        batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
    )
    v = torch.randn(
        batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
    )

    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')

    (
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        q,
        k,
        v,
        output_pad_fn,
        dq_pad_fn,
        dk_pad_fn,
    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
    # print("cu_seqlens_q: ", cu_seqlens_q)
    # print("cu_seqlens_k: ", cu_seqlens_k)
    # print("q_unpad, shape: ", q_unpad.shape)
    # print("k_unpad, shape: ", k_unpad.shape)
    # print("v_unpad, shape: ", v_unpad.shape)
    out_unpad, sm_lse = flash_attn_varlen_func(
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        causal=causal,
    )
    out = output_pad_fn(out_unpad)
    dropout_mask = None

    out_ref, attn_ref = attention_ref(
        q,
        k,
        v,
        query_padding_mask,
        key_padding_mask,
        causal=causal,
    )
    out_pt, attn_pt = attention_ref(
        q,
        k,
        v,
        query_padding_mask,
        key_padding_mask,
        causal=causal,
        upcast=False,
        reorder_ops=True,
    )

ganeshcolfax's avatar
ganeshcolfax committed
285
286
287
288
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
289

Tri Dao's avatar
Tri Dao committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    # g = torch.randn_like(out)
    # if d <= 128:
    #     (
    #         dq_unpad,
    #         dk_unpad,
    #         dv_unpad,
    #     ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
    #     dk = dk_pad_fn(dk_unpad)
    #     dv = dk_pad_fn(dv_unpad)
    #     (
    #         dq_ref,
    #         dk_ref,
    #         dv_ref,
    #     ) = torch.autograd.grad(out_ref, (q, k, v), g)
    #     (
    #         dq_pt,
    #         dk_pt,
    #         dv_pt,
    #     ) = torch.autograd.grad(out_pt, (q, k, v), g)
    #     dq = dq_pad_fn(dq_unpad)
    #     print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
    #     print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
    #     print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
    #     print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
    #     print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
    #     print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
    #     print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
    #     print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
    #     print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
    #     print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
    #     print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
    #     print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

    # Check that FlashAttention's numerical error is at most twice the numerical error
    # of a Pytorch implementation.
ganeshcolfax's avatar
ganeshcolfax committed
325
    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
326
327
328
329
330

    # if d <= 128:
    #     assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
    #     assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
    #     assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()