test_flash_attention.py 3.82 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
import pytest
2
import torch
oahzxl's avatar
oahzxl committed
3
from einops import rearrange
4

5
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
6
7
8
9
10
11

if HAS_FLASH_ATTN:
    from colossalai.kernel.cuda_native.flash_attention import flash_attention

if HAS_TRITON:
    from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
oahzxl's avatar
oahzxl committed
12
13
14
15
16
17
18
19
20
21
22
23


def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    for z in range(Z):
        for h in range(H):
            p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    ref_out = torch.matmul(p, v)
    return ref_out

24

25
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
oahzxl's avatar
oahzxl committed
26
27
28
29
30
31
32
33
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    torch.manual_seed(20)
    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    sm_scale = 0.3
    dout = torch.randn_like(q)
34

oahzxl's avatar
oahzxl committed
35
36
37
38
39
40
41
    ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
    ref_out.backward(dout)
    ref_dv, v.grad = v.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dq, q.grad = q.grad.clone(), None

    # triton implementation
42
    if HAS_TRITON:
oahzxl's avatar
oahzxl committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        tri_out = triton_flash_attention(q, k, v, sm_scale)
        tri_out.backward(dout)
        tri_dv, v.grad = v.grad.clone(), None
        tri_dk, k.grad = k.grad.clone(), None
        tri_dq, q.grad = q.grad.clone(), None
        # compare
        assert torch.allclose(ref_out, tri_out, atol=1e-3)
        assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
        assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
        assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
    else:
        try:
            tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
        except RuntimeError:
            pass
        else:
            raise TypeError("Error type not match!")


62
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
oahzxl's avatar
oahzxl committed
63
64
65
66
67
68
69
70
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    torch.manual_seed(20)
    q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    sm_scale = 0.3
    dout = torch.randn_like(q)
71

oahzxl's avatar
oahzxl committed
72
73
74
75
76
77
    # reference implementation
    ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
    ref_out.backward(dout)
    ref_dv, v.grad = v.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dq, q.grad = q.grad.clone(), None
78

oahzxl's avatar
oahzxl committed
79
80
81
82
83
84
    # flash implementation
    q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
    tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
    dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
    tri_out.backward(dout, retain_graph=True)
    tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
85
86
    tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
                                          (tri_out, tri_dq, tri_dk, tri_dv))
oahzxl's avatar
oahzxl committed
87
88
89
90
91
92

    # compare
    assert torch.allclose(ref_out, tri_out, atol=1e-3)
    assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
    assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
    assert torch.allclose(ref_dq, tri_dq, atol=1e-3)