test_flash_mla_dense_decoding.py 8.88 KB
Newer Older
1
2
3
4
import argparse
import math
import random
import dataclasses
5
from typing import Tuple
6
7
8

import torch

9
import kernelkit as kk
10
11
12
13
import flash_mla

@dataclasses.dataclass
class TestParam:
Jiashi Li's avatar
Jiashi Li committed
14
15
16
    b: int    # Batch size
    s_q: int  # Number of queries for one request
    s_k: int  # Seq len, or mean seq len if varlen == True
17
18
19
20
21
    is_varlen: bool
    is_causal: bool
    test_performance: bool = True
    have_zero_seqlen_k: bool = False
    block_size: int = 64
Jiashi Li's avatar
Jiashi Li committed
22
23
    h_q: int = 128    # Number of q heads
    h_kv: int = 1     # Number of kv heads
24
25
26
27
28
    d: int = 576      # Q/K head dim (= dv + RoPE dim)
    dv: int = 512     # V head dim
    seed: int = 0


29
def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    """
    Generate test data from a given configuration
    Return: [cache_seqlens, q, block_table, blocked_k]
    Pay attention: This function changes the random seed
    """
    random.seed(t.seed)
    torch.manual_seed(t.seed)
    torch.cuda.manual_seed(t.seed)
    torch.backends.cudnn.deterministic = True

    assert t.h_q % t.h_kv == 0

    cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu')
    if t.is_varlen:
        for i in range(t.b):
            cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q)

    if t.have_zero_seqlen_k:
        zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0
        cache_seqlens_cpu[zeros_mask] = 0

51
52
    max_seqlen = int(cache_seqlens_cpu.max().item())
    max_seqlen_pad = kk.cdiv(max_seqlen, 256) * 256
53
54
    cache_seqlens = cache_seqlens_cpu.cuda()

55
    q = torch.randn(t.b, t.s_q, t.h_q, t.d) / 10
56
57
58
59
60
61
62
    q.clamp_(min=-1.0, max=1.0)

    block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size)
    block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1)
    blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10
    blocked_k.clamp_(min=-1.0, max=1.0)

63
64
65
66
67
68
69
70
    for i in range(t.b):
        cur_len = int(cache_seqlens_cpu[i].item())
        cur_num_blocks = kk.cdiv(cur_len, t.block_size)
        blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
        if cur_len % t.block_size != 0:
            blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan")
        block_table[i][cur_num_blocks:] = 2147480000
    return cache_seqlens, q, block_table, blocked_k
71
72
73
74
75
76
77
78
79
80
81
82
83


def reference_torch(
    cache_seqlens: torch.Tensor,    # [batch_size]
    block_table: torch.Tensor,      # [batch_size, ?]
    q: torch.Tensor,    # [batch_size, s_q, h_q, d]
    blocked_k: torch.Tensor,    # [?, block_size, h_kv, d]
    dv: int,
    is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    A reference implementation in PyTorch
    """
Jiashi Li's avatar
Jiashi Li committed
84

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def scaled_dot_product_attention(
        batch_idx: int,
        query: torch.Tensor,    # [h_q, s_q, d]
        kv: torch.Tensor,      # [h_kv, s_k, d]
        dv: int,
        is_causal,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        h_q = query.size(0)
        h_kv = kv.size(0)
        s_q = query.shape[-2]
        s_k = kv.shape[-2]
        query = query.float()
        kv = kv.float()
        if h_kv != 1:
            kv = kv.repeat_interleave(h_q // h_kv, dim=0)
        kv[kv != kv] = 0.0
Jiashi Li's avatar
Jiashi Li committed
101
        attn_weight = query @ kv.transpose(-2, -1)  # [h_q, s_q, s_k]
102
        if is_causal and query.size(1) > 1:
103
104
105
106
107
108
109
            mask = torch.ones(s_q, s_k, dtype=torch.bool)
            if is_causal:
                mask = mask.tril(diagonal=s_k - s_q)
            attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
            attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
            attn_weight += attn_bias.to(q.dtype)
        attn_weight /= math.sqrt(query.size(-1))
Jiashi Li's avatar
Jiashi Li committed
110
        lse = attn_weight.logsumexp(dim=-1)  # [h_q, s_q]
111
112
113
114
115
116
        attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
        output = attn_weight @ kv[..., :dv]    # [h_q, s_q, dv]
        # Correct for q tokens which has no attendable k
        lonely_q_mask = (lse == float("-inf"))
        output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
        lse[lonely_q_mask] = float("+inf")
Jiashi Li's avatar
Jiashi Li committed
117

118
119
120
121
122
123
124
125
126
        return output, lse

    b, s_q, h_q, d = q.size()
    block_size = blocked_k.size(1)
    h_kv = blocked_k.size(2)
    cache_seqlens_cpu = cache_seqlens.cpu()
    out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
    lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
    for i in range(b):
127
128
        cur_len = int(cache_seqlens_cpu[i].item())
        cur_num_blocks = kk.cdiv(cur_len, block_size)
129
130
131
132
133
134
135
        cur_block_indices = block_table[i][0: cur_num_blocks]
        cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
        cur_out, cur_lse = scaled_dot_product_attention(
            i,
            q[i].transpose(0, 1),
            cur_kv.transpose(0, 1),
            dv,
136
            is_causal
137
138
139
        )
        out_ref[i] = cur_out.transpose(0, 1)
        lse_ref[i] = cur_lse
140
    out_ref = out_ref.to(q.dtype)
141
    return out_ref, lse_ref
Jiashi Li's avatar
Jiashi Li committed
142

143
144
145
146
147
148
149
150

@torch.inference_mode()
def test_flash_mla(t: TestParam):
    print('-------------------------------')
    print(f"Running on {t}...")

    # Generating test data
    torch.cuda.synchronize()
151
    cache_seqlens, q, block_table, blocked_k, = generate_test_data(t)
152

153
    tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
154
155
156
157

    def run_flash_mla():
        return flash_mla.flash_mla_with_kvcache(
            q,
158
            blocked_k,
159
160
161
162
163
            block_table,
            cache_seqlens,
            t.dv,
            tile_scheduler_metadata,
            num_splits,
164
            causal=t.is_causal
165
166
167
        )

    out_ans, lse_ans = run_flash_mla()
168
169
170
171
172
    out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal)
    is_correct = True
    is_correct &= kk.check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6)
    is_correct &= kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536)
    assert is_correct
173
174

    if t.test_performance:
zhanghj2's avatar
zhanghj2 committed
175
        time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla")
176
177

        mean_attended_seqlens = cache_seqlens.float().mean().item()
Jiashi Li's avatar
Jiashi Li committed
178
179
180
        compute_volume_flop = t.b * t.h_q * t.s_q * sum([
            2 * t.d * mean_attended_seqlens,   # Q * K^T
            2 * mean_attended_seqlens * t.dv,  # attention * V
181
182
        ])
        q_elem_size = torch.bfloat16.itemsize
183
        kv_token_size = t.d * torch.bfloat16.itemsize
Jiashi Li's avatar
Jiashi Li committed
184
185
        memory_volume_B = t.b * sum([
            t.s_q * t.h_q * (t.d * q_elem_size),    # Q
186
            mean_attended_seqlens * t.h_kv * kv_token_size,    # K/V
Jiashi Li's avatar
Jiashi Li committed
187
            t.s_q * t.h_q * (t.dv * q_elem_size),   # Output
188
189
190
191
        ])
        achieved_tflops = compute_volume_flop / time_usage / 1e12
        achieved_gBps = memory_volume_B / time_usage / 1e9

Jiashi Li's avatar
Jiashi Li committed
192
        print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
193
194
195
196
197
198
199
200


def main(torch_dtype):
    device = torch.device("cuda:0")
    torch.set_default_dtype(torch_dtype)
    torch.set_default_device(device)
    torch.cuda.set_device(device)

201
202
203
    cc_major, cc_minor = torch.cuda.get_device_capability()
    assert cc_major == 9, "Dense MLA decoding is only supported on sm90 (Hopper) currently."

204
    correctness_cases = [
205
        TestParam(b, s_q, s_k, is_varlen, is_causal, test_performance=False, have_zero_seqlen_k=False, block_size=64, h_q=h_q, h_kv=h_kv)
206
207
208
        for b in [1, 2, 6, 64]
        for s_q in [1, 2, 4]
        for s_k in [20, 140, 4096]
209
210
        for h_q in [1, 3, 9, 63, 64, 126, 128]
        for h_kv in [1, 2, 3, 8]
211
212
        for is_varlen in [False, True]
        for is_causal in [False, True]
213
        if h_q % h_kv == 0
214
215
216
217
    ]

    corner_cases = [
        # Cases where some kv cache have zero length
218
219
220
221
222
        TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, test_performance=False, have_zero_seqlen_k=True, h_q=h_q, h_kv=h_kv)
        for h_q in [1, 3, 9, 63, 64, 126, 128]
        for h_kv in [1, 2, 3, 8]
        for is_causal in [False, True]
        if h_q % h_kv == 0
223
224
225
    ]

    performance_cases = [
zhanghj2's avatar
zhanghj2 committed
226
        TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, h_q = h_q, test_performance=True)
227
        for is_causal in [False, True]
228
        for s_q in [1, 2]
zhanghj2's avatar
zhanghj2 committed
229
        for h_q in [16, 64, 128]
230
231
232
233
        for s_k in [4096, 8192, 16384, 32768]
    ]

    testcases = correctness_cases + corner_cases + performance_cases
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    for testcase in testcases:
        test_flash_mla(testcase)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dtype",
        type=str,
        choices=["bf16", "fp16"],
        default="bf16",
        help="Data type to use for testing (bf16 or fp16)",
    )

    args = parser.parse_args()

    torch_dtype = torch.bfloat16
    if args.dtype == "fp16":
        torch_dtype = torch.float16

255
    main(torch_dtype)