test_cutlass_mla.py 3.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
from torch import Tensor

if torch.cuda.get_device_capability() < (10, 0):
    pytest.skip(
        reason="Cutlass MLA Requires compute capability of 10 or above.",
        allow_module_level=True,
    )


def ref_mla(
    out: Tensor,  # (bs, num_heads, v_head_dim)
    query: Tensor,  # (bs, num_heads, head_dim)
    kv_cache: Tensor,  # (num_blocks, block_size, head_dim)
    scale: float,
    block_tables: Tensor,  # (bs, max_num_blocks)
    seq_lens: Tensor,  # (bs,)
):
    bs, num_heads, v_head_dim = out.shape
    head_dim = query.shape[2]

    for i in range(bs):
        # gather and flatten KV-cache
        kv = kv_cache[block_tables[i]]  # (max_num_blocks, block_size, head_dim)
        kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]]  # (1, seq_len, head_dim)
        v = kv[:, :, :v_head_dim]

        q = query[i].view(num_heads, 1, head_dim)
        o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
        out[i] = o.view(num_heads, v_head_dim)

    return out


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
@pytest.mark.parametrize("bs", [1, 2, 4])
@pytest.mark.parametrize("varlen", [False, True])
42
@pytest.mark.parametrize("block_size", [1, 16, 64, 128])
43
44
@pytest.mark.parametrize("num_heads", [16, 32, 64, 128])
@pytest.mark.parametrize("num_kv_splits", [-1, 1])
45
def test_cutlass_mla_decode(
46
47
48
49
50
51
52
    dtype: torch.dtype,
    mean_seq_len: int,
    bs: int,
    varlen: bool,
    block_size: int,
    num_heads: int,
    num_kv_splits: int,
53
54
55
56
57
58
):
    torch.set_default_dtype(dtype)
    torch.set_default_device("cuda")
    torch.manual_seed(42)

    d = 576
59
    h_q = num_heads
60
61
62
63
64
65
66
67
68
69
70
71
72
    dv = 512

    q_nope_dim = 128
    q_pe_dim = 64
    scale = (q_nope_dim + q_pe_dim) ** (-0.5)
    if varlen:
        seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
        seq_lens = seq_lens.clip(2).to(torch.int32)
    else:
        seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
    max_seq_len = seq_lens.max().item()
    block_num = (max_seq_len + block_size - 1) // block_size

73
74
75
76
77
    # Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
    # One 128-wide tile can hold (128 // block_size) small blocks.
    pack_factor = 128 // block_size
    block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

78
    # Lager q values to detect split kv error
79
    q = torch.randn(bs, h_q, d) * 100.0
80
81
82
83
    block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)

    kv_cache = torch.randn(block_table.numel(), block_size, d)

84
85
86
    workspace_size = cutlass_mla_get_workspace_size(
        block_num * block_size, bs, num_kv_splits=num_kv_splits
    )
87
88
    workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

89
90
91
92
    q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
    q_nope.copy_(q[:, :, :dv])
    q_pe = q[:, :, dv:].clone()

93
94
    out_ref = q.new_zeros(bs, h_q, dv)
    ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
95
    out = cutlass_mla_decode(
96
        q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
97
    )
98
99
100
101
102
103

    torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
    pytest.main([__file__])