test_flashinfer.py 7.1 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
import pytest
import torch
Liangsheng Yin's avatar
Liangsheng Yin committed
3
4
5
6
7
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
Liangsheng Yin's avatar
Liangsheng Yin committed
8

9
10
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.layers.attention.triton_ops.extend_attention import (
11
12
13
    extend_attention_fwd,
    redundant_attention,
)
14
from sglang.srt.utils import should_use_tensor_core
Lianmin Zheng's avatar
Lianmin Zheng committed
15

Liangsheng Yin's avatar
Liangsheng Yin committed
16
17
18
flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None

Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
22
23

@pytest.mark.parametrize("batch_size", [12, 37, 67])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
Liangsheng Yin's avatar
Liangsheng Yin committed
24
@pytest.mark.parametrize("num_qo_heads", [32, 4])
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26
27
28
29
30
31
32
33
@pytest.mark.parametrize("head_dim", [128])
def test_batch_prefill_with_paged_kv_cache(
    batch_size,
    kv_len,
    qo_len,
    num_kv_heads,
    num_qo_heads,
    head_dim,
):
Liangsheng Yin's avatar
Liangsheng Yin committed
34
35
    init_flashinfer(num_qo_heads, num_kv_heads)

Lianmin Zheng's avatar
Lianmin Zheng committed
36
    q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
Liangsheng Yin's avatar
Liangsheng Yin committed
37
    qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
Lianmin Zheng's avatar
Lianmin Zheng committed
38
    total_tokens = kv_len * batch_size
Liangsheng Yin's avatar
Liangsheng Yin committed
39
    kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
    kv_indices = torch.arange(0, total_tokens).to(0).int()
    kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)

    # init args for triton kernel
    k_extend = (
        kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
        .contiguous()
        .view(-1, num_kv_heads, head_dim)
    )
    v_extend = (
        kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
        .contiguous()
        .view(-1, num_kv_heads, head_dim)
    )
    o_triton = torch.empty_like(q)
    k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
    v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
    req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
    b_req_idx = torch.arange(0, batch_size).to(0).int()
    b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
    b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
    b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
    max_len_in_batch = kv_len
    max_len_extend = qo_len

    extend_attention_fwd(
        q,
        k_extend,
        v_extend,
        o_triton,
        k_buffer,
        v_buffer,
        req_to_token,
        b_req_idx,
        None,  # b_start_loc = None
        b_seq_len,
        None,  # b_seq_len_prefix = None
        b_start_loc_extend,
        b_seq_len_extend,
        max_len_in_batch,
        max_len_extend,
    )

Liangsheng Yin's avatar
Liangsheng Yin committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    o_redundant = torch.empty_like(q)
    b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
    b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
    b_seq_len_prefix = b_seq_len - b_seq_len_extend

    redundant_attention(
        q,
        k_extend,
        v_extend,
        o_redundant,
        k_buffer,
        v_buffer,
        req_to_token,
        b_req_idx,
        b_start_loc,
        b_seq_len,
        b_seq_len_prefix,
        max_len_in_batch,
    )
    print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
    print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
    assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)

    flashinfer_prefill_wrapper.end_forward()

    flashinfer_prefill_wrapper.begin_forward(
        qo_indptr,
        kv_indptr,
        kv_indices,
        kv_last_page_len,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        1,
    )
    o = flashinfer_prefill_wrapper.forward(
        q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    print("Mean: ", torch.mean(torch.abs(o - o_triton)))
    print("Max: ", torch.max(torch.abs(o - o_triton)))
    assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)


@pytest.mark.parametrize("batch_size", [12, 17, 37])
@pytest.mark.parametrize("kv_len", [54, 127, 537])
@pytest.mark.parametrize("num_kv_heads", [32])
@pytest.mark.parametrize("num_qo_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_decode_with_paged_kv_cache(
    batch_size,
    kv_len,
    num_kv_heads,
    num_qo_heads,
    head_dim,
):
    # note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
    # to test different shape of decode, change the parameters in the __main__, and run decode only once
Liangsheng Yin's avatar
Liangsheng Yin committed
142
    init_flashinfer(num_qo_heads, num_kv_heads)
Lianmin Zheng's avatar
Lianmin Zheng committed
143
144
145

    q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
    total_tokens = kv_len * batch_size
Liangsheng Yin's avatar
Liangsheng Yin committed
146
    kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
    kv_indices = torch.arange(0, total_tokens).to(0).int()
    kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)

    # init args for triton kernel
    k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
    v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
    o_triton = torch.empty_like(q)
    req_to_token = (
        torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
    )
    b_req_idx = torch.arange(0, batch_size).to(0).int()
    b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
    b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
    max_len_in_batch = kv_len
    other_kv_index = 0
163
    decode_attention_fwd(
Lianmin Zheng's avatar
Lianmin Zheng committed
164
165
166
167
168
169
170
171
172
173
174
175
176
        q,
        k_buffer,
        v_buffer,
        o_triton,
        req_to_token,
        b_req_idx,
        b_start_loc,
        b_seq_len,
        max_len_in_batch,
        other_kv_index,
        total_tokens,
    )

Liangsheng Yin's avatar
Liangsheng Yin committed
177
178
    flashinfer_decode_wrapper.end_forward()
    flashinfer_decode_wrapper.begin_forward(
Lianmin Zheng's avatar
Lianmin Zheng committed
179
        kv_indptr,
Liangsheng Yin's avatar
Liangsheng Yin committed
180
        kv_indices,
Lianmin Zheng's avatar
Lianmin Zheng committed
181
182
183
184
185
        kv_last_page_len,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        1,
Liangsheng Yin's avatar
Liangsheng Yin committed
186
187
188
189
190
        pos_encoding_mode="NONE",
        data_type="float16",
    )
    o = flashinfer_decode_wrapper.forward(
        q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
Lianmin Zheng's avatar
Lianmin Zheng committed
191
192
193
194
195
196
197
    )

    print("Mean: ", torch.mean(torch.abs(o - o_triton)))
    print("Max: ", torch.max(torch.abs(o - o_triton)))
    assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)


Liangsheng Yin's avatar
Liangsheng Yin committed
198
def init_flashinfer(num_attention_heads, num_kv_heads):
199
200
201
    use_tensor_cores = should_use_tensor_core(
        torch.half, num_attention_heads, num_kv_heads
    )
Liangsheng Yin's avatar
Liangsheng Yin committed
202
203
204
205
206
207
208
209
210
211
212
213
214

    workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")

    global flashinfer_prefill_wrapper, flashinfer_decode_wrapper

    flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, "NHD"
    )
    flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
    )


Lianmin Zheng's avatar
Lianmin Zheng committed
215
if __name__ == "__main__":
Liangsheng Yin's avatar
Liangsheng Yin committed
216
217
    test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
    test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
Lianmin Zheng's avatar
Lianmin Zheng committed
218
    test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)