test_create_kvindices.py 3.95 KB
Newer Older
1
2
3
4
5
6
import itertools
import unittest

import numpy as np
import torch

7
8
9
10
from sglang.srt.layers.attention.utils import (
    create_flashinfer_kv_indices_triton,
    create_flashmla_kv_indices_triton,
)
11
from sglang.test.test_utils import CustomTestCase
12
13


14
class TestCreateKvIndices(CustomTestCase):
15
16
17
18
19
20
    @classmethod
    def setUpClass(cls):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA is not available")
        torch.set_default_device("cuda")

21
22
23
    def _run_test(self, batch, max_batch, max_context_len, page_size):
        np.random.seed(9)
        PAGE_SIZE = page_size
24
25
26
        req_to_token = torch.arange(
            max_batch * max_context_len, dtype=torch.int32, device="cuda"
        ).reshape((max_batch, max_context_len))
27
28

        # the block table
29
30
31
32
33
34
35
        req_pool_indices = torch.tensor(
            torch.from_numpy(
                np.random.choice(range(max_batch), size=batch, replace=False)
            ),
            dtype=torch.int32,
            device="cuda",
        )
36
        seq_lens = torch.tensor(
37
38
39
40
41
42
            torch.from_numpy(
                np.random.choice(range(max_context_len), size=batch, replace=False)
            ),
            dtype=torch.int32,
            device="cuda",
        )
43
        num_pages_per_req = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE
44
        kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
45
        kv_indptr[1:] = torch.cumsum(num_pages_per_req, dim=0)
46
47

        # ref
48
        kv_indices_ref = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
49
        req_pool_indices_cpu = req_pool_indices.cpu().numpy()
50
51
52
53
54
55
56
57
58
59
60
61
        seq_lens_cpu = seq_lens.cpu().numpy()
        for i in range(batch):
            kv_indptr_req = kv_indptr[i]
            num_toks_seq = seq_lens_cpu[i]
            curr_req_pool = req_pool_indices_cpu[i]
            curr_num_pages = num_pages_per_req[i]
            curr_token_ids = req_to_token[curr_req_pool]
            curr_pages = (curr_token_ids[:num_toks_seq] // PAGE_SIZE).unique()
            assert (
                len(curr_pages) == curr_num_pages
            ), f"req {i} has #{curr_num_pages} pages, but got {len(curr_pages)} pages"
            kv_indices_ref[kv_indptr_req : kv_indptr_req + curr_num_pages] = curr_pages
62
63
64
65
66
67

        # triton
        kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
        create_flashinfer_kv_indices_triton[(batch,)](
            req_to_token,
            req_pool_indices,
68
            seq_lens,
69
70
71
            kv_indptr,
            None,
            kv_indices_triton,
72
            req_to_token.size(1),
73
74
75
76
77
            PAGE_SIZE,
        )
        max_pages = max_context_len // PAGE_SIZE
        kv_indices_flashmla = torch.empty(
            batch, max_pages, dtype=torch.int32, device="cuda"
78
79
        )

80
81
82
83
84
85
86
87
88
89
        create_flashmla_kv_indices_triton[(batch,)](
            req_to_token,
            req_pool_indices,
            seq_lens,
            None,
            kv_indices_flashmla,
            req_to_token.size(1),
            max_pages,
            PAGE_SIZE,
        )
90
91
92
93
        # Check
        self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton))

    def test_create_kvindices(self):
94
        BATCH = [4, 37, 512, 1786]
95
96
        MAX_BATCH = 4096
        MAX_CONTEXT_LEN = 4096
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        PAGE_SIZE = [1, 2, 16, 64]
        # for debug
        # BATCH = [4]
        # MAX_BATCH = 4
        # MAX_CONTEXT_LEN = 10
        # Test for small batch size
        for page_size in PAGE_SIZE[:1]:
            print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}")
            self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size)

        # Test for larger batch size
        for batch in BATCH[1:]:
            for page_size in PAGE_SIZE:
                print(
                    f"Running test for batch size: {batch} and page size: {page_size}"
                )
                self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size)
114
115
116
117


if __name__ == "__main__":
    unittest.main()