test_swa_unittest.py 7.4 KB
Newer Older
Hanming Lu's avatar
Hanming Lu committed
1
2
3
4
5
6
import unittest

import torch

from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
7
8
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Hanming Lu's avatar
Hanming Lu committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22


class TestSWA(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        pass

    @classmethod
    def tearDownClass(cls):
        pass

    def test_swa_memory_pool(self):
        size = 16
        size_swa = 16
23
        head_num = 8
Hanming Lu's avatar
Hanming Lu committed
24
25
26
27
28
29
30
31
32
33
34
        head_dim = 128
        num_layers = 48
        global_interval = 4
        dtype = torch.bfloat16
        device = "cuda"
        full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
        full_attention_layer_ids_set = set(full_attention_layer_ids)
        swa_attention_layer_ids = [
            i for i in range(num_layers) if i not in full_attention_layer_ids_set
        ]
        pool = SWAKVPool(
35
36
37
            size=size,
            size_swa=size_swa,
            dtype=dtype,
38
            head_num=head_num,
39
40
41
            head_dim=head_dim,
            swa_attention_layer_ids=swa_attention_layer_ids,
            full_attention_layer_ids=full_attention_layer_ids,
42
            enable_kvcache_transpose=False,
43
44
45
            device=device,
        )
        alloc = SWATokenToKVPoolAllocator(
46
47
48
49
50
51
            size=size,
            size_swa=size_swa,
            dtype=dtype,
            device=device,
            kvcache=pool,
            need_sort=False,
Hanming Lu's avatar
Hanming Lu committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        )
        assert alloc.available_size() == size + size_swa
        index = alloc.alloc(1)
        assert alloc.available_size() == size_swa + size_swa - 2
        alloc.free_swa(index)
        result = alloc.translate_loc_from_full_to_swa(index)
        print(result)

    def test_swa_radix_cache_1(self):
        # args
        req_size = 10
        max_context_len = 128
        kv_size = 128
        kv_size_swa = 64
        sliding_window_size = 4
67
        head_num = 8
Hanming Lu's avatar
Hanming Lu committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        head_dim = 128
        num_layers = 48
        global_interval = 4
        dtype = torch.bfloat16
        device = "cuda"
        full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
        full_attention_layer_ids_set = set(full_attention_layer_ids)
        swa_attention_layer_ids = [
            i for i in range(num_layers) if i not in full_attention_layer_ids_set
        ]
        # setup req to token pool
        req_to_token_pool = ReqToTokenPool(
            size=req_size,
            max_context_len=max_context_len,
            device=device,
            enable_memory_saver=False,
        )
        # setup kv pool
        kv_pool = SWAKVPool(
87
88
89
            size=kv_size,
            size_swa=kv_size_swa,
            dtype=dtype,
90
            head_num=head_num,
91
92
93
            head_dim=head_dim,
            swa_attention_layer_ids=swa_attention_layer_ids,
            full_attention_layer_ids=full_attention_layer_ids,
94
            enable_kvcache_transpose=False,
95
            device=device,
Hanming Lu's avatar
Hanming Lu committed
96
97
98
        )
        # setup token to kv pool allocator
        allocator = SWATokenToKVPoolAllocator(
99
100
101
102
103
            size=kv_size,
            size_swa=kv_size_swa,
            dtype=dtype,
            device=device,
            kvcache=kv_pool,
104
            need_sort=False,
Hanming Lu's avatar
Hanming Lu committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        )
        # setup radix cache
        tree = SWARadixCache(
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool_allocator=allocator,
            sliding_window_size=sliding_window_size,
            page_size=1,
            disable=False,
        )

        # test
        print(
            f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
        assert len(req1_token_ids) == len(req1_kv_indices)
        print(
            f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
        )
124
        prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
Hanming Lu's avatar
Hanming Lu committed
125
126
127
128
129
130
131
132
        print(
            f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
        assert len(req2_token_ids) == len(req2_kv_indices)
        print(
            f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
        )
133
        prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
Hanming Lu's avatar
Hanming Lu committed
134
135
136
137
138
139
140
141
        print(
            f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
        assert len(req3_token_ids) == len(req3_kv_indices)
        print(
            f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
        )
142
        prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
Hanming Lu's avatar
Hanming Lu committed
143
144
145
146
147
148
149
150
        print(
            f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
        assert len(req4_token_ids) == len(req4_kv_indices)
        print(
            f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
        )
151
        prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
Hanming Lu's avatar
Hanming Lu committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        print(
            f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )

        tree.pretty_print()
        full_num_tokens, swa_num_tokens = 1, 0
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        full_num_tokens, swa_num_tokens = 0, 1
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        full_num_tokens, swa_num_tokens = 1, 2
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        req5_token_ids = [1, 2, 3, 4, 5]
173
174
        result = tree.match_prefix(RadixKey(req5_token_ids))
        kv_indices, last_node = result.device_indices, result.last_device_node
Hanming Lu's avatar
Hanming Lu committed
175
176
177
178
179
180
        print(
            f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
        )
        assert len(kv_indices) == 0

        req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
181
182
        result = tree.match_prefix(RadixKey(req6_token_ids))
        kv_indices, last_node = result.device_indices, result.last_device_node
Hanming Lu's avatar
Hanming Lu committed
183
184
185
186
187
        print(
            f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
        )
        assert len(kv_indices) == 7
        assert len(last_node.key) == 2
188
189
        assert last_node.key.token_ids[0] == 60
        assert last_node.key.token_ids[1] == 70
Hanming Lu's avatar
Hanming Lu committed
190
191
192
193


if __name__ == "__main__":
    unittest.main()