test_cutlass_mla.py 3.41 KB
Newer Older
1
2
3
4
5
6
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
from torch import Tensor

7
8
# Disable tests on SM103 until the accuracy issues are fixed.
if torch.cuda.get_device_capability() != (10, 0):
9
    pytest.skip(
10
        reason="Cutlass MLA Requires compute capability of 10.",
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        allow_module_level=True,
    )


def ref_mla(
    out: Tensor,  # (bs, num_heads, v_head_dim)
    query: Tensor,  # (bs, num_heads, head_dim)
    kv_cache: Tensor,  # (num_blocks, block_size, head_dim)
    scale: float,
    block_tables: Tensor,  # (bs, max_num_blocks)
    seq_lens: Tensor,  # (bs,)
):
    bs, num_heads, v_head_dim = out.shape
    head_dim = query.shape[2]

    for i in range(bs):
        # gather and flatten KV-cache
        kv = kv_cache[block_tables[i]]  # (max_num_blocks, block_size, head_dim)
        kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]]  # (1, seq_len, head_dim)
        v = kv[:, :, :v_head_dim]

        q = query[i].view(num_heads, 1, head_dim)
        o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
        out[i] = o.view(num_heads, v_head_dim)

    return out


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
@pytest.mark.parametrize("bs", [1, 2, 4])
@pytest.mark.parametrize("varlen", [False, True])
43
@pytest.mark.parametrize("block_size", [1, 16, 64, 128])
44
45
@pytest.mark.parametrize("num_heads", [16, 32, 64, 128])
@pytest.mark.parametrize("num_kv_splits", [-1, 1])
46
def test_cutlass_mla_decode(
47
48
49
50
51
52
53
    dtype: torch.dtype,
    mean_seq_len: int,
    bs: int,
    varlen: bool,
    block_size: int,
    num_heads: int,
    num_kv_splits: int,
54
55
56
57
58
59
):
    torch.set_default_dtype(dtype)
    torch.set_default_device("cuda")
    torch.manual_seed(42)

    d = 576
60
    h_q = num_heads
61
62
63
64
65
66
67
68
69
70
71
72
73
    dv = 512

    q_nope_dim = 128
    q_pe_dim = 64
    scale = (q_nope_dim + q_pe_dim) ** (-0.5)
    if varlen:
        seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
        seq_lens = seq_lens.clip(2).to(torch.int32)
    else:
        seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
    max_seq_len = seq_lens.max().item()
    block_num = (max_seq_len + block_size - 1) // block_size

74
75
76
77
78
    # Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
    # One 128-wide tile can hold (128 // block_size) small blocks.
    pack_factor = 128 // block_size
    block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

79
    # Lager q values to detect split kv error
80
    q = torch.randn(bs, h_q, d) * 100.0
81
82
83
84
    block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)

    kv_cache = torch.randn(block_table.numel(), block_size, d)

85
86
87
    workspace_size = cutlass_mla_get_workspace_size(
        block_num * block_size, bs, num_kv_splits=num_kv_splits
    )
88
89
    workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

90
91
92
93
    q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
    q_nope.copy_(q[:, :, :dv])
    q_pe = q[:, :, dv:].clone()

94
95
    out_ref = q.new_zeros(bs, h_q, dv)
    ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
96
    out = cutlass_mla_decode(
97
        q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
98
    )
99
100
101
102
103
104

    torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
    pytest.main([__file__])