test_swa_unittest.py 13.5 KB
Newer Older
Hanming Lu's avatar
Hanming Lu committed
1
2
3
4
5
6
import unittest

import torch

from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
7
8
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Hanming Lu's avatar
Hanming Lu committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22


class TestSWA(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        pass

    @classmethod
    def tearDownClass(cls):
        pass

    def test_swa_memory_pool(self):
        size = 16
        size_swa = 16
23
        head_num = 8
Hanming Lu's avatar
Hanming Lu committed
24
25
26
27
28
29
30
31
32
33
34
        head_dim = 128
        num_layers = 48
        global_interval = 4
        dtype = torch.bfloat16
        device = "cuda"
        full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
        full_attention_layer_ids_set = set(full_attention_layer_ids)
        swa_attention_layer_ids = [
            i for i in range(num_layers) if i not in full_attention_layer_ids_set
        ]
        pool = SWAKVPool(
35
36
37
            size=size,
            size_swa=size_swa,
            dtype=dtype,
38
            head_num=head_num,
39
40
41
            head_dim=head_dim,
            swa_attention_layer_ids=swa_attention_layer_ids,
            full_attention_layer_ids=full_attention_layer_ids,
42
            enable_kvcache_transpose=False,
43
44
45
            device=device,
        )
        alloc = SWATokenToKVPoolAllocator(
46
47
48
49
50
51
            size=size,
            size_swa=size_swa,
            dtype=dtype,
            device=device,
            kvcache=pool,
            need_sort=False,
Hanming Lu's avatar
Hanming Lu committed
52
        )
53
54
55
        self.assertEqual(
            alloc.full_available_size() + alloc.swa_available_size(), size + size_swa
        )
Hanming Lu's avatar
Hanming Lu committed
56
        index = alloc.alloc(1)
57
58
59
60
        self.assertEqual(
            alloc.full_available_size() + alloc.swa_available_size(),
            size_swa + size_swa - 2,
        )
Hanming Lu's avatar
Hanming Lu committed
61
62
63
64
65
66
67
68
69
70
71
        alloc.free_swa(index)
        result = alloc.translate_loc_from_full_to_swa(index)
        print(result)

    def test_swa_radix_cache_1(self):
        # args
        req_size = 10
        max_context_len = 128
        kv_size = 128
        kv_size_swa = 64
        sliding_window_size = 4
72
        head_num = 8
Hanming Lu's avatar
Hanming Lu committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        head_dim = 128
        num_layers = 48
        global_interval = 4
        dtype = torch.bfloat16
        device = "cuda"
        full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
        full_attention_layer_ids_set = set(full_attention_layer_ids)
        swa_attention_layer_ids = [
            i for i in range(num_layers) if i not in full_attention_layer_ids_set
        ]
        # setup req to token pool
        req_to_token_pool = ReqToTokenPool(
            size=req_size,
            max_context_len=max_context_len,
            device=device,
            enable_memory_saver=False,
        )
        # setup kv pool
        kv_pool = SWAKVPool(
92
93
94
            size=kv_size,
            size_swa=kv_size_swa,
            dtype=dtype,
95
            head_num=head_num,
96
97
98
            head_dim=head_dim,
            swa_attention_layer_ids=swa_attention_layer_ids,
            full_attention_layer_ids=full_attention_layer_ids,
99
            enable_kvcache_transpose=False,
100
            device=device,
Hanming Lu's avatar
Hanming Lu committed
101
102
103
        )
        # setup token to kv pool allocator
        allocator = SWATokenToKVPoolAllocator(
104
105
106
107
108
            size=kv_size,
            size_swa=kv_size_swa,
            dtype=dtype,
            device=device,
            kvcache=kv_pool,
109
            need_sort=False,
Hanming Lu's avatar
Hanming Lu committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        )
        # setup radix cache
        tree = SWARadixCache(
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool_allocator=allocator,
            sliding_window_size=sliding_window_size,
            page_size=1,
            disable=False,
        )

        # test
        print(
            f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        self.assertEqual(len(req1_token_ids), len(req1_kv_indices))
        print(
            f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
        )
        prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
        print(
            f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
        self.assertEqual(len(req2_token_ids), len(req2_kv_indices))
        print(
            f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
        )
        prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
        print(
            f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
        self.assertEqual(len(req3_token_ids), len(req3_kv_indices))
        print(
            f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
        )
        prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
        print(
            f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
        self.assertEqual(len(req4_token_ids), len(req4_kv_indices))
        print(
            f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
        )
        prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
        print(
            f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )

        tree.pretty_print()
        full_num_tokens, swa_num_tokens = 1, 0
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        full_num_tokens, swa_num_tokens = 0, 1
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        full_num_tokens, swa_num_tokens = 1, 2
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        req5_token_ids = [1, 2, 3, 4, 5]
        result = tree.match_prefix(RadixKey(req5_token_ids))
        kv_indices, last_node = result.device_indices, result.last_device_node
        print(
            f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
        )
        self.assertEqual(len(kv_indices), 0)

        req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
        result = tree.match_prefix(RadixKey(req6_token_ids))
        kv_indices, last_node = result.device_indices, result.last_device_node
        print(
            f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
        )
        self.assertEqual(len(kv_indices), 7)
        self.assertEqual(len(last_node.key), 2)
        self.assertEqual(last_node.key.token_ids[0], 60)
        self.assertEqual(last_node.key.token_ids[1], 70)

    def test_swa_radix_cache_eagle(self):
        # args
        req_size = 10
        max_context_len = 128
        kv_size = 128
        kv_size_swa = 64
        sliding_window_size = 4
        head_num = 8
        head_dim = 128
        num_layers = 48
        global_interval = 4
        dtype = torch.bfloat16
        device = "cuda"
        full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
        full_attention_layer_ids_set = set(full_attention_layer_ids)
        swa_attention_layer_ids = [
            i for i in range(num_layers) if i not in full_attention_layer_ids_set
        ]
        # setup req to token pool
        req_to_token_pool = ReqToTokenPool(
            size=req_size,
            max_context_len=max_context_len,
            device=device,
            enable_memory_saver=False,
        )
        # setup kv pool
        kv_pool = SWAKVPool(
            size=kv_size,
            size_swa=kv_size_swa,
            dtype=dtype,
            head_num=head_num,
            head_dim=head_dim,
            swa_attention_layer_ids=swa_attention_layer_ids,
            full_attention_layer_ids=full_attention_layer_ids,
            enable_kvcache_transpose=False,
            device=device,
        )
        # setup token to kv pool allocator
        allocator = SWATokenToKVPoolAllocator(
            size=kv_size,
            size_swa=kv_size_swa,
            dtype=dtype,
            device=device,
            kvcache=kv_pool,
            need_sort=False,
        )
        # setup radix cache
        tree = SWARadixCache(
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool_allocator=allocator,
            sliding_window_size=sliding_window_size,
            page_size=1,
            disable=False,
            is_eagle=True,
        )

        # test
        print(
            f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
        self.assertEqual(len(req1_token_ids), len(req1_kv_indices))
Hanming Lu's avatar
Hanming Lu committed
258
259
260
        print(
            f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
        )
261
        prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
262
        self.assertEqual(prefix_len, 0)
Hanming Lu's avatar
Hanming Lu committed
263
264
265
266
        print(
            f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
267
        self.assertEqual(len(req2_token_ids), len(req2_kv_indices))
Hanming Lu's avatar
Hanming Lu committed
268
269
270
        print(
            f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
        )
271
        prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
272
        self.assertEqual(prefix_len, 2)
Hanming Lu's avatar
Hanming Lu committed
273
274
275
276
        print(
            f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
277
        self.assertEqual(len(req3_token_ids), len(req3_kv_indices))
Hanming Lu's avatar
Hanming Lu committed
278
279
280
        print(
            f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
        )
281
        prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
282
        self.assertEqual(prefix_len, 0)
Hanming Lu's avatar
Hanming Lu committed
283
284
285
286
        print(
            f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )
        req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
287
        self.assertEqual(len(req4_token_ids), len(req4_kv_indices))
Hanming Lu's avatar
Hanming Lu committed
288
289
290
        print(
            f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
        )
291
        prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
292
        self.assertEqual(prefix_len, 4)
Hanming Lu's avatar
Hanming Lu committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        print(
            f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
        )

        tree.pretty_print()
        full_num_tokens, swa_num_tokens = 1, 0
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        full_num_tokens, swa_num_tokens = 0, 1
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        full_num_tokens, swa_num_tokens = 1, 2
        print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
        tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
        tree.pretty_print()

        req5_token_ids = [1, 2, 3, 4, 5]
314
315
        result = tree.match_prefix(RadixKey(req5_token_ids))
        kv_indices, last_node = result.device_indices, result.last_device_node
Hanming Lu's avatar
Hanming Lu committed
316
317
318
        print(
            f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
        )
319
        self.assertEqual(len(kv_indices), 0)  # no swa prefix matched
Hanming Lu's avatar
Hanming Lu committed
320
321

        req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
322
323
        result = tree.match_prefix(RadixKey(req6_token_ids))
        kv_indices, last_node = result.device_indices, result.last_device_node
Hanming Lu's avatar
Hanming Lu committed
324
325
326
        print(
            f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
        )
327
328
329
330
        self.assertEqual(len(kv_indices), 6)
        self.assertEqual(len(last_node.key), 2)
        self.assertEqual(last_node.key.token_ids[0], (5, 60))
        self.assertEqual(last_node.key.token_ids[1], (60, 70))
Hanming Lu's avatar
Hanming Lu committed
331
332
333
334


if __name__ == "__main__":
    unittest.main()