hiradix_cache.py 15.8 KB
Newer Older
1
2
import heapq
import logging
3
import threading
4
5
6
7
8
9
10
import time
from typing import List, Optional

import torch

from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
11
    MHATokenToKVPool,
12
    MHATokenToKVPoolHost,
13
14
    MLATokenToKVPool,
    MLATokenToKVPoolHost,
15
    ReqToTokenPool,
16
    TokenToKVPoolAllocator,
17
)
Lianmin Zheng's avatar
Lianmin Zheng committed
18
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
19
20
21
22
23
24
25
26
27

logger = logging.getLogger(__name__)


class HiRadixCache(RadixCache):

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
28
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
29
        tp_cache_group: torch.distributed.ProcessGroup,
30
        page_size: int,
31
        hicache_ratio: float,
32
    ):
33
34
        self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
        if isinstance(self.kv_cache, MHATokenToKVPool):
35
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
36
                self.kv_cache, hicache_ratio, page_size
37
            )
38
        elif isinstance(self.kv_cache, MLATokenToKVPool):
39
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(
40
                self.kv_cache, hicache_ratio, page_size
41
            )
42
        else:
43
            raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
44

45
46
47
        self.tp_group = tp_cache_group

        self.load_cache_event = threading.Event()
48
        self.cache_controller = HiCacheController(
49
50
            token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
51
            page_size,
52
            load_cache_event=self.load_cache_event,
53
54
55
56
57
58
59
60
61
        )

        # record the nodes with ongoing write through
        self.ongoing_write_through = {}
        # record the node segments with ongoing load back
        self.ongoing_load_back = {}
        # todo: dynamically adjust the threshold
        self.write_through_threshold = 1
        self.load_back_threshold = 10
62
        super().__init__(
63
            req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
64
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    def reset(self):
        TreeNode.counter = 0
        self.cache_controller.reset()
        self.token_to_kv_pool_host.clear()
        super().reset()

    def get_height(self, node: TreeNode):
        height = 0
        while node != self.root_node:
            node = node.parent
            height += 1
        return height

    def write_backup(self, node: TreeNode):
        host_indices = self.cache_controller.write(
            device_indices=node.value,
            node_id=node.id,
        )
        if host_indices is None:
            self.evict_host(len(node.value))
            host_indices = self.cache_controller.write(
                device_indices=node.value,
                node_id=node.id,
            )
        if host_indices is not None:
            node.host_value = host_indices
            self.ongoing_write_through[node.id] = node
            self.inc_lock_ref(node)
        else:
            return None

        return len(host_indices)

    def inc_hit_count(self, node: TreeNode):
        if self.cache_controller.write_policy != "write_through_selective":
            return
        node.hit_count += 1
        if node.host_value is None and node.hit_count > self.write_through_threshold:
            self.write_backup(node)
            node.hit_count = 0

    def writing_check(self):
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        queue_size = torch.tensor(
            self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
        )
        if torch.distributed.get_world_size(group=self.tp_group) > 1:
            # synchrnoize TP workers to make the same update to radix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            ack_id = self.cache_controller.ack_write_queue.get()
            self.dec_lock_ref(self.ongoing_write_through[ack_id])
            del self.ongoing_write_through[ack_id]
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    def loading_check(self):
        while not self.cache_controller.ack_load_queue.empty():
            try:
                ack_id = self.cache_controller.ack_load_queue.get_nowait()
                start_node, end_node = self.ongoing_load_back[ack_id]
                self.dec_lock_ref(end_node)
                while end_node != start_node:
                    assert end_node.loading
                    end_node.loading = False
                    end_node = end_node.parent
                # clear the reference
                del self.ongoing_load_back[ack_id]
            except Exception:
                break

    def evictable_size(self):
        return self.evictable_size_

Lianmin Zheng's avatar
Lianmin Zheng committed
141
    def evict(self, num_tokens: int):
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        leaves = self._collect_leaves_device()
        heapq.heapify(leaves)

        num_evicted = 0
        pending_nodes = []
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x.lock_ref > 0:
                continue

            if x.host_value is None:
                if self.cache_controller.write_policy == "write_back":
                    num_evicted += self.write_backup(x)
                elif self.cache_controller.write_policy == "write_through_selective":
                    num_evicted += self._evict_write_through_selective(x)
                else:
                    assert (
                        self.cache_controller.write_policy != "write_through"
                    ), "write_through should be inclusive"
                    raise NotImplementedError
            else:
                num_evicted += self._evict_write_through(x)

            for child in x.parent.children.values():
                if child in pending_nodes:
                    continue
                if not child.evicted:
                    break
            else:
                # all children are evicted or no children
                heapq.heappush(leaves, x.parent)

        if self.cache_controller.write_policy == "write_back":
            # blocking till all write back complete
            while len(self.ongoing_write_through) > 0:
                self.writing_check()
                time.sleep(0.1)

    def _evict_write_through(self, node: TreeNode):
        # evict a node already written to host
        num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
        assert num_evicted > 0
        self.evictable_size_ -= num_evicted
        node.value = None
        return num_evicted

    def _evict_write_through_selective(self, node: TreeNode):
        # evict a node not initiated write to host
191
        self.cache_controller.mem_pool_device_allocator.free(node.value)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        num_evicted = len(node.value)
        self._delete_leaf(node)
        return num_evicted

    def evict_host(self, num_tokens: int):
        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)
            if x == self.root_node:
                break
            # only evict the host value of evicted nodes
            if not x.evicted:
                continue

209
210
            num_evicted += self.cache_controller.evict_host(x.host_value)

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
            for k, v in x.parent.children.items():
                if v == x:
                    break
            del x.parent.children[k]

            if len(x.parent.children) == 0 and x.parent.evicted:
                heapq.heappush(leaves, x.parent)

    def load_back(
        self, node: TreeNode, mem_quota: Optional[int] = None
    ) -> Optional[torch.Tensor]:
        # todo: more loading policies

        last_hit_node = node
        nodes_to_load = []
        while node.evicted:
            assert (
                node.backuped
            ), "No backup available on evicted nodes, should not happen"
            nodes_to_load.insert(0, node)
            node = node.parent
        else:
            ancester_node = node

        # protect the ancestor nodes from eviction
        delta = self.inc_lock_ref(ancester_node)

        # load it all or not at all
        host_indices = torch.cat([n.host_value for n in nodes_to_load])
        if len(host_indices) < self.load_back_threshold or (
            len(host_indices) > mem_quota + delta if mem_quota is not None else False
        ):
            # skip loading back if the total size is too small or exceeding the memory quota
            self.dec_lock_ref(ancester_node)
            return None

        device_indices = self.cache_controller.load(
            host_indices=host_indices, node_id=last_hit_node.id
        )
        if device_indices is None:
            self.evict(len(host_indices))
            device_indices = self.cache_controller.load(
                host_indices=host_indices, node_id=last_hit_node.id
            )
        self.dec_lock_ref(ancester_node)
        if device_indices is None:
            # no sufficient GPU memory to load back KV caches
            return None

        self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
        offset = 0
        for node in nodes_to_load:
            node.value = device_indices[offset : offset + len(node.host_value)]
            offset += len(node.host_value)
            node.loading = True
        self.evictable_size_ += len(device_indices)
        self.inc_lock_ref(last_hit_node)

        return device_indices

    def init_load_back(
        self,
        last_node: TreeNode,
        prefix_indices: torch.Tensor,
        mem_quota: Optional[int] = None,
    ):
        assert (
            len(prefix_indices) == 0 or prefix_indices.is_cuda
        ), "indices of device kV caches should be on GPU"
        if last_node.evicted:
            loading_values = self.load_back(last_node, mem_quota)
            if loading_values is not None:
                prefix_indices = (
                    loading_values
                    if len(prefix_indices) == 0
                    else torch.cat([prefix_indices, loading_values])
                )
                logger.debug(
                    f"loading back {len(loading_values)} tokens for node {last_node.id}"
                )

            while last_node.evicted:
                last_node = last_node.parent

        return last_node, prefix_indices

297
    def ready_to_load_cache(self):
298
299
300
        self.load_cache_event.set()

    def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
301
302
303
304
305
306
307
308
309
310
        empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
        if self.disable or len(key) == 0:
            if include_evicted:
                return empty_value, self.root_node, self.root_node
            else:
                return empty_value, self.root_node

        if self.page_size != 1:
            page_aligned_len = len(key) // self.page_size * self.page_size
            key = key[:page_aligned_len]
311
312
313

        value, last_node = self._match_prefix_helper(self.root_node, key)
        if value:
314
            value = torch.cat(value)
315
        else:
316
            value = empty_value
317
318
319
320
321
322
323
324
325
326

        last_node_global = last_node
        while last_node.evicted:
            last_node = last_node.parent

        if include_evicted:
            return value, last_node, last_node_global
        else:
            return value, last_node

327
    def _match_prefix_helper(self, node: TreeNode, key: List):
328
        node.last_access_time = time.time()
329
        child_key = self.get_child_key_fn(key)
330
        value = []
331
332
333

        while len(key) > 0 and child_key in node.children.keys():
            child = node.children[child_key]
334
            child.last_access_time = time.time()
335
            prefix_len = self.key_match_fn(child.key, key)
336
337
338
339
            if prefix_len < len(child.key):
                new_node = self._split_node(child.key, child, prefix_len)
                if not new_node.evicted:
                    value.append(new_node.value)
340
341
                node = new_node
                break
342
343
344
            else:
                if not child.evicted:
                    value.append(child.value)
345
346
                node = child
                key = key[prefix_len:]
347
348
349
350

                if len(key):
                    child_key = self.get_child_key_fn(key)

351
        return value, node
352
353
354
355

    def _split_node(self, key, child: TreeNode, split_len: int):
        # child node split into new_node -> child
        new_node = TreeNode()
356
        new_node.children = {self.get_child_key_fn(key[split_len:]): child}
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        new_node.parent = child.parent
        new_node.lock_ref = child.lock_ref
        new_node.key = child.key[:split_len]
        new_node.loading = child.loading

        # split value and host value if exists
        if child.evicted:
            new_node.value = None
        else:
            new_node.value = child.value[:split_len]
            child.value = child.value[split_len:]
        if child.host_value is not None:
            new_node.host_value = child.host_value[:split_len]
            child.host_value = child.host_value[split_len:]
        child.parent = new_node
        child.key = child.key[split_len:]
373
        new_node.parent.children[self.get_child_key_fn(key)] = new_node
374
375
376
377
378
379
380
        return new_node

    def _insert_helper(self, node: TreeNode, key: List, value):
        node.last_access_time = time.time()
        if len(key) == 0:
            return 0

381
382
383
384
385
386
387
        child_key = self.get_child_key_fn(key)
        total_prefix_length = 0

        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
            node.last_access_time = time.time()
            prefix_len = self.key_match_fn(node.key, key)
388

389
390
            if prefix_len == len(node.key):
                if node.evicted:
391
392
                    # change the reference if the node is evicted
                    # this often happens in the case of KV cache recomputation
393
394
395
                    node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(node.host_value)
                    self.evictable_size_ += len(node.value)
396
                else:
397
398
                    self.inc_hit_count(node)
                    total_prefix_length += prefix_len
399
            else:
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                # partial match, split the node
                new_node = self._split_node(node.key, node, prefix_len)
                if new_node.evicted:
                    new_node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(new_node.host_value)
                    self.evictable_size_ += len(new_node.value)
                else:
                    self.inc_hit_count(new_node)
                    total_prefix_length += prefix_len
                node = new_node

            key = key[prefix_len:]
            value = value[prefix_len:]

            if len(key):
                child_key = self.get_child_key_fn(key)
416
417
418
419
420
421

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = value
422
            node.children[child_key] = new_node
423
424
425
426
            self.evictable_size_ += len(value)

            if self.cache_controller.write_policy == "write_through":
                self.write_backup(new_node)
427
        return total_prefix_length
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

    def _collect_leaves_device(self):
        def is_leaf(node):
            if node.evicted:
                return False
            if node == self.root_node:
                return False
            if len(node.children) == 0:
                return True
            for child in node.children.values():
                if not child.evicted:
                    return False
            return True

        ret_list = []
        stack = [self.root_node]
        while stack:
            cur_node = stack.pop()
            if is_leaf(cur_node):
                ret_list.append(cur_node)
            else:
                for cur_child in cur_node.children.values():
                    if not cur_child.evicted:
                        stack.append(cur_child)
        return ret_list