test_sparse_sla_attn.py 5.23 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import pytest
import torch
import torch.nn.functional as F
import argparse

import pdb

from flash_attn import (
    flash_attn_func,
    sparse_attn_func,
    sparse_attn_with_sla,
)
from flash_attn.utils.sparse_utils import (
    get_block_map_meansim,
    hyperparameter_check,
    block_map_to_block_offset,
    block_map_lut,
    block_map_to_block_offset_triton,
    block_map_lut_triton,
)

pytestmark = pytest.mark.skipif(
    not torch.cuda.is_available(),
    reason="Sparse attention tests require CUDA.",
)

DEVICE = "cuda"
BLOCK_K = 64
INVALID_OFFSET = 10000000


def _default_dtype():
    if torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)():
        return torch.bfloat16
    return torch.float16


DTYPE = _default_dtype()


def precision_metric(out1, out2):
    x, xx = out1.float(), out2.float()
    cos_sim = F.cosine_similarity(x.reshape(1, -1), xx.reshape(1, -1)).item()
    # Avoid division by zero
    xx_abs_sum = xx.abs().sum()
    l1 = ((x - xx).abs().sum() / (xx_abs_sum + 1e-8)).item()
    rmse = torch.sqrt(torch.mean((x - xx) ** 2)).item()
    max_diff = (x - xx).abs().max().item()
    return {
        "cos_sim": cos_sim,
        "l1": l1,
        "rmse": rmse,
        "max_diff": max_diff,
    }


def _column_buffers(batch, heads, num_q_blocks):
    column_count = torch.zeros((batch, heads, num_q_blocks), dtype=torch.int32, device=DEVICE)
    column_index = torch.zeros((batch, heads, num_q_blocks, 1), dtype=torch.int32, device=DEVICE)
    return column_count, column_index


def test_discrete_block_selection_fixed_matches_manual():
    torch.manual_seed(42)

    batch, seqlen, heads, headdim = 2, 256, 2, 128
    q = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)
    k = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)
    v = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)

    num_q_blocks = (seqlen + BLOCK_K - 1) // BLOCK_K
    block_count = torch.full((batch, heads, num_q_blocks), 2, dtype=torch.int32, device=DEVICE)
    block_offset = torch.full(
        (batch, heads, num_q_blocks, num_q_blocks), INVALID_OFFSET, dtype=torch.int32, device=DEVICE
    )
    block_offset[:, :, :, 0] = 0
    block_offset[:, :, :, 1] = 2 * BLOCK_K
    column_count, column_index = _column_buffers(batch, heads, num_q_blocks)

    out_discrete = sparse_attn_func(
        q,
        k,
        v,
        block_count=block_count,
        block_offset=block_offset,
        column_count=column_count,
        column_index=column_index,
        causal=False,
    )

    q_seq = q[0].float()
    k_seq = k[0].float()
    v_seq = v[0].float()

    k_selected = torch.cat([k_seq[0:BLOCK_K], k_seq[2 * BLOCK_K : 3 * BLOCK_K]], dim=0)
    v_selected = torch.cat([v_seq[0:BLOCK_K], v_seq[2 * BLOCK_K : 3 * BLOCK_K]], dim=0)

    q_t = q_seq.transpose(0, 1)
    k_t = k_selected.transpose(0, 1)
    v_t = v_selected.transpose(0, 1)

    scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * (headdim ** -0.5)
    attn = torch.softmax(scores, dim=-1)
    out_manual = torch.matmul(attn, v_t).transpose(0, 1).to(DTYPE)

    cos_sim = F.cosine_similarity(
        out_manual.reshape(1, -1).float(), out_discrete[0].reshape(1, -1).float()
    ).item()
    assert cos_sim > 0.99

'''
def test_sparse_linear_attention_check():
    torch.manual_seed(42)

    batch, seqlen, heads, headdim = 1, 19440, 5, 128
    q = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)
    k = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)
    v = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)

    sparsity = 0.7
    topk = 1.0 - sparsity

    out_sla = sparse_attn_with_sla(
        q,
        k,
        v,
        topk=topk,
        feature_map="softmax",
        return_sparsity=False,
    )
'''

def test_sparse_linear_attention_check(dtype=torch.bfloat16, headdim=128):
    torch.manual_seed(42)

    batch, seqlen, heads = 1, 19440, 5
    # batch, seqlen, heads = 1, 256, 2
    q = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)
    k = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)
    v = torch.randn(batch, seqlen, heads, headdim, device=DEVICE, dtype=DTYPE)

    sparsity = 0.3
    topk = 1.0 - sparsity

    out_sla = sparse_attn_with_sla(
        q,
        k,
        v,
        topk=topk,
        feature_map="softmax",
        use_bf16 = (True if dtype==torch.bfloat16 else False),
        use_fp8= (True if dtype==torch.float8_e4m3fn else False),
        return_sparsity=False,
    )



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dtype", type=str, choices=["bf16", "fp16", "fp8"], default="bf16", help="Data type to use for testing (bf16, fp16 or fp8)")
    parser.add_argument("--dim", type=int, choices=[64, 128], default=128, help="Dim to use for testing (64, 128)")
    parser.add_argument('--prof', default=False, action='store_true', help='prof or not')

    args = parser.parse_args()

    torch_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
    torch_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch_dtype
    print("dtype:", torch_dtype)
    test_sparse_linear_attention_check(torch_dtype, args.dim)
    print("Test passed.")