hiradix_cache.py 27 KB
Newer Older
1
2
import heapq
import logging
3
import threading
4
import time
pansicheng's avatar
pansicheng committed
5
from queue import Queue
6
7
8
9
from typing import List, Optional

import torch

pansicheng's avatar
pansicheng committed
10
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
11
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
12
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
13
from sglang.srt.mem_cache.memory_pool import (
14
15
    MHATokenToKVPool,
    MLATokenToKVPool,
16
17
    ReqToTokenPool,
)
18
19
20
21
from sglang.srt.mem_cache.memory_pool_host import (
    MHATokenToKVPoolHost,
    MLATokenToKVPoolHost,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
24
25
26
27
28
29
30
31

logger = logging.getLogger(__name__)


class HiRadixCache(RadixCache):

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
32
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
33
        tp_cache_group: torch.distributed.ProcessGroup,
34
        page_size: int,
35
        hicache_ratio: float,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
36
37
        hicache_size: int,
        hicache_write_policy: str,
38
        hicache_io_backend: str,
39
        hicache_mem_layout: str,
40
        hicache_storage_backend: Optional[str] = None,
pansicheng's avatar
pansicheng committed
41
        hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
    ):
43
44
45
46
47
48
49
50

        if hicache_io_backend == "direct":
            if hicache_mem_layout == "page_first":
                hicache_mem_layout = "layer_first"
                logger.warning(
                    "Page first layout is not supported with direct IO backend, switching to layer first layout"
                )

51
52
        self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
        if isinstance(self.kv_cache, MHATokenToKVPool):
53
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
54
55
56
57
58
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
59
            )
60
        elif isinstance(self.kv_cache, MLATokenToKVPool):
61
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(
62
63
64
65
66
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
67
            )
68
        else:
69
            raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
70

71
        self.tp_group = tp_cache_group
72
        self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
73
74
75
        self.enable_storage = hicache_storage_backend is not None
        # todo: customizable storage prefetch threshold
        self.prefetch_threshold = 256
76
77

        self.load_cache_event = threading.Event()
78
        self.cache_controller = HiCacheController(
79
80
            token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
81
            page_size,
82
            self.tp_group,
83
            load_cache_event=self.load_cache_event,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
84
            write_policy=hicache_write_policy,
85
            io_backend=hicache_io_backend,
86
87
            storage_backend=hicache_storage_backend,
            prefetch_threshold=self.prefetch_threshold,
88
89
        )

pansicheng's avatar
pansicheng committed
90
91
92
93
94
95
96
        self.prefetch_stop_policy = hicache_storage_prefetch_policy
        # todo: customizable storage prefetch timeout
        self.prefetch_timeout = 3  # seconds
        logger.info(
            f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
        )

97
98
99
100
        # record the nodes with ongoing write through
        self.ongoing_write_through = {}
        # record the node segments with ongoing load back
        self.ongoing_load_back = {}
101
102
103
        # record the ongoing prefetch requests
        self.ongoing_prefetch = {}
        self.ongoing_backup = {}
104
        # todo: dynamically adjust the threshold
Zhiqiang Xie's avatar
Zhiqiang Xie committed
105
106
107
        self.write_through_threshold = (
            1 if hicache_write_policy == "write_through" else 3
        )
108
109
110
        self.write_through_threshold_storage = (
            1 if hicache_write_policy == "write_through" else 3
        )
111
        self.load_back_threshold = 10
112
        super().__init__(
113
            req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
114
        )
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    def reset(self):
        TreeNode.counter = 0
        self.cache_controller.reset()
        self.token_to_kv_pool_host.clear()
        super().reset()

    def get_height(self, node: TreeNode):
        height = 0
        while node != self.root_node:
            node = node.parent
            height += 1
        return height

Zhiqiang Xie's avatar
Zhiqiang Xie committed
129
    def write_backup(self, node: TreeNode, write_back=False):
130
131
132
133
134
135
136
137
138
139
140
141
        host_indices = self.cache_controller.write(
            device_indices=node.value,
            node_id=node.id,
        )
        if host_indices is None:
            self.evict_host(len(node.value))
            host_indices = self.cache_controller.write(
                device_indices=node.value,
                node_id=node.id,
            )
        if host_indices is not None:
            node.host_value = host_indices
142
            assert len(node.host_value) > 0
143
            self.ongoing_write_through[node.id] = node
Zhiqiang Xie's avatar
Zhiqiang Xie committed
144
145
146
            if not write_back:
                # no need to lock nodes if write back
                self.inc_lock_ref(node)
147
        else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
148
            return 0
149
150
151

        return len(host_indices)

152
153
154
155
156
157
158
    def write_backup_storage(self, node: TreeNode):
        operation_id = self.cache_controller.write_storage(
            node.host_value, node.key, node.parent.get_last_hash_value()
        )
        self.ongoing_backup[operation_id] = node
        node.protect_host()

159
    def inc_hit_count(self, node: TreeNode):
160
        if self.cache_controller.write_policy == "write_back":
161
162
            return
        node.hit_count += 1
163
164
165
166
167
168
169
170
171
172
173
174
175

        if not node.backuped:
            if node.hit_count >= self.write_through_threshold:
                # write to host if the node is not backuped
                self.write_backup(node)
        else:
            if (
                self.enable_storage
                and (not node.backuped_storage)
                and node.hit_count >= self.write_through_threshold_storage
            ):
                # if the node is backuped on host memory but not on storage
                self.write_backup_storage(node)
176

Zhiqiang Xie's avatar
Zhiqiang Xie committed
177
178
179
180
181
182
183
    def writing_check(self, write_back=False):
        if write_back:
            # blocking till all write back complete
            while len(self.ongoing_write_through) > 0:
                ack_id = self.cache_controller.ack_write_queue.get()
                del self.ongoing_write_through[ack_id]
            return
184
185
186
        queue_size = torch.tensor(
            self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
        )
187
        if self.tp_world_size > 1:
188
189
190
191
192
193
194
195
196
197
            # synchrnoize TP workers to make the same update to radix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            ack_id = self.cache_controller.ack_write_queue.get()
            self.dec_lock_ref(self.ongoing_write_through[ack_id])
            del self.ongoing_write_through[ack_id]
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    def loading_check(self):
        while not self.cache_controller.ack_load_queue.empty():
            try:
                ack_id = self.cache_controller.ack_load_queue.get_nowait()
                start_node, end_node = self.ongoing_load_back[ack_id]
                self.dec_lock_ref(end_node)
                while end_node != start_node:
                    assert end_node.loading
                    end_node.loading = False
                    end_node = end_node.parent
                # clear the reference
                del self.ongoing_load_back[ack_id]
            except Exception:
                break

    def evictable_size(self):
        return self.evictable_size_

Lianmin Zheng's avatar
Lianmin Zheng committed
217
    def evict(self, num_tokens: int):
218
219
220
221
        leaves = self._collect_leaves_device()
        heapq.heapify(leaves)

        num_evicted = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
222
        write_back_nodes = []
223
224
225
226
227
228
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x.lock_ref > 0:
                continue

Zhiqiang Xie's avatar
Zhiqiang Xie committed
229
            if not x.backuped:
230
                if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
231
232
233
                    # write to host if the node is not backuped
                    num_evicted += self.write_backup(x, write_back=True)
                    write_back_nodes.append(x)
234
                else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
235
                    num_evicted += self._evict_regular(x)
236
            else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
237
                num_evicted += self._evict_backuped(x)
238
239

            for child in x.parent.children.values():
Zhiqiang Xie's avatar
Zhiqiang Xie committed
240
                if child in write_back_nodes:
241
242
243
244
245
246
247
248
                    continue
                if not child.evicted:
                    break
            else:
                # all children are evicted or no children
                heapq.heappush(leaves, x.parent)

        if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
249
250
251
252
            self.writing_check(write_back=True)
            for node in write_back_nodes:
                assert node.backuped
                self._evict_backuped(node)
253

Zhiqiang Xie's avatar
Zhiqiang Xie committed
254
    def _evict_backuped(self, node: TreeNode):
255
256
257
258
259
260
261
        # evict a node already written to host
        num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
        assert num_evicted > 0
        self.evictable_size_ -= num_evicted
        node.value = None
        return num_evicted

Zhiqiang Xie's avatar
Zhiqiang Xie committed
262
    def _evict_regular(self, node: TreeNode):
263
        # evict a node not initiated write to host
264
        self.cache_controller.mem_pool_device_allocator.free(node.value)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        num_evicted = len(node.value)
        self._delete_leaf(node)
        return num_evicted

    def evict_host(self, num_tokens: int):
        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)
            if x == self.root_node:
                break
            # only evict the host value of evicted nodes
            if not x.evicted:
                continue

282
283
284
285
            # node is protected from eviction as it has ongoing prefetch or backup to storage
            if x.host_ref_counter > 0:
                continue

286
287
            num_evicted += self.cache_controller.evict_host(x.host_value)

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            for k, v in x.parent.children.items():
                if v == x:
                    break
            del x.parent.children[k]

            if len(x.parent.children) == 0 and x.parent.evicted:
                heapq.heappush(leaves, x.parent)

    def load_back(
        self, node: TreeNode, mem_quota: Optional[int] = None
    ) -> Optional[torch.Tensor]:
        # todo: more loading policies

        last_hit_node = node
        nodes_to_load = []
        while node.evicted:
            assert (
                node.backuped
            ), "No backup available on evicted nodes, should not happen"
            nodes_to_load.insert(0, node)
            node = node.parent
        else:
            ancester_node = node

        # protect the ancestor nodes from eviction
        delta = self.inc_lock_ref(ancester_node)

        # load it all or not at all
        host_indices = torch.cat([n.host_value for n in nodes_to_load])
        if len(host_indices) < self.load_back_threshold or (
            len(host_indices) > mem_quota + delta if mem_quota is not None else False
        ):
            # skip loading back if the total size is too small or exceeding the memory quota
            self.dec_lock_ref(ancester_node)
            return None

        device_indices = self.cache_controller.load(
            host_indices=host_indices, node_id=last_hit_node.id
        )
        if device_indices is None:
            self.evict(len(host_indices))
            device_indices = self.cache_controller.load(
                host_indices=host_indices, node_id=last_hit_node.id
            )
        self.dec_lock_ref(ancester_node)
        if device_indices is None:
            # no sufficient GPU memory to load back KV caches
            return None

        self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
        offset = 0
        for node in nodes_to_load:
            node.value = device_indices[offset : offset + len(node.host_value)]
            offset += len(node.host_value)
            node.loading = True
        self.evictable_size_ += len(device_indices)
        self.inc_lock_ref(last_hit_node)

        return device_indices

    def init_load_back(
        self,
        last_node: TreeNode,
351
        host_hit_length: int,
352
353
        mem_quota: Optional[int] = None,
    ):
354
        _ = host_hit_length  # unused, but kept for compatibility
355
356
357
358
359
360
        if last_node.evicted:
            loading_values = self.load_back(last_node, mem_quota)
            if loading_values is not None:
                logger.debug(
                    f"loading back {len(loading_values)} tokens for node {last_node.id}"
                )
361
                return loading_values, last_node
362
363
364
365

            while last_node.evicted:
                last_node = last_node.parent

366
367
368
369
        return (
            torch.empty((0,), dtype=torch.int64, device=self.device),
            last_node,
        )
370

371
    def ready_to_load_host_cache(self):
372
        producer_index = self.cache_controller.layer_done_counter.next_producer()
373
        self.load_cache_event.set()
374
        return producer_index
375

376
377
378
    def check_hicache_events(self):
        self.writing_check()
        self.loading_check()
379
380
381
382
383
384
385
386
        if self.enable_storage:
            self.check_revoked_prefetch()
            self.check_backup_progress()

    def check_revoked_prefetch(self):
        queue_size = torch.tensor(
            self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
        )
387
        if self.tp_world_size > 1:
388
389
390
391
392
393
394
395
396
            # synchrnoize TP workers to make the same update to hiradix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            req_id = self.cache_controller.prefetch_revoke_queue.get()
            if req_id in self.ongoing_prefetch:
pansicheng's avatar
pansicheng committed
397
                last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
398
399
                last_host_node.release_host()
                del self.ongoing_prefetch[req_id]
pansicheng's avatar
pansicheng committed
400
                self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
401
402
403
            else:
                # the revoked operation already got terminated
                pass
404
405
406
407
408

    def check_backup_progress(self):
        queue_size = torch.tensor(
            self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
        )
409
        if self.tp_world_size > 1:
410
411
412
413
414
415
416
            # synchrnoize TP workers to make the same update to hiradix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
417
418
419
420
            ack_id, hash_value, completed_tokens = (
                self.cache_controller.ack_backup_queue.get()
            )
            host_node = self.ongoing_backup[ack_id]
421
422
423
            if completed_tokens == 0:
                host_node.hash_value = None
            elif completed_tokens < len(host_node.key):
424
425
426
                # backup is only partially successful, split the node
                new_node = self._split_node(host_node.key, host_node, completed_tokens)
                new_node.hash_value = hash_value
427
428
            else:
                host_node.hash_value = hash_value
429
            host_node.release_host()
430
431
            del self.ongoing_backup[ack_id]

pansicheng's avatar
pansicheng committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    def can_terminate_prefetch(self, operation: PrefetchOperation):
        can_terminate = True

        if self.prefetch_stop_policy == "best_effort":
            return can_terminate

        completed = (
            operation.completed_tokens == len(operation.hash_value) * self.page_size
        )

        if self.prefetch_stop_policy == "wait_complete":
            can_terminate = completed
        elif self.prefetch_stop_policy == "timeout":
            can_terminate = completed or (
                time.monotonic() - operation.start_time > self.prefetch_timeout
            )
        else:
            # unknown prefetch stop policy, just return True
            return True

        if self.tp_world_size > 1:
            can_terminate = torch.tensor(can_terminate, dtype=torch.int)
            torch.distributed.all_reduce(
                can_terminate,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
            can_terminate = bool(can_terminate.item())

        return can_terminate

    def check_prefetch_progress(self, req_id: str) -> bool:
464
465
        if req_id not in self.ongoing_prefetch:
            # there is no ongoing prefetch for this request or it has been revoked
pansicheng's avatar
pansicheng committed
466
            return True
467
468
469
470
471
472

        # todo: more policies for prefetch progress such as timeout
        # the current policy is to prefetch with best effort and terminate when queuing is over
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
            req_id
        ]
473

Zhiqiang Xie's avatar
Zhiqiang Xie committed
474
475
476
477
        if operation.host_indices is None:
            # prefetch has not been issued due to insufficient host memory
            return True

pansicheng's avatar
pansicheng committed
478
479
480
        if not self.can_terminate_prefetch(operation):
            return False

481
482
483
484
485
        completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
            operation
        )
        logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

486
        min_completed_tokens = completed_tokens
pansicheng's avatar
pansicheng committed
487
        if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
488
            # synchrnoize TP workers to make the same update to hiradix cache
489
490
491
            completed_tokens_tensor = torch.tensor(
                min_completed_tokens, dtype=torch.int
            )
492
            torch.distributed.all_reduce(
493
                completed_tokens_tensor,
494
495
496
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
497
            min_completed_tokens = completed_tokens_tensor.item()
498
499
500
501
502
503
        fetched_token_ids = token_ids[:min_completed_tokens]
        written_indices = host_indices[:min_completed_tokens]
        matched_length = self._insert_helper_host(
            last_host_node,
            fetched_token_ids,
            written_indices,
504
            hash_value[: min_completed_tokens // self.page_size],
505
        )
506
507
        if len(written_indices):
            self.cache_controller.mem_pool_host.update_prefetch(written_indices)
508
509
510
511
512
513
514

        self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
        self.cache_controller.mem_pool_host.free(
            host_indices[min_completed_tokens:completed_tokens]
        )
        last_host_node.release_host()
        del self.ongoing_prefetch[req_id]
pansicheng's avatar
pansicheng committed
515
516
517
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

        return True
518
519

    def match_prefix(self, key: List[int], **kwargs):
520
521
        empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
        if self.disable or len(key) == 0:
522
523
524
525
526
527
            return MatchResult(
                device_indices=empty_value,
                last_device_node=self.root_node,
                last_host_node=self.root_node,
                host_hit_length=0,
            )
528
529
530
531

        if self.page_size != 1:
            page_aligned_len = len(key) // self.page_size * self.page_size
            key = key[:page_aligned_len]
532
533
534

        value, last_node = self._match_prefix_helper(self.root_node, key)
        if value:
535
            value = torch.cat(value)
536
        else:
537
            value = empty_value
538

539
540
        host_hit_length = 0
        last_host_node = last_node
541
        while last_node.evicted:
542
            host_hit_length += len(last_node.host_value)
543
544
            last_node = last_node.parent

545
546
547
548
549
550
        return MatchResult(
            device_indices=value,
            last_device_node=last_node,
            last_host_node=last_host_node,
            host_hit_length=host_hit_length,
        )
551

552
553
554
555
556
557
558
    def prefetch_from_storage(
        self,
        req_id: str,
        last_host_node: TreeNode,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
    ):
559
560
561
562
563
564
        # align the number of fetching tokens to the page size
        prefetch_length = len(new_input_tokens) - (
            len(new_input_tokens) % self.page_size
        )
        new_input_tokens = new_input_tokens[:prefetch_length]
        if not self.enable_storage or prefetch_length < self.prefetch_threshold:
565
566
567
            return

        last_host_node.protect_host()
568
        host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
569
        if host_indices is None:
570
571
            self.evict_host(prefetch_length)
            host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
572
573
574
575
576
577
578
579
580
        operation = self.cache_controller.prefetch(
            req_id, host_indices, new_input_tokens, last_hash
        )
        self.ongoing_prefetch[req_id] = (
            last_host_node,
            new_input_tokens,
            host_indices,
            operation,
        )
pansicheng's avatar
pansicheng committed
581
        self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596

    def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
        node.last_access_time = time.monotonic()
        if len(key) == 0:
            return 0

        child_key = self.get_child_key_fn(key)

        matched_length = 0
        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
            node.last_access_time = time.monotonic()
            prefix_len = self.key_match_fn(node.key, key)
            key = key[prefix_len:]
            host_value = host_value[prefix_len:]
597
            hash_value = hash_value[prefix_len // self.page_size :]
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
            matched_length += prefix_len

            if prefix_len < len(node.key):
                new_node = self._split_node(node.key, node, prefix_len)
                node = new_node

            if len(key):
                child_key = self.get_child_key_fn(key)

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = None
            new_node.host_value = host_value
            new_node.hash_value = hash_value
            node.children[child_key] = new_node
        return matched_length

617
    def _match_prefix_helper(self, node: TreeNode, key: List):
618
        node.last_access_time = time.monotonic()
619
        child_key = self.get_child_key_fn(key)
620
        value = []
621
622
623

        while len(key) > 0 and child_key in node.children.keys():
            child = node.children[child_key]
624
            child.last_access_time = time.monotonic()
625
            prefix_len = self.key_match_fn(child.key, key)
626
627
628
629
            if prefix_len < len(child.key):
                new_node = self._split_node(child.key, child, prefix_len)
                if not new_node.evicted:
                    value.append(new_node.value)
630
631
                node = new_node
                break
632
633
634
            else:
                if not child.evicted:
                    value.append(child.value)
635
636
                node = child
                key = key[prefix_len:]
637
638
639
640

                if len(key):
                    child_key = self.get_child_key_fn(key)

641
        return value, node
642
643
644
645

    def _split_node(self, key, child: TreeNode, split_len: int):
        # child node split into new_node -> child
        new_node = TreeNode()
646
        new_node.children = {self.get_child_key_fn(key[split_len:]): child}
647
648
649
650
        new_node.parent = child.parent
        new_node.lock_ref = child.lock_ref
        new_node.key = child.key[:split_len]
        new_node.loading = child.loading
651
        new_node.hit_count = child.hit_count
652
653
654
655
656
657
658

        # split value and host value if exists
        if child.evicted:
            new_node.value = None
        else:
            new_node.value = child.value[:split_len]
            child.value = child.value[split_len:]
Zhiqiang Xie's avatar
Zhiqiang Xie committed
659
        if child.backuped:
660
661
            new_node.host_value = child.host_value[:split_len]
            child.host_value = child.host_value[split_len:]
662
663
664
665

        if child.hash_value:
            new_node.hash_value = child.hash_value[: split_len // self.page_size]
            child.hash_value = child.hash_value[split_len // self.page_size :]
666
667
        child.parent = new_node
        child.key = child.key[split_len:]
668
        new_node.parent.children[self.get_child_key_fn(key)] = new_node
669
670
671
        return new_node

    def _insert_helper(self, node: TreeNode, key: List, value):
672
        node.last_access_time = time.monotonic()
673
674
675
        if len(key) == 0:
            return 0

676
677
678
679
680
        child_key = self.get_child_key_fn(key)
        total_prefix_length = 0

        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
681
            node.last_access_time = time.monotonic()
682
            prefix_len = self.key_match_fn(node.key, key)
683

684
685
            if prefix_len == len(node.key):
                if node.evicted:
686
687
                    # change the reference if the node is evicted
                    # this often happens in the case of KV cache recomputation
688
689
690
                    node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(node.host_value)
                    self.evictable_size_ += len(node.value)
691
                else:
692
693
                    self.inc_hit_count(node)
                    total_prefix_length += prefix_len
694
            else:
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
                # partial match, split the node
                new_node = self._split_node(node.key, node, prefix_len)
                if new_node.evicted:
                    new_node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(new_node.host_value)
                    self.evictable_size_ += len(new_node.value)
                else:
                    self.inc_hit_count(new_node)
                    total_prefix_length += prefix_len
                node = new_node

            key = key[prefix_len:]
            value = value[prefix_len:]

            if len(key):
                child_key = self.get_child_key_fn(key)
711
712
713
714
715
716

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = value
717
            node.children[child_key] = new_node
718
719
            self.evictable_size_ += len(value)

Zhiqiang Xie's avatar
Zhiqiang Xie committed
720
721
            if self.cache_controller.write_policy != "write_back":
                self.inc_hit_count(new_node)
722
        return total_prefix_length
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747

    def _collect_leaves_device(self):
        def is_leaf(node):
            if node.evicted:
                return False
            if node == self.root_node:
                return False
            if len(node.children) == 0:
                return True
            for child in node.children.values():
                if not child.evicted:
                    return False
            return True

        ret_list = []
        stack = [self.root_node]
        while stack:
            cur_node = stack.pop()
            if is_leaf(cur_node):
                ret_list.append(cur_node)
            else:
                for cur_child in cur_node.children.values():
                    if not cur_child.evicted:
                        stack.append(cur_child)
        return ret_list