"tests_mpi/test_internode.py" did not exist on "e57e9270aa73eddb733e6cc8c2a14ab6a378626c"
test_flash_mla_sparse_prefill.py 6.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import time
import sys

import torch
import kernelkit as kk

from lib import TestParam
import lib
import ref

_counter = kk.Counter()

@torch.inference_mode()
def run_test(p: TestParam) -> bool:
    if p.seed == -1:
        global _counter
        p.seed = _counter.next()

    print("================")
    print(f"Running on {p}")
    torch.cuda.empty_cache()

    t = lib.generate_testcase(p)
    torch.cuda.synchronize()
    
    def run_prefill():
        return lib.run_flash_mla_sparse_fwd(p, t, False)
    
    prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill()
    torch.cuda.synchronize()

    if p.num_runs > 0:
        flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
        prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd")
        prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
        prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
        print(f"Prefill:  {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps")

    if p.check_correctness:
        torch.cuda.synchronize()
        ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t)
        ref_lse[ref_lse == float("-inf")] = float("+inf")
        torch.cuda.synchronize()

        is_correct = True
        is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
        is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
        is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)

        return is_correct
    else:
        return True

zhanghj2's avatar
zhanghj2 committed
54
55
56
def get_gcn_arch_name() -> str:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return GPU_ARCH.split(':')[0]
57
58

if __name__ == '__main__':
zhanghj2's avatar
zhanghj2 committed
59
60
61
    if get_gcn_arch_name() == "gfx928":
        print("[WARNING] gfx928 architecture is not supported.")
        exit(0) 
62
63
64
65
66
67
68
69
70
71
72
    device = torch.device("cuda:0")
    torch.set_default_dtype(torch.bfloat16)
    torch.set_default_device(device)
    torch.cuda.set_device(device)
    torch.set_float32_matmul_precision('high')

    correctness_cases = [
        # Regular shapes
        TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)
        for d_qk in [512, 576]
        for h_q in [
zhanghj2's avatar
zhanghj2 committed
73
            16, 128, 64
74
75
76
77
78
79
80
81
82
83
84
85
        ]
        for s_kv, topk in [
            # Regular shapes
            (128, 128),
            (256, 256),
            (512, 512),

            # Irregular shapes
            (592, 128),
            (1840, 256),
            (1592, 384),
            (1521, 512),
86
            (3000, 2048),
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            # Irregular shapes with OOB TopK
            (95, 128),
            (153, 256),
            (114, 384),
        ]
        for s_q in [
            1, 62, 213
        ]
    ]

    correctness_cases_with_features = [
        TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)
        for d_qk in [512, 576]
        for h_q in [
zhanghj2's avatar
zhanghj2 committed
101
            16, 128, 64
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        ]
        for s_kv, topk in [
            (592, 128),
            (1840, 256),
            (1592, 384),
            (1521, 512),

            (95, 128),
            (153, 256),
            (114, 384),
        ]
        for s_q in [62, 213]
        for have_sink_lse in [False, True]
        for have_attn_sink in [False, True]
        for have_topk_length in [False, True]
    ]

    corner_cases = [
        TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
        for d_qk in [512, 576]
        for h_q in [
zhanghj2's avatar
zhanghj2 committed
123
            16, 128, 64
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        ]
        for s_q, s_kv, topk in [
            (1, 128, 128),
            (1, 256, 256),
            (1234, 4321, 4096),
            (4096, 2048, 2048)
        ]
    ] + [
        # In these cases, some blocks may not have any valid topk indices
        TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
        for d_qk in [512, 576]
        for h_q in [
            128, 64
        ]
        for s_kv, topk in [
            (32, 2048),
            (64, 8192)
        ]
        for s_q in [1, 1024]
    ] + [
        # In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape
        TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
        for d_qk in [512, 576]
        for h_q in [
            128, 64
        ]
    ]

    performance_case_templates = [
        # V3.2
        (576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),
155
        (576, 64, 2048, [8192, 32768, 65536, 98304, 131072]),
156
157
158
159
        # MODEL1 CONFIG1
        (512, 64, 512, [8192, 32768, 49152, 65536]),
        # MODEL1 CONFIG2
        (512, 128, 1024, [8192, 32768, 49152, 65536]),
zhanghj2's avatar
zhanghj2 committed
160
        (512, 16, 1024, [8192, 32768, 49152, 65536]),
161
162
163
    ]

    performance_cases = [
zhanghj2's avatar
zhanghj2 committed
164
        TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length)
165
166
        for (d_qk, h_q, topk, s_kv_list) in performance_case_templates
        for s_q in [4096]
167
        for have_attn_sink in [False, True]
zhanghj2's avatar
zhanghj2 committed
168
        for have_topk_length in [False, True]
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        for s_kv in s_kv_list
    ]

    testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases

    is_no_cooldown = lib.is_no_cooldown()
    failed_cases = []
    for test in testcases:
        if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown:
            time.sleep(0.3)
        is_correct = run_test(test)
        if not is_correct:
            failed_cases.append(test)
    
    if len(failed_cases) > 0:
        print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
        for case in failed_cases:
            print(f"    {case}")
        sys.exit(1)
    else:
        print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")