hiradix_cache.py 28.5 KB
Newer Older
1
2
import heapq
import logging
3
import threading
4
import time
pansicheng's avatar
pansicheng committed
5
from queue import Queue
6
7
8
9
from typing import List, Optional

import torch

pansicheng's avatar
pansicheng committed
10
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
11
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
12
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
13
from sglang.srt.mem_cache.memory_pool import (
14
15
    MHATokenToKVPool,
    MLATokenToKVPool,
16
17
    ReqToTokenPool,
)
18
19
20
21
from sglang.srt.mem_cache.memory_pool_host import (
    MHATokenToKVPoolHost,
    MLATokenToKVPoolHost,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
24
25
26
27
28
29
30
31

logger = logging.getLogger(__name__)


class HiRadixCache(RadixCache):

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
32
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
33
        tp_cache_group: torch.distributed.ProcessGroup,
34
        page_size: int,
35
        hicache_ratio: float,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
36
37
        hicache_size: int,
        hicache_write_policy: str,
38
        hicache_io_backend: str,
39
        hicache_mem_layout: str,
40
        hicache_storage_backend: Optional[str] = None,
pansicheng's avatar
pansicheng committed
41
        hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
43
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
44
    ):
45
46
47
48
49
50
51
52

        if hicache_io_backend == "direct":
            if hicache_mem_layout == "page_first":
                hicache_mem_layout = "layer_first"
                logger.warning(
                    "Page first layout is not supported with direct IO backend, switching to layer first layout"
                )

53
54
        self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
        if isinstance(self.kv_cache, MHATokenToKVPool):
55
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
56
57
58
59
60
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
61
            )
62
        elif isinstance(self.kv_cache, MLATokenToKVPool):
63
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(
64
65
66
67
68
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
69
            )
70
        else:
71
            raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
72

73
        self.tp_group = tp_cache_group
74
        self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
75
        self.enable_storage = hicache_storage_backend is not None
76
        # todo: customizable storage prefetch threshold and timeout
77
        self.prefetch_threshold = 256
78
79
        self.prefetch_timeout = 3  # seconds
        self.prefetch_stop_policy = hicache_storage_prefetch_policy
80
81

        self.load_cache_event = threading.Event()
82
        self.cache_controller = HiCacheController(
83
84
            token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
85
            page_size,
86
            self.tp_group,
87
            load_cache_event=self.load_cache_event,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
88
            write_policy=hicache_write_policy,
89
            io_backend=hicache_io_backend,
90
91
            storage_backend=hicache_storage_backend,
            prefetch_threshold=self.prefetch_threshold,
92
93
            model_name=model_name,
            storage_backend_extra_config=storage_backend_extra_config,
94
95
96
97
98
99
        )

        # record the nodes with ongoing write through
        self.ongoing_write_through = {}
        # record the node segments with ongoing load back
        self.ongoing_load_back = {}
100
101
102
        # record the ongoing prefetch requests
        self.ongoing_prefetch = {}
        self.ongoing_backup = {}
103
        # todo: dynamically adjust the threshold
Zhiqiang Xie's avatar
Zhiqiang Xie committed
104
        self.write_through_threshold = (
105
            1 if hicache_write_policy == "write_through" else 2
Zhiqiang Xie's avatar
Zhiqiang Xie committed
106
        )
107
        self.load_back_threshold = 10
108
        super().__init__(
109
            req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
110
        )
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    def reset(self):
        TreeNode.counter = 0
        self.cache_controller.reset()
        self.token_to_kv_pool_host.clear()
        super().reset()

    def get_height(self, node: TreeNode):
        height = 0
        while node != self.root_node:
            node = node.parent
            height += 1
        return height

125
126
127
128
129
130
131
132
133
    def clear_storage_backend(self):
        if self.enable_storage:
            self.cache_controller.storage_backend.clear()
            logger.info("Hierarchical cache storage backend cleared successfully!")
            return True
        else:
            logger.warning("Hierarchical cache storage backend is not enabled.")
            return False

Zhiqiang Xie's avatar
Zhiqiang Xie committed
134
    def write_backup(self, node: TreeNode, write_back=False):
135
136
137
138
139
140
141
142
143
144
145
146
        host_indices = self.cache_controller.write(
            device_indices=node.value,
            node_id=node.id,
        )
        if host_indices is None:
            self.evict_host(len(node.value))
            host_indices = self.cache_controller.write(
                device_indices=node.value,
                node_id=node.id,
            )
        if host_indices is not None:
            node.host_value = host_indices
147
            assert len(node.host_value) > 0
148
            self.ongoing_write_through[node.id] = node
Zhiqiang Xie's avatar
Zhiqiang Xie committed
149
150
151
            if not write_back:
                # no need to lock nodes if write back
                self.inc_lock_ref(node)
152
        else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
153
            return 0
154
155
156

        return len(host_indices)

157
158
    def write_backup_storage(self, node: TreeNode):
        operation_id = self.cache_controller.write_storage(
159
            node.host_value, node.key, node.hash_value
160
161
162
163
        )
        self.ongoing_backup[operation_id] = node
        node.protect_host()

164
165
166
    def _inc_hit_count(self, node: TreeNode, chunked=False):
        # skip the hit count update for chunked requests
        if self.cache_controller.write_policy == "write_back" or chunked:
167
168
            return
        node.hit_count += 1
169
170
171
172
173

        if not node.backuped:
            if node.hit_count >= self.write_through_threshold:
                # write to host if the node is not backuped
                self.write_backup(node)
174

Zhiqiang Xie's avatar
Zhiqiang Xie committed
175
176
177
178
179
180
181
    def writing_check(self, write_back=False):
        if write_back:
            # blocking till all write back complete
            while len(self.ongoing_write_through) > 0:
                ack_id = self.cache_controller.ack_write_queue.get()
                del self.ongoing_write_through[ack_id]
            return
182
183
184
        queue_size = torch.tensor(
            self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
        )
185
        if self.tp_world_size > 1:
186
187
188
189
190
191
192
193
            # synchrnoize TP workers to make the same update to radix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            ack_id = self.cache_controller.ack_write_queue.get()
194
195
            backuped_node = self.ongoing_write_through[ack_id]
            self.dec_lock_ref(backuped_node)
196
            del self.ongoing_write_through[ack_id]
197
198
            if self.enable_storage:
                self.write_backup_storage(backuped_node)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    def loading_check(self):
        while not self.cache_controller.ack_load_queue.empty():
            try:
                ack_id = self.cache_controller.ack_load_queue.get_nowait()
                start_node, end_node = self.ongoing_load_back[ack_id]
                self.dec_lock_ref(end_node)
                while end_node != start_node:
                    assert end_node.loading
                    end_node.loading = False
                    end_node = end_node.parent
                # clear the reference
                del self.ongoing_load_back[ack_id]
            except Exception:
                break

    def evictable_size(self):
        return self.evictable_size_

Lianmin Zheng's avatar
Lianmin Zheng committed
218
    def evict(self, num_tokens: int):
219
220
221
222
        leaves = self._collect_leaves_device()
        heapq.heapify(leaves)

        num_evicted = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
223
        write_back_nodes = []
224
225
226
227
228
229
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x.lock_ref > 0:
                continue

Zhiqiang Xie's avatar
Zhiqiang Xie committed
230
            if not x.backuped:
231
                if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
232
233
234
                    # write to host if the node is not backuped
                    num_evicted += self.write_backup(x, write_back=True)
                    write_back_nodes.append(x)
235
                else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
236
                    num_evicted += self._evict_regular(x)
237
            else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
238
                num_evicted += self._evict_backuped(x)
239
240

            for child in x.parent.children.values():
Zhiqiang Xie's avatar
Zhiqiang Xie committed
241
                if child in write_back_nodes:
242
243
244
245
246
247
248
249
                    continue
                if not child.evicted:
                    break
            else:
                # all children are evicted or no children
                heapq.heappush(leaves, x.parent)

        if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
250
251
252
253
            self.writing_check(write_back=True)
            for node in write_back_nodes:
                assert node.backuped
                self._evict_backuped(node)
254

Zhiqiang Xie's avatar
Zhiqiang Xie committed
255
    def _evict_backuped(self, node: TreeNode):
256
257
258
259
260
261
262
        # evict a node already written to host
        num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
        assert num_evicted > 0
        self.evictable_size_ -= num_evicted
        node.value = None
        return num_evicted

Zhiqiang Xie's avatar
Zhiqiang Xie committed
263
    def _evict_regular(self, node: TreeNode):
264
        # evict a node not initiated write to host
265
        self.cache_controller.mem_pool_device_allocator.free(node.value)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        num_evicted = len(node.value)
        self._delete_leaf(node)
        return num_evicted

    def evict_host(self, num_tokens: int):
        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)
            if x == self.root_node:
                break
            # only evict the host value of evicted nodes
            if not x.evicted:
                continue

283
284
285
286
            # node is protected from eviction as it has ongoing prefetch or backup to storage
            if x.host_ref_counter > 0:
                continue

287
288
            num_evicted += self.cache_controller.evict_host(x.host_value)

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
            for k, v in x.parent.children.items():
                if v == x:
                    break
            del x.parent.children[k]

            if len(x.parent.children) == 0 and x.parent.evicted:
                heapq.heappush(leaves, x.parent)

    def load_back(
        self, node: TreeNode, mem_quota: Optional[int] = None
    ) -> Optional[torch.Tensor]:
        # todo: more loading policies

        last_hit_node = node
        nodes_to_load = []
        while node.evicted:
            assert (
                node.backuped
            ), "No backup available on evicted nodes, should not happen"
            nodes_to_load.insert(0, node)
            node = node.parent
        else:
            ancester_node = node

        # protect the ancestor nodes from eviction
        delta = self.inc_lock_ref(ancester_node)

        # load it all or not at all
        host_indices = torch.cat([n.host_value for n in nodes_to_load])
        if len(host_indices) < self.load_back_threshold or (
            len(host_indices) > mem_quota + delta if mem_quota is not None else False
        ):
            # skip loading back if the total size is too small or exceeding the memory quota
            self.dec_lock_ref(ancester_node)
            return None

        device_indices = self.cache_controller.load(
            host_indices=host_indices, node_id=last_hit_node.id
        )
        if device_indices is None:
            self.evict(len(host_indices))
            device_indices = self.cache_controller.load(
                host_indices=host_indices, node_id=last_hit_node.id
            )
        self.dec_lock_ref(ancester_node)
        if device_indices is None:
            # no sufficient GPU memory to load back KV caches
            return None

        self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
        offset = 0
        for node in nodes_to_load:
            node.value = device_indices[offset : offset + len(node.host_value)]
            offset += len(node.host_value)
            node.loading = True
        self.evictable_size_ += len(device_indices)
        self.inc_lock_ref(last_hit_node)

        return device_indices

    def init_load_back(
        self,
        last_node: TreeNode,
352
        host_hit_length: int,
353
354
        mem_quota: Optional[int] = None,
    ):
355
        _ = host_hit_length  # unused, but kept for compatibility
356
357
358
359
360
361
        if last_node.evicted:
            loading_values = self.load_back(last_node, mem_quota)
            if loading_values is not None:
                logger.debug(
                    f"loading back {len(loading_values)} tokens for node {last_node.id}"
                )
362
                return loading_values, last_node
363
364
365
366

            while last_node.evicted:
                last_node = last_node.parent

367
368
369
370
        return (
            torch.empty((0,), dtype=torch.int64, device=self.device),
            last_node,
        )
371

372
    def ready_to_load_host_cache(self):
373
        producer_index = self.cache_controller.layer_done_counter.next_producer()
374
        self.load_cache_event.set()
375
        return producer_index
376

377
378
379
    def check_hicache_events(self):
        self.writing_check()
        self.loading_check()
380
        if self.enable_storage:
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
            self.drain_storage_control_queues()

    def drain_storage_control_queues(self):
        """
        Combine prefetch revoke, backup ack, and host mem release checks
        to minimize TP synchronization and Python overhead.
        """
        cc = self.cache_controller

        qsizes = torch.tensor(
            [
                cc.prefetch_revoke_queue.qsize(),
                cc.ack_backup_queue.qsize(),
                cc.host_mem_release_queue.qsize(),
            ],
            dtype=torch.int,
397
        )
398
        if self.tp_world_size > 1:
399
            torch.distributed.all_reduce(
400
                qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
401
402
            )

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        n_revoke, n_backup, n_release = map(int, qsizes.tolist())

        # process prefetch revokes
        for _ in range(n_revoke):
            req_id = cc.prefetch_revoke_queue.get()
            info = self.ongoing_prefetch.pop(req_id, None)
            if info is not None:
                last_host_node, token_ids, _, _ = info
                last_host_node.release_host()
                cc.prefetch_tokens_occupied -= len(token_ids)
            # else: the revoked operation already got terminated, nothing to do

        # process backup acks
        for _ in range(n_backup):
            ack_id = cc.ack_backup_queue.get()
            entry = self.ongoing_backup.pop(ack_id, None)
            if entry is not None:
                entry.release_host()

        # release host memory
        host_indices_list = []
        for _ in range(n_release):
            host_indices_list.append(cc.host_mem_release_queue.get())
        if host_indices_list:
            host_indices = torch.cat(host_indices_list, dim=0)
            cc.mem_pool_host.free(host_indices)
429

pansicheng's avatar
pansicheng committed
430
431
432
433
434
435
    def can_terminate_prefetch(self, operation: PrefetchOperation):
        can_terminate = True

        if self.prefetch_stop_policy == "best_effort":
            return can_terminate

436
437
438
439
440
441
        if len(operation.hash_value) == 0:
            completed = False
        else:
            completed = (
                operation.completed_tokens == len(operation.hash_value) * self.page_size
            )
pansicheng's avatar
pansicheng committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

        if self.prefetch_stop_policy == "wait_complete":
            can_terminate = completed
        elif self.prefetch_stop_policy == "timeout":
            can_terminate = completed or (
                time.monotonic() - operation.start_time > self.prefetch_timeout
            )
        else:
            # unknown prefetch stop policy, just return True
            return True

        if self.tp_world_size > 1:
            can_terminate = torch.tensor(can_terminate, dtype=torch.int)
            torch.distributed.all_reduce(
                can_terminate,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
            can_terminate = bool(can_terminate.item())

        return can_terminate

    def check_prefetch_progress(self, req_id: str) -> bool:
465
466
        if req_id not in self.ongoing_prefetch:
            # there is no ongoing prefetch for this request or it has been revoked
pansicheng's avatar
pansicheng committed
467
            return True
468
469
470

        # todo: more policies for prefetch progress such as timeout
        # the current policy is to prefetch with best effort and terminate when queuing is over
471
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
472
            req_id
473
        ]
474

Zhiqiang Xie's avatar
Zhiqiang Xie committed
475
476
477
478
        if operation.host_indices is None:
            # prefetch has not been issued due to insufficient host memory
            return True

pansicheng's avatar
pansicheng committed
479
480
481
        if not self.can_terminate_prefetch(operation):
            return False

482
483
484
485
486
        completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
            operation
        )
        logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

487
        min_completed_tokens = completed_tokens
pansicheng's avatar
pansicheng committed
488
        if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
489
            # synchrnoize TP workers to make the same update to hiradix cache
490
491
492
            completed_tokens_tensor = torch.tensor(
                min_completed_tokens, dtype=torch.int
            )
493
            torch.distributed.all_reduce(
494
                completed_tokens_tensor,
495
496
497
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
498
            min_completed_tokens = completed_tokens_tensor.item()
499
500
501
502
503
504
        fetched_token_ids = token_ids[:min_completed_tokens]
        written_indices = host_indices[:min_completed_tokens]
        matched_length = self._insert_helper_host(
            last_host_node,
            fetched_token_ids,
            written_indices,
505
            hash_value[: min_completed_tokens // self.page_size],
506
        )
507
508
        if len(written_indices):
            self.cache_controller.mem_pool_host.update_prefetch(written_indices)
509
510

        self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
511
        self.cache_controller.append_host_mem_release(
512
513
514
            host_indices[min_completed_tokens:completed_tokens]
        )
        last_host_node.release_host()
515
        del self.ongoing_prefetch[req_id]
pansicheng's avatar
pansicheng committed
516
517
518
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

        return True
519
520

    def match_prefix(self, key: List[int], **kwargs):
521
522
        empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
        if self.disable or len(key) == 0:
523
524
525
526
527
528
            return MatchResult(
                device_indices=empty_value,
                last_device_node=self.root_node,
                last_host_node=self.root_node,
                host_hit_length=0,
            )
529
530
531
532

        if self.page_size != 1:
            page_aligned_len = len(key) // self.page_size * self.page_size
            key = key[:page_aligned_len]
533
534
535

        value, last_node = self._match_prefix_helper(self.root_node, key)
        if value:
536
            value = torch.cat(value)
537
        else:
538
            value = empty_value
539

540
541
        host_hit_length = 0
        last_host_node = last_node
542
        while last_node.evicted:
543
            host_hit_length += len(last_node.host_value)
544
            last_node = last_node.parent
545
546
        while not last_host_node.backuped:
            last_host_node = last_host_node.parent
547

548
549
550
551
552
553
        return MatchResult(
            device_indices=value,
            last_device_node=last_node,
            last_host_node=last_host_node,
            host_hit_length=host_hit_length,
        )
554

555
556
557
558
559
560
561
    def prefetch_from_storage(
        self,
        req_id: str,
        last_host_node: TreeNode,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
    ):
562
563
564
565
566
        # align the number of fetching tokens to the page size
        prefetch_length = len(new_input_tokens) - (
            len(new_input_tokens) % self.page_size
        )
        new_input_tokens = new_input_tokens[:prefetch_length]
567
568
569
570
571
        if (
            not self.enable_storage
            or prefetch_length < self.prefetch_threshold
            or self.cache_controller.prefetch_rate_limited()
        ):
572
573
574
            return

        last_host_node.protect_host()
575
        host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
576
        if host_indices is None:
577
578
            self.evict_host(prefetch_length)
            host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
579
580
581
582
        if host_indices is None:
            last_host_node.release_host()
            # no sufficient host memory for prefetch
            return
583
584
585
586
587
588
589
590
591
        operation = self.cache_controller.prefetch(
            req_id, host_indices, new_input_tokens, last_hash
        )
        self.ongoing_prefetch[req_id] = (
            last_host_node,
            new_input_tokens,
            host_indices,
            operation,
        )
pansicheng's avatar
pansicheng committed
592
        self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607

    def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
        node.last_access_time = time.monotonic()
        if len(key) == 0:
            return 0

        child_key = self.get_child_key_fn(key)

        matched_length = 0
        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
            node.last_access_time = time.monotonic()
            prefix_len = self.key_match_fn(node.key, key)
            key = key[prefix_len:]
            host_value = host_value[prefix_len:]
608
            hash_value = hash_value[prefix_len // self.page_size :]
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
            matched_length += prefix_len

            if prefix_len < len(node.key):
                new_node = self._split_node(node.key, node, prefix_len)
                node = new_node

            if len(key):
                child_key = self.get_child_key_fn(key)

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = None
            new_node.host_value = host_value
            new_node.hash_value = hash_value
            node.children[child_key] = new_node
        return matched_length

628
    def _match_prefix_helper(self, node: TreeNode, key: List):
629
        node.last_access_time = time.monotonic()
630
        child_key = self.get_child_key_fn(key)
631
        value = []
632
633
634

        while len(key) > 0 and child_key in node.children.keys():
            child = node.children[child_key]
635
            child.last_access_time = time.monotonic()
636
            prefix_len = self.key_match_fn(child.key, key)
637
638
639
640
            if prefix_len < len(child.key):
                new_node = self._split_node(child.key, child, prefix_len)
                if not new_node.evicted:
                    value.append(new_node.value)
641
642
                node = new_node
                break
643
644
645
            else:
                if not child.evicted:
                    value.append(child.value)
646
647
                node = child
                key = key[prefix_len:]
648
649
650
651

                if len(key):
                    child_key = self.get_child_key_fn(key)

652
        return value, node
653
654
655
656

    def _split_node(self, key, child: TreeNode, split_len: int):
        # child node split into new_node -> child
        new_node = TreeNode()
657
        new_node.children = {self.get_child_key_fn(key[split_len:]): child}
658
659
660
661
        new_node.parent = child.parent
        new_node.lock_ref = child.lock_ref
        new_node.key = child.key[:split_len]
        new_node.loading = child.loading
662
        new_node.hit_count = child.hit_count
663
664
665
666
667
668
669

        # split value and host value if exists
        if child.evicted:
            new_node.value = None
        else:
            new_node.value = child.value[:split_len]
            child.value = child.value[split_len:]
Zhiqiang Xie's avatar
Zhiqiang Xie committed
670
        if child.backuped:
671
672
            new_node.host_value = child.host_value[:split_len]
            child.host_value = child.host_value[split_len:]
673
674
675
676

        if child.hash_value:
            new_node.hash_value = child.hash_value[: split_len // self.page_size]
            child.hash_value = child.hash_value[split_len // self.page_size :]
677
678
        child.parent = new_node
        child.key = child.key[split_len:]
679
        new_node.parent.children[self.get_child_key_fn(key)] = new_node
680
681
        return new_node

682
    def insert(self, key: List, value, chunked=False):
683
684
685
        if len(key) == 0:
            return 0

686
        node = self.root_node
687
688
689
690
691
        child_key = self.get_child_key_fn(key)
        total_prefix_length = 0

        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
692
            node.last_access_time = time.monotonic()
693
            prefix_len = self.key_match_fn(node.key, key)
694

695
696
            if prefix_len == len(node.key):
                if node.evicted:
697
698
                    # change the reference if the node is evicted
                    # this often happens in the case of KV cache recomputation
699
700
701
                    node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(node.host_value)
                    self.evictable_size_ += len(node.value)
702
                else:
703
                    self._inc_hit_count(node, chunked)
704
                    total_prefix_length += prefix_len
705
            else:
706
707
708
709
710
711
712
                # partial match, split the node
                new_node = self._split_node(node.key, node, prefix_len)
                if new_node.evicted:
                    new_node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(new_node.host_value)
                    self.evictable_size_ += len(new_node.value)
                else:
713
                    self._inc_hit_count(new_node, chunked)
714
715
716
717
718
719
720
721
                    total_prefix_length += prefix_len
                node = new_node

            key = key[prefix_len:]
            value = value[prefix_len:]

            if len(key):
                child_key = self.get_child_key_fn(key)
722
723
724
725
726
727

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = value
728
            node.children[child_key] = new_node
729
730
            self.evictable_size_ += len(value)

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
            if self.enable_storage:
                last_hash = node.get_last_hash_value()
                assert (node == self.root_node) or (
                    last_hash is not None
                ), "Parent node must have a hash value with storage enabled"
                new_node.hash_value = []
                for idx in range(0, len(key), self.page_size):
                    new_node.hash_value.append(
                        self.cache_controller.get_hash_str(
                            key[idx : idx + self.page_size],
                            prior_hash=last_hash,
                        )
                    )
                    last_hash = new_node.hash_value[-1]

Zhiqiang Xie's avatar
Zhiqiang Xie committed
746
            if self.cache_controller.write_policy != "write_back":
747
                self._inc_hit_count(new_node, chunked)
748
        return total_prefix_length
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773

    def _collect_leaves_device(self):
        def is_leaf(node):
            if node.evicted:
                return False
            if node == self.root_node:
                return False
            if len(node.children) == 0:
                return True
            for child in node.children.values():
                if not child.evicted:
                    return False
            return True

        ret_list = []
        stack = [self.root_node]
        while stack:
            cur_node = stack.pop()
            if is_leaf(cur_node):
                ret_list.append(cur_node)
            else:
                for cur_child in cur_node.children.values():
                    if not cur_child.evicted:
                        stack.append(cur_child)
        return ret_list
774
775
776
777
778

    def release_aborted_request(self, rid: str):
        if rid not in self.ongoing_prefetch:
            return

779
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
780
781
782
783
784
785
786
        if operation.host_indices is None:
            return

        completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
        if self.tp_world_size > 1:
            torch.distributed.barrier(group=self.tp_group)
        last_host_node.release_host()
787
        del self.ongoing_prefetch[rid]
788
789
        self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)