test_flash_attention.py 4.25 KB
Newer Older
zbian's avatar
zbian committed
1
2
import random

oahzxl's avatar
oahzxl committed
3
import pytest
4
import torch
oahzxl's avatar
oahzxl committed
5
from einops import rearrange
6

7
8
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
9
from colossalai.testing import clear_cache_before_run, parameterize
oahzxl's avatar
oahzxl committed
10

11
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
12
    from colossalai.kernel.cuda_native import ColoAttention
13
14
15
    from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType

DTYPE = [torch.float16, torch.bfloat16, torch.float32]
16

oahzxl's avatar
oahzxl committed
17
18
19
20
21
22
23
24
25
26
27

def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    for z in range(Z):
        for h in range(H):
            p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    ref_out = torch.matmul(p, v)
    return ref_out

28

29
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
30
@clear_cache_before_run()
31
32
33
34
35
@parameterize('proj_shape', [(1, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_gpt(proj_shape, dtype):
    # TODO check output value
    (B, S, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
36
37
38
39
40
41
42
43
44
    D = H * D_HEAD

    c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
    attn = ColoAttention(D, H, dropout=0.1)

    x = torch.randn((B, S, D), dtype=dtype, device="cuda")

    qkv = c_attn(x)
    q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
45
46
47
48
49

    mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)

    y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
zbian's avatar
zbian committed
50
51
52
53
54

    assert list(y.shape) == [B, S, D]

    dy = torch.rand_like(y)
    y.backward(dy)
zbian's avatar
zbian committed
55
56


57
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
58
@clear_cache_before_run()
59
60
61
62
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_bert(proj_shape, dtype):
    (B, S, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    D = H * D_HEAD

    c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
    attn = ColoAttention(D, H, dropout=0.1)

    x = torch.randn((B, S, D), dtype=dtype, device="cuda")
    # attention mask of shape [B, S] with zero padding to max length S
    mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)

    qkv = c_attn(x)
    q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
    y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)

    assert list(y.shape) == [B, S, D]

    dy = torch.rand_like(y)
    y.backward(dy)


83
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
84
@clear_cache_before_run()
85
86
87
88
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_no_mask(proj_shape, dtype):
    (B, S, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    D = H * D_HEAD

    c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
    attn = ColoAttention(D, H, dropout=0.1)

    x = torch.randn((B, S, D), dtype=dtype, device="cuda")
    qkv = c_attn(x)
    q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
    y = attn(q, k, v)

    assert list(y.shape) == [B, S, D]

    dy = torch.rand_like(y)
    y.backward(dy)


105
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
106
@clear_cache_before_run()
107
108
109
110
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_cross_attention(proj_shape, dtype):
    (B, S, T, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
111
112
113
114
    D = H * D_HEAD

    q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
    kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
115

zbian's avatar
zbian committed
116
    attn = ColoAttention(D, H, dropout=0.1)
117

zbian's avatar
zbian committed
118
119
    src = torch.randn((B, S, D), dtype=dtype, device="cuda")
    tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
120

zbian's avatar
zbian committed
121
122
123
124
125
    q = q_attn(tgt)
    kv = kv_attn(src)
    q = rearrange(q, 'b s (h d) -> b s h d', h=H)
    k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
    y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
126

zbian's avatar
zbian committed
127
    assert list(y.shape) == [B, T, D]
128

zbian's avatar
zbian committed
129
130
    dy = torch.rand_like(y)
    y.backward(dy)