test_flash_attention.py 6.35 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
import pytest
2
import torch
oahzxl's avatar
oahzxl committed
3
from einops import rearrange
4

5
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
6
7

if HAS_FLASH_ATTN:
8
    from colossalai.kernel.cuda_native.flash_attention import (
zbian's avatar
zbian committed
9
10
11
12
13
        MaskedFlashAttention,
        flash_attention_q_k_v,
        flash_attention_q_kv,
        flash_attention_qkv,
    )
14
15
16

if HAS_TRITON:
    from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
oahzxl's avatar
oahzxl committed
17

18
19
20
if HAS_MEM_EFF_ATTN:
    from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention

oahzxl's avatar
oahzxl committed
21
22
23
24
25
26
27
28
29
30
31

def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    for z in range(Z):
        for h in range(H):
            p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    ref_out = torch.matmul(p, v)
    return ref_out

32

33
34
@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
oahzxl's avatar
oahzxl committed
35
36
37
38
39
40
41
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    torch.manual_seed(20)
    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    sm_scale = 0.3
    dout = torch.randn_like(q)
42

oahzxl's avatar
oahzxl committed
43
44
45
46
47
48
49
    ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
    ref_out.backward(dout)
    ref_dv, v.grad = v.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dq, q.grad = q.grad.clone(), None

    # triton implementation
50
51
52
53
54
55
56
57
58
59
    tri_out = triton_flash_attention(q, k, v, sm_scale)
    tri_out.backward(dout)
    tri_dv, v.grad = v.grad.clone(), None
    tri_dk, k.grad = k.grad.clone(), None
    tri_dq, q.grad = q.grad.clone(), None
    # compare
    assert torch.allclose(ref_out, tri_out, atol=1e-3)
    assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
    assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
    assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
oahzxl's avatar
oahzxl committed
60
61


62
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
63
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
oahzxl's avatar
oahzxl committed
64
65
66
67
68
69
70
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    torch.manual_seed(20)
    q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    sm_scale = 0.3
    dout = torch.randn_like(q)
71

oahzxl's avatar
oahzxl committed
72
73
74
75
76
77
    # reference implementation
    ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
    ref_out.backward(dout)
    ref_dv, v.grad = v.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dq, q.grad = q.grad.clone(), None
78

oahzxl's avatar
oahzxl committed
79
80
81
    # flash implementation
    q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
    dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
82
83
84
85
86
87
88
89
90
    for i in range(3):
        if i == 0:
            tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True)
        elif i == 1:
            kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1)
            tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True)
        else:
            qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1)
            tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True)
oahzxl's avatar
oahzxl committed
91

92
93
94
95
96
        tri_out.backward(dout, retain_graph=True)

        if i == 0:
            tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
            tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
zbian's avatar
zbian committed
97
                                                  (tri_out, tri_dq, tri_dk, tri_dv))
98
99
100
101
        elif i == 1:
            tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
            tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
            tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
zbian's avatar
zbian committed
102
                                                  (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
103
104
105
106
        else:
            tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
            tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
            tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
zbian's avatar
zbian committed
107
                                                  (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
108
109
110
111
112
113
114
115

        # compare
        assert torch.allclose(ref_out, tri_out, atol=1e-3)
        assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
        assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
        assert torch.allclose(ref_dq, tri_dq, atol=1e-3)


zbian's avatar
zbian committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)

    qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    attention_mask = torch.randint(2, (Z, H)).cuda().bool()

    out = attn(qkv, attention_mask)

    dout = torch.rand_like(out)
    out.backward(dout)


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)

    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()

    out = attn(q, k, v, attention_mask=LowerTriangularMask())

    dout = torch.rand_like(out)
    out.backward(dout)


145
146
if __name__ == '__main__':
    test_flash_attention(3, 4, 2, 16)