test_flash_mla_qkvfp8_with_cat.py 10.7 KB
Newer Older
zhanghj2's avatar
zhanghj2 committed
1
2
3
4
5
6
7
import argparse
import math
import random

import torch
import triton

zhanghj2's avatar
zhanghj2 committed
8
from flash_mla import flash_mla_with_kvcache_fp8_with_cat, get_mla_decoding_metadata_dense_fp8
zhanghj2's avatar
zhanghj2 committed
9
10
11
12
13
14
15
16
torch.set_printoptions(precision=4, profile="default", sci_mode=False)

def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
    query = query.float()
    key = key.float() * k_scale
    value = value.float() * k_scale
    key = key.repeat_interleave(h_q // h_kv, dim=0)
    value = value.repeat_interleave(h_q // h_kv, dim=0)
zhanghj2's avatar
zhanghj2 committed
17
18
    tmp =  query @ key.transpose(-2, -1)
    # print("tmp s ", tmp[0, :4, :10])
zhanghj2's avatar
zhanghj2 committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
    if is_causal:
        s_q = query.shape[-2]
        s_k = key.shape[-2]
        attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
        temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
        attn_weight += attn_bias
    lse = attn_weight.logsumexp(dim=-1)
    attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
    return attn_weight @ value, lse


zhanghj2's avatar
zhanghj2 committed
33
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None:
zhanghj2's avatar
zhanghj2 committed
34
35
36
37
38
39
    torch_dtype = x.dtype
    x, y = x.double(), y.double()
    RMSE = ((x - y) * (x - y)).mean().sqrt().item()
    cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
    amax_diff = (x - y).abs().max().item()
    print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
zhanghj2's avatar
zhanghj2 committed
40
41
42
43
44
45
46
    if use_fp8:
        assert cos_diff < 1e-3
    else:
        assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)

    

zhanghj2's avatar
zhanghj2 committed
47
48

@torch.inference_mode()
zhanghj2's avatar
zhanghj2 committed
49
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False,torch_dtype=torch.float16, is_q_bf16=False):
zhanghj2's avatar
zhanghj2 committed
50
    print(
zhanghj2's avatar
zhanghj2 committed
51
        f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
zhanghj2's avatar
zhanghj2 committed
52
    )
zhanghj2's avatar
zhanghj2 committed
53
54
55
    # torch.cuda.empty_cache()

    use_fp8 = torch_dtype == torch.float8_e4m3fn
zhanghj2's avatar
zhanghj2 committed
56
57
58
59
60
61
62
63
64
65
66

    cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
    if varlen:
        for i in range(b):
            cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
    total_seqlens = cache_seqlens.sum().item()
    mean_seqlens = cache_seqlens.float().mean().int().item()
    max_seqlen = cache_seqlens.max().item()
    max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
    print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")

zhanghj2's avatar
zhanghj2 committed
67
    # q = torch.ones(b, s_q, h_q, d)
zhanghj2's avatar
zhanghj2 committed
68
    q = torch.randn(b, s_q, h_q, d)
zhanghj2's avatar
zhanghj2 committed
69
70
71
    # for i in range(576):
    #     q[:, :, :, i] = i
    # q[:, :, 1:, :] = 0
zhanghj2's avatar
zhanghj2 committed
72
    # q = torch.ones(b, s_q, h_q, d)
zhanghj2's avatar
zhanghj2 committed
73
    # print("q ", q[0, 0, 0:3, :10])
zhanghj2's avatar
zhanghj2 committed
74
75
76
77
    block_size = 64
    block_table = torch.arange(
        b * max_seqlen_pad // block_size, dtype=torch.int32
    ).view(b, max_seqlen_pad // block_size)
zhanghj2's avatar
zhanghj2 committed
78
79
80
81
82
83
84
85
86
87
88
    # blocked_k = torch.ones(block_table.numel(), block_size, h_kv, d)
    blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
    # blocked_k[:, :, :, 32:] = 0.0
    # blocked_k[:, 32:, :, :] = 0
    # blocked_k[:, :, :, 4:] = 0
    # blocked_k[:, :32, :, :] = 0
    # blocked_k[:, 16:, :, :] = 0
    for i in range(b):
        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
            float("nan")
        )
zhanghj2's avatar
zhanghj2 committed
89
90
    blocked_v = blocked_k[..., :dv]

zhanghj2's avatar
zhanghj2 committed
91
    tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
zhanghj2's avatar
zhanghj2 committed
92
        cache_seqlens, s_q * h_q // h_kv, h_kv
zhanghj2's avatar
zhanghj2 committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    )

    init_dtype = q.dtype
    def prepare_fp8_input():
        q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None

        if use_fp8:
            nonlocal q, blocked_k, blocked_v
            fp8_dtype = torch.float8_e4m3fn
            descale_q = torch.ones((1), dtype=torch.float32)
            descale_k = torch.ones((1), dtype=torch.float32)

            q_fp8 = q.to(fp8_dtype)
            blocked_k_fp8 = blocked_k.to(fp8_dtype)
            blocked_v_fp8 = blocked_k_fp8[..., :dv]

        return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k

    q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
    
    # print(blocked_v_fp8[0, 32:36, 0, :4])
    if use_fp8:
        q = q_fp8
        blocked_k = blocked_k_fp8
        blocked_v = blocked_v_fp8
    # print(" descale_q  ", descale_q.shape, descale_q.stride())
    # print(" blocked_k ", blocked_k.shape)
    q_nope = q[:, :, :, :512].contiguous()
    q_pe = q[:, :, :, 512:].contiguous()
    if is_q_bf16:
        q_nope = q_nope.to(torch.bfloat16).contiguous()
        q_pe = q_pe.to(torch.bfloat16).contiguous()
zhanghj2's avatar
zhanghj2 committed
125
    def flash_mla():
zhanghj2's avatar
zhanghj2 committed
126
127
128
        return flash_mla_with_kvcache_fp8_with_cat(
            q_nope,
            q_pe,
zhanghj2's avatar
zhanghj2 committed
129
130
131
132
133
134
135
            blocked_k,
            block_table,
            cache_seqlens,
            dv,
            tile_scheduler_metadata,
            num_splits,
            causal=causal,
zhanghj2's avatar
zhanghj2 committed
136
137
            descale_q=descale_q,
            descale_k=descale_k,
zhanghj2's avatar
zhanghj2 committed
138
139
140
        )

    def ref_mla():
zhanghj2's avatar
zhanghj2 committed
141
142
143
        q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
        blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k
        blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v
zhanghj2's avatar
zhanghj2 committed
144
145
146
147
148
149
        out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
        lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
        for i in range(b):
            begin = i * max_seqlen_pad
            end = begin + cache_seqlens[i]
            O, LSE = scaled_dot_product_attention(
zhanghj2's avatar
zhanghj2 committed
150
151
152
                q_[i].transpose(0, 1),
                blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
                blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
zhanghj2's avatar
zhanghj2 committed
153
154
155
156
157
158
159
                h_q=h_q,
                h_kv=h_kv,
                is_causal=causal,
            )
            out[i] = O.transpose(0, 1)
            lse[i] = LSE
        return out, lse
zhanghj2's avatar
zhanghj2 committed
160
    torch.cuda.synchronize()
zhanghj2's avatar
zhanghj2 committed
161
    out_flash, lse_flash = flash_mla()
zhanghj2's avatar
zhanghj2 committed
162
    torch.cuda.synchronize()
zhanghj2's avatar
zhanghj2 committed
163
    out_torch, lse_torch = ref_mla()
zhanghj2's avatar
zhanghj2 committed
164
    # print(" ", out_flash.shape, lse_flash.shape, q.shape)
zhanghj2's avatar
zhanghj2 committed
165
166
    print("out max_diff ", (out_flash - out_torch).abs().max())
    print("lse max_diff ", (lse_flash - lse_torch).abs().max())
zhanghj2's avatar
zhanghj2 committed
167
168
169
170
171
172
173
174

    # print(" diff ", torch.nonzero((lse_flash - lse_torch).abs() > 0.1))
    # print(" diff ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
    # print(" out_torch ", out_torch[0, 0, 0, 0:10])
    # print(" out_flash ", out_flash[0, 0, 0,  0:10])
    # print(" lse_flash ", lse_flash[0,  0:3, :1])
    # print(" lse_torch ", lse_torch[0,  0:3, :1])
    # print(" nan ", torch.nonzero(torch.isnan(out_flash)))
zhanghj2's avatar
zhanghj2 committed
175
    cal_diff(lse_flash, lse_torch, "lse")
zhanghj2's avatar
zhanghj2 committed
176
177
178
    cal_diff(out_flash, out_torch, "out", use_fp8)
    
    if is_prof: return
zhanghj2's avatar
zhanghj2 committed
179
180
    t = triton.testing.do_bench(flash_mla)
    FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
zhanghj2's avatar
zhanghj2 committed
181
    bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
zhanghj2's avatar
zhanghj2 committed
182
183
184
185
    print(
        f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
    )

zhanghj2's avatar
zhanghj2 committed
186
187


zhanghj2's avatar
zhanghj2 committed
188
189
def main(torch_dtype, is_prof=False):
    device = torch.device("cuda:0")
zhanghj2's avatar
zhanghj2 committed
190
191
    init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
    torch.set_default_dtype(init_dtype)
zhanghj2's avatar
zhanghj2 committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    torch.set_default_device(device)
    torch.cuda.set_device(device)
    torch.manual_seed(0)
    random.seed(0)
    '''
    h_kv = 1
    d, dv = 576, 512
    causal = True

    for b in [128]:
        for s in [4096, 8192]:
            for h_q in [16, 32, 64, 128]:  # TP = 8, 4, 2, 1
                for s_q in [1, 2]:  # MTP = 1, 2
                    for varlen in [False, True]:
                        test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
    #                b, s_q,    s,   h_q, h_kv,   d,  dv, causal, varlen'''
    # test_flash_mla(  1,   1,  64,    16,    1, 576, 512,   True,  False, is_prof=is_prof)
    # test_flash_mla_fp8( 1,   1, 1000,     1,    1, 576, 512,   True,  False, is_prof=is_prof)
    # test_flash_mla_fp8( 1,   1, 4096,     8,    1, 576, 512,   True,  False, is_prof=is_prof)
    # test_flash_mla_fp8(32,   1, 4096,     16,    1, 576, 512,   False,  False, is_prof=is_prof)
    # '''
    h_kv = 1
    d, dv = 576, 512
    causal = True
zhanghj2's avatar
zhanghj2 committed
216
   
zhanghj2's avatar
zhanghj2 committed
217

zhanghj2's avatar
zhanghj2 committed
218
219
    # for b in [40, 80]:
    #     for s in [3500, 4000, 8192, 16384]:
zhanghj2's avatar
zhanghj2 committed
220
221
222
    #         for h_q in [16]:
    #             for s_q in [1]:  # MTP = 1, 2
    #                 for varlen in [False]:
zhanghj2's avatar
zhanghj2 committed
223
224
225
226
227
    #                     test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
    # 压测
    for b in [3, 6, 9, 12, 15, 18, 21, 24, 40, 41, 79, 80]:
        for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 12345, 45321]:
            for h_q in [16, 64]:
zhanghj2's avatar
zhanghj2 committed
228
229
                for s_q in [1, 2, 3]:  # MTP = 1, 2
                    for varlen in [False, True]:
zhanghj2's avatar
zhanghj2 committed
230
231
232
                        test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True,torch_dtype)
                        test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True,torch_dtype, True)

zhanghj2's avatar
zhanghj2 committed
233
234
    for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
        for s in [4000]:
zhanghj2's avatar
zhanghj2 committed
235
            for h_q in [16, 64]:
zhanghj2's avatar
zhanghj2 committed
236
237
                for s_q in [1]:  # MTP = 1, 2
                    for varlen in [False]:
zhanghj2's avatar
zhanghj2 committed
238
239
                        test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
                        test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype, True)
zhanghj2's avatar
zhanghj2 committed
240
    # for b in [1]:
zhanghj2's avatar
zhanghj2 committed
241
242
243
    #     for s in [128]:
    #         for h_q in [128]:
    #             for s_q in [2]:  # MTP = 1, 2
zhanghj2's avatar
zhanghj2 committed
244
    #                 for varlen in [False]:
zhanghj2's avatar
zhanghj2 committed
245
246
247
248
249
250
251
252
253
254
    #                     test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
    # for b in [1, 32]:
    #     for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
    #         for h_q in [4, 16, 32, 64]:
    #             for s_q in [1, 2]:  # MTP = 1, 2
    #                 for varlen in [False]:
    #                     test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)                      


    # '''
zhanghj2's avatar
zhanghj2 committed
255
256
257
258
259
260

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dtype",
        type=str,
zhanghj2's avatar
zhanghj2 committed
261
        choices=["bf16", "fp16","e4m3"],
zhanghj2's avatar
zhanghj2 committed
262
        default="bf16",
zhanghj2's avatar
zhanghj2 committed
263
        help="Data type to use for testing (bf16/fp16/e4m3)",
zhanghj2's avatar
zhanghj2 committed
264
265
266
267
268
    )
    parser.add_argument('--prof', default=False, action='store_true', help='prof or not')

    args = parser.parse_args()

zhanghj2's avatar
zhanghj2 committed
269
    torch_dtype = torch.float8_e4m3fn
zhanghj2's avatar
zhanghj2 committed
270
271
    if args.dtype == "fp16":
        torch_dtype = torch.float16
zhanghj2's avatar
zhanghj2 committed
272
273
    elif args.dtype == "e4m3":
        torch_dtype = torch.float8_e4m3fn
zhanghj2's avatar
zhanghj2 committed
274
    main(torch_dtype, args.prof)