"lib/engines/sglang/Cargo.toml" did not exist on "03b0101e4d4013874e33f8144c9793567e762c9f"
test_flash_attention.py 5.55 KB
Newer Older
1
2
3
4
5
6
import math
from copy import copy

import torch
from torch.testing import assert_close

7
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer.attn import invert_mask
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils import get_current_device, set_seed

DTYPE = [torch.float16, torch.bfloat16]
B, N, S, D = 2, 8, 256, 32

TOL_MAP = {
    torch.float16: {"atol": 5e-4, "rtol": 2e-3},
    torch.bfloat16: {},
}


def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0):
    head_dim = q.size(-1)
    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
    if attn_mask is not None:
        attn_weights = attn_weights + attn_mask
    attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype)
    attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True)
    attn_output = torch.matmul(attn_weights, v)
    return attn_output


def gen_padded_kwargs(dtype: torch.dtype):
    padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
    padding_mask[0, : S // 4] = 0
    return (
        ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask),
        padding_mask,
    )


def gen_padded_causal_kwargs(dtype: torch.dtype):
    padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
    padding_mask[0, S // 2 :] = 0
    return (
        ColoAttention.prepare_attn_kwargs(
            (B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True
        ),
        padding_mask,
    )


def gen_causal_kwargs(dtype: torch.dtype):
    return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None


def gen_custom_kwargs(dtype: torch.dtype):
    attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device())
    attn_mask[0, : S // 2, S // 2 :] = 0
    attn_mask[0, S // 2 :, : S // 2] = 0
    attn_mask[1, :, S // 4 :] = 0
    attn_mask = invert_mask(attn_mask).unsqueeze(1)
    assert not torch.all(attn_mask != 0, dim=-1).any()
    return {"attention_mask": attn_mask}, None


def post_process_kwargs_for_raw_attn(attn_kwargs: dict):
    if "attention_mask_type" in attn_kwargs:
        attn_kwargs = copy(attn_kwargs)
        mask_type = attn_kwargs.pop("attention_mask_type")
        attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
    return attn_kwargs


def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None):
    tols = TOL_MAP[dtype]
    q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
    k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
    v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
    q_flash = q.clone().detach().requires_grad_(True)
    k_flash = k.clone().detach().requires_grad_(True)
    v_flash = v.clone().detach().requires_grad_(True)
    attn_mask = attn_kwargs.get("attention_mask", None)
    ref_output = attention_ref(q, k, v, attn_mask)
    output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs)
    if padding_mask is not None:
        # [B, Sq] -> [B, 1, Sq, 1]
        padding_mask = padding_mask[:, None, :, None].logical_not()
        ref_output = ref_output.masked_fill(padding_mask, 0)
        output = output.masked_fill(padding_mask, 0)
    assert_close(output, ref_output, **tols)
    output.mean().backward()
    ref_output.mean().backward()
    assert_close(q.grad, q_flash.grad, **tols)
    assert_close(k.grad, k_flash.grad, **tols)
    assert_close(v.grad, v_flash.grad, **tols)


@clear_cache_before_run()
@parameterize("dtype", DTYPE)
def test_flash_attn_func(dtype: torch.dtype):
    torch.backends.cudnn.deterministic = True
    set_seed(0)
    # (func, name, need_postprocess)
    avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
    avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
    avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
    for ext_cls in FlashAttentionLoader.REGISTRY:
        ext = ext_cls()
        if ext.is_available():
            ext.assert_compatible()
            avail_attn_funcs.append((ext.load(), ext.name, True))
    for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY:
        ext = ext_cls()
        if ext.is_available():
            ext.assert_compatible()
            avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))

    test_sets = {
        "none": (lambda dtype: ({}, None), avail_attn_funcs),
        "padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs),
        "padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs),
        "causal": (gen_causal_kwargs, avail_attn_funcs),
        "custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs),
    }

    for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items():
        attn_kwargs, padding_mask = gen_kwargs_func(dtype)
        for attn_func, name, need_postprocess in attn_funcs:
            print(f"{dtype}, {name}, {mask_type}")
            if need_postprocess:
                check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
            else:
                check_attn_func(dtype, attn_func, attn_kwargs, padding_mask)


if __name__ == "__main__":
    test_flash_attn_func()