"...AutoBuildImmortalWrt.git" did not exist on "7509ba9d61f037159b2fa171c79be016f2ecd7c4"
test_flash_attention.py 6.85 KB
Newer Older
1
import math
zbian's avatar
zbian committed
2

oahzxl's avatar
oahzxl committed
3
import pytest
4
import torch
oahzxl's avatar
oahzxl committed
5
from einops import rearrange
6

7
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
8
from colossalai.testing import clear_cache_before_run, parameterize
oahzxl's avatar
oahzxl committed
9

10
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
11
    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
12
13

DTYPE = [torch.float16, torch.bfloat16, torch.float32]
14

oahzxl's avatar
oahzxl committed
15

16
17
18
19
20
21
22
23
def attention_ref(q, k, v, attn_mask=None, causal=False):
    """
    attention output of the control group
    """
    dtype_og = q.dtype
    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
    d = q.shape[-1]
    scale = 1.0 / math.sqrt(d)
24
    scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
25
26

    if attn_mask is not None:
27
        scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
28
29
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
30
        scores.masked_fill_(causal_mask, float("-inf"))
31
32
    attention = torch.softmax(scores, dim=-1)

33
    output = torch.einsum("bhts,bshd->bthd", attention, v)
34
35
36
37
    output = rearrange(output, "b s h d -> b s (h d)")

    # Modify the data at the positions of the mask to 0
    if attn_mask is not None:
38
        output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0)
39
40

    return output.to(dtype=dtype_og)
oahzxl's avatar
oahzxl committed
41

42

43
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
44
@clear_cache_before_run()
45
46
47
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
48
def test_attention_gpt(proj_shape, dtype, dropout):
49
    (B, S, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
50
51
    D = H * D_HEAD

52
53
54
    q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
zbian's avatar
zbian committed
55

56
    mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
57
58
    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)

59
    attn = ColoAttention(D, H, dropout=dropout)
60
    y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
zbian's avatar
zbian committed
61
62
63

    assert list(y.shape) == [B, S, D]

64
65
66
    out_ref = attention_ref(q, k, v, mask, causal=True)

    # check gradients
zbian's avatar
zbian committed
67
    dy = torch.rand_like(y)
68
69
70
71
72
73
74
    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

    torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
zbian's avatar
zbian committed
75
76


77
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
78
@clear_cache_before_run()
79
80
81
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
82
def test_attention_bert(proj_shape, dtype, dropout):
83
    (B, S, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
84
85
    D = H * D_HEAD

86
87
88
    q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
zbian's avatar
zbian committed
89
90

    # attention mask of shape [B, S] with zero padding to max length S
91
    mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")
zbian's avatar
zbian committed
92

93
    attn = ColoAttention(D, H, dropout=dropout)
zbian's avatar
zbian committed
94
95
96
97
    y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)

    assert list(y.shape) == [B, S, D]

98
99
    out_ref = attention_ref(q, k, v, mask, causal=False)

zbian's avatar
zbian committed
100
    dy = torch.rand_like(y)
101
102
103
104
105
106
107
    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

    torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
zbian's avatar
zbian committed
108
109


110
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
111
@clear_cache_before_run()
112
113
114
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
115
def test_attention_no_mask(proj_shape, dtype, dropout):
116
    (B, S, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
117
118
    D = H * D_HEAD

119
120
121
    q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
zbian's avatar
zbian committed
122

123
    attn = ColoAttention(D, H, dropout=dropout)
zbian's avatar
zbian committed
124
125
126
127
    y = attn(q, k, v)

    assert list(y.shape) == [B, S, D]

128
129
    out_ref = attention_ref(q, k, v, None, causal=False)

zbian's avatar
zbian committed
130
    dy = torch.rand_like(y)
131
132
133
134
135
136
137
    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

    torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
zbian's avatar
zbian committed
138
139


140
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
141
@clear_cache_before_run()
142
143
144
@parameterize("proj_shape", [(6, 24, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
145
def test_cross_attention(proj_shape, dtype, dropout):
146
    (B, S, T, H, D_HEAD) = proj_shape
zbian's avatar
zbian committed
147
148
    D = H * D_HEAD

149
150
151
    q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
152

153
    attn = ColoAttention(D, H, dropout=dropout)
zbian's avatar
zbian committed
154
    y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
155

zbian's avatar
zbian committed
156
    assert list(y.shape) == [B, T, D]
157

158
159
    out_ref = attention_ref(q, k, v, None, causal=True)

zbian's avatar
zbian committed
160
    dy = torch.rand_like(y)
161
162
163
164
165
166
    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

    torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
167
    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"