test_flashmla_sparse.py 4.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch


def _cuda_sm90_available() -> bool:
    if not torch.cuda.is_available():
        return False
    major, _ = torch.cuda.get_device_capability()
    return major == 9


def test_sparse_flashmla_metadata_smoke():
    import vllm.attention.ops.flashmla as fm
    ok, reason = fm.is_flashmla_supported()
    if not ok or not _cuda_sm90_available():
        pytest.skip(reason or "SM90 not available")

    device = torch.device("cuda")
    batch_size = 1
    seqlen_q = 1
    num_heads_q = 128
    num_heads_k = 1
    q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
    topk = 128

    cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)

    tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
                                              q_seq_per_hk,
                                              num_heads_k,
                                              num_heads_q=num_heads_q,
                                              topk=topk,
                                              is_fp8_kvcache=True)
    assert tile_md.dtype == torch.int32
    assert num_splits.dtype == torch.int32


def test_sparse_flashmla_decode_smoke():
    import vllm.attention.ops.flashmla as fm
    ok, reason = fm.is_flashmla_supported()
    if not ok or not _cuda_sm90_available():
        pytest.skip(reason or "SM90 not available")

    device = torch.device("cuda")
    batch_size = 1
    seqlen_q = 1
    num_heads_q = 1
    head_dim_k = 576
    head_dim_v = 512
    num_heads_k = 1
    page_block_size = 64
    bytes_per_token = 656
    topk = 128

    # Metadata
    q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
    # q_heads_per_hk = num_heads_q // num_heads_k
    cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
    tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
                                              q_seq_per_hk,
                                              num_heads_k,
                                              num_heads_q=num_heads_q,
                                              topk=topk,
                                              is_fp8_kvcache=True)

    # Inputs
    q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
                    dtype=torch.bfloat16,
                    device=device)
    k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
                          dtype=torch.uint8,
                          device=device)
    indices = torch.zeros((batch_size, seqlen_q, topk),
                          dtype=torch.int32,
                          device=device)

    block_table = torch.zeros((batch_size, 128),
                              dtype=torch.int32,
                              device=device)
    out, lse = fm.flash_mla_with_kvcache(q,
                                         k_cache,
                                         block_table,
                                         cache_seqlens,
                                         head_dim_v,
                                         tile_md,
                                         num_splits,
                                         indices=indices,
                                         is_fp8_kvcache=True)
    assert out.shape[0] == batch_size
    assert out.shape[-1] == head_dim_v
    assert lse.shape[0] == batch_size


def test_sparse_flashmla_prefill_smoke():
    import vllm.attention.ops.flashmla as fm
    ok, reason = fm.is_flashmla_supported()
    if not ok or not _cuda_sm90_available():
        pytest.skip(reason or "SM90 not available")

    device = torch.device("cuda")
    s_q = 1
    s_kv = 1
    h_q = 64  # kernel expects multiple of 64
    h_kv = 1
    d_qk = 576
    d_v = 512
    topk = 128

    q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
    kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
    indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)

    out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
                                                       d_v)
    assert out.shape == (s_q, h_q, d_v)
    assert max_logits.shape == (s_q, h_q)
    assert lse.shape == (s_q, h_q)