hiradix_cache.py 27.8 KB
Newer Older
1
2
import heapq
import logging
3
import threading
4
import time
pansicheng's avatar
pansicheng committed
5
from queue import Queue
6
7
8
9
from typing import List, Optional

import torch

pansicheng's avatar
pansicheng committed
10
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
11
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
12
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
13
from sglang.srt.mem_cache.memory_pool import (
14
15
    MHATokenToKVPool,
    MLATokenToKVPool,
16
17
    ReqToTokenPool,
)
18
19
20
21
from sglang.srt.mem_cache.memory_pool_host import (
    MHATokenToKVPoolHost,
    MLATokenToKVPoolHost,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
24
25
26
27
28
29
30
31

logger = logging.getLogger(__name__)


class HiRadixCache(RadixCache):

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
32
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
33
        tp_cache_group: torch.distributed.ProcessGroup,
34
        page_size: int,
35
        hicache_ratio: float,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
36
37
        hicache_size: int,
        hicache_write_policy: str,
38
        hicache_io_backend: str,
39
        hicache_mem_layout: str,
40
        hicache_storage_backend: Optional[str] = None,
pansicheng's avatar
pansicheng committed
41
        hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
43
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
44
    ):
45
46
47
48
49
50
51
52

        if hicache_io_backend == "direct":
            if hicache_mem_layout == "page_first":
                hicache_mem_layout = "layer_first"
                logger.warning(
                    "Page first layout is not supported with direct IO backend, switching to layer first layout"
                )

53
54
        self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
        if isinstance(self.kv_cache, MHATokenToKVPool):
55
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
56
57
58
59
60
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
61
            )
62
        elif isinstance(self.kv_cache, MLATokenToKVPool):
63
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(
64
65
66
67
68
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
69
            )
70
        else:
71
            raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
72

73
        self.tp_group = tp_cache_group
74
        self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
75
        self.enable_storage = hicache_storage_backend is not None
76
        # todo: customizable storage prefetch threshold and timeout
77
        self.prefetch_threshold = 256
78
79
        self.prefetch_timeout = 3  # seconds
        self.prefetch_stop_policy = hicache_storage_prefetch_policy
80
81

        self.load_cache_event = threading.Event()
82
        self.cache_controller = HiCacheController(
83
84
            token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
85
            page_size,
86
            self.tp_group,
87
            load_cache_event=self.load_cache_event,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
88
            write_policy=hicache_write_policy,
89
            io_backend=hicache_io_backend,
90
91
            storage_backend=hicache_storage_backend,
            prefetch_threshold=self.prefetch_threshold,
92
93
            model_name=model_name,
            storage_backend_extra_config=storage_backend_extra_config,
94
95
96
97
98
99
        )

        # record the nodes with ongoing write through
        self.ongoing_write_through = {}
        # record the node segments with ongoing load back
        self.ongoing_load_back = {}
100
101
102
        # record the ongoing prefetch requests
        self.ongoing_prefetch = {}
        self.ongoing_backup = {}
103
        # todo: dynamically adjust the threshold
Zhiqiang Xie's avatar
Zhiqiang Xie committed
104
105
106
        self.write_through_threshold = (
            1 if hicache_write_policy == "write_through" else 3
        )
107
108
109
        self.write_through_threshold_storage = (
            1 if hicache_write_policy == "write_through" else 3
        )
110
        self.load_back_threshold = 10
111
        super().__init__(
112
            req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
113
        )
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    def reset(self):
        TreeNode.counter = 0
        self.cache_controller.reset()
        self.token_to_kv_pool_host.clear()
        super().reset()

    def get_height(self, node: TreeNode):
        height = 0
        while node != self.root_node:
            node = node.parent
            height += 1
        return height

Zhiqiang Xie's avatar
Zhiqiang Xie committed
128
    def write_backup(self, node: TreeNode, write_back=False):
129
130
131
132
133
134
135
136
137
138
139
140
        host_indices = self.cache_controller.write(
            device_indices=node.value,
            node_id=node.id,
        )
        if host_indices is None:
            self.evict_host(len(node.value))
            host_indices = self.cache_controller.write(
                device_indices=node.value,
                node_id=node.id,
            )
        if host_indices is not None:
            node.host_value = host_indices
141
            assert len(node.host_value) > 0
142
            self.ongoing_write_through[node.id] = node
Zhiqiang Xie's avatar
Zhiqiang Xie committed
143
144
145
            if not write_back:
                # no need to lock nodes if write back
                self.inc_lock_ref(node)
146
        else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
147
            return 0
148
149
150

        return len(host_indices)

151
152
    def write_backup_storage(self, node: TreeNode):
        operation_id = self.cache_controller.write_storage(
153
            node.host_value, node.key, node.hash_value
154
155
156
157
        )
        self.ongoing_backup[operation_id] = node
        node.protect_host()

158
    def inc_hit_count(self, node: TreeNode):
159
        if self.cache_controller.write_policy == "write_back":
160
161
            return
        node.hit_count += 1
162
163
164
165
166
167
168
169
170
171
172
173
174

        if not node.backuped:
            if node.hit_count >= self.write_through_threshold:
                # write to host if the node is not backuped
                self.write_backup(node)
        else:
            if (
                self.enable_storage
                and (not node.backuped_storage)
                and node.hit_count >= self.write_through_threshold_storage
            ):
                # if the node is backuped on host memory but not on storage
                self.write_backup_storage(node)
175

Zhiqiang Xie's avatar
Zhiqiang Xie committed
176
177
178
179
180
181
182
    def writing_check(self, write_back=False):
        if write_back:
            # blocking till all write back complete
            while len(self.ongoing_write_through) > 0:
                ack_id = self.cache_controller.ack_write_queue.get()
                del self.ongoing_write_through[ack_id]
            return
183
184
185
        queue_size = torch.tensor(
            self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
        )
186
        if self.tp_world_size > 1:
187
188
189
190
191
192
193
194
195
196
            # synchrnoize TP workers to make the same update to radix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            ack_id = self.cache_controller.ack_write_queue.get()
            self.dec_lock_ref(self.ongoing_write_through[ack_id])
            del self.ongoing_write_through[ack_id]
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

    def loading_check(self):
        while not self.cache_controller.ack_load_queue.empty():
            try:
                ack_id = self.cache_controller.ack_load_queue.get_nowait()
                start_node, end_node = self.ongoing_load_back[ack_id]
                self.dec_lock_ref(end_node)
                while end_node != start_node:
                    assert end_node.loading
                    end_node.loading = False
                    end_node = end_node.parent
                # clear the reference
                del self.ongoing_load_back[ack_id]
            except Exception:
                break

    def evictable_size(self):
        return self.evictable_size_

Lianmin Zheng's avatar
Lianmin Zheng committed
216
    def evict(self, num_tokens: int):
217
218
219
220
        leaves = self._collect_leaves_device()
        heapq.heapify(leaves)

        num_evicted = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
221
        write_back_nodes = []
222
223
224
225
226
227
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x.lock_ref > 0:
                continue

Zhiqiang Xie's avatar
Zhiqiang Xie committed
228
            if not x.backuped:
229
                if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
230
231
232
                    # write to host if the node is not backuped
                    num_evicted += self.write_backup(x, write_back=True)
                    write_back_nodes.append(x)
233
                else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
234
                    num_evicted += self._evict_regular(x)
235
            else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
236
                num_evicted += self._evict_backuped(x)
237
238

            for child in x.parent.children.values():
Zhiqiang Xie's avatar
Zhiqiang Xie committed
239
                if child in write_back_nodes:
240
241
242
243
244
245
246
247
                    continue
                if not child.evicted:
                    break
            else:
                # all children are evicted or no children
                heapq.heappush(leaves, x.parent)

        if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
248
249
250
251
            self.writing_check(write_back=True)
            for node in write_back_nodes:
                assert node.backuped
                self._evict_backuped(node)
252

Zhiqiang Xie's avatar
Zhiqiang Xie committed
253
    def _evict_backuped(self, node: TreeNode):
254
255
256
257
258
259
260
        # evict a node already written to host
        num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
        assert num_evicted > 0
        self.evictable_size_ -= num_evicted
        node.value = None
        return num_evicted

Zhiqiang Xie's avatar
Zhiqiang Xie committed
261
    def _evict_regular(self, node: TreeNode):
262
        # evict a node not initiated write to host
263
        self.cache_controller.mem_pool_device_allocator.free(node.value)
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        num_evicted = len(node.value)
        self._delete_leaf(node)
        return num_evicted

    def evict_host(self, num_tokens: int):
        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)
            if x == self.root_node:
                break
            # only evict the host value of evicted nodes
            if not x.evicted:
                continue

281
282
283
284
            # node is protected from eviction as it has ongoing prefetch or backup to storage
            if x.host_ref_counter > 0:
                continue

285
286
            num_evicted += self.cache_controller.evict_host(x.host_value)

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            for k, v in x.parent.children.items():
                if v == x:
                    break
            del x.parent.children[k]

            if len(x.parent.children) == 0 and x.parent.evicted:
                heapq.heappush(leaves, x.parent)

    def load_back(
        self, node: TreeNode, mem_quota: Optional[int] = None
    ) -> Optional[torch.Tensor]:
        # todo: more loading policies

        last_hit_node = node
        nodes_to_load = []
        while node.evicted:
            assert (
                node.backuped
            ), "No backup available on evicted nodes, should not happen"
            nodes_to_load.insert(0, node)
            node = node.parent
        else:
            ancester_node = node

        # protect the ancestor nodes from eviction
        delta = self.inc_lock_ref(ancester_node)

        # load it all or not at all
        host_indices = torch.cat([n.host_value for n in nodes_to_load])
        if len(host_indices) < self.load_back_threshold or (
            len(host_indices) > mem_quota + delta if mem_quota is not None else False
        ):
            # skip loading back if the total size is too small or exceeding the memory quota
            self.dec_lock_ref(ancester_node)
            return None

        device_indices = self.cache_controller.load(
            host_indices=host_indices, node_id=last_hit_node.id
        )
        if device_indices is None:
            self.evict(len(host_indices))
            device_indices = self.cache_controller.load(
                host_indices=host_indices, node_id=last_hit_node.id
            )
        self.dec_lock_ref(ancester_node)
        if device_indices is None:
            # no sufficient GPU memory to load back KV caches
            return None

        self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
        offset = 0
        for node in nodes_to_load:
            node.value = device_indices[offset : offset + len(node.host_value)]
            offset += len(node.host_value)
            node.loading = True
        self.evictable_size_ += len(device_indices)
        self.inc_lock_ref(last_hit_node)

        return device_indices

    def init_load_back(
        self,
        last_node: TreeNode,
350
        host_hit_length: int,
351
352
        mem_quota: Optional[int] = None,
    ):
353
        _ = host_hit_length  # unused, but kept for compatibility
354
355
356
357
358
359
        if last_node.evicted:
            loading_values = self.load_back(last_node, mem_quota)
            if loading_values is not None:
                logger.debug(
                    f"loading back {len(loading_values)} tokens for node {last_node.id}"
                )
360
                return loading_values, last_node
361
362
363
364

            while last_node.evicted:
                last_node = last_node.parent

365
366
367
368
        return (
            torch.empty((0,), dtype=torch.int64, device=self.device),
            last_node,
        )
369

370
    def ready_to_load_host_cache(self):
371
        producer_index = self.cache_controller.layer_done_counter.next_producer()
372
        self.load_cache_event.set()
373
        return producer_index
374

375
376
377
    def check_hicache_events(self):
        self.writing_check()
        self.loading_check()
378
379
380
381
382
383
384
385
        if self.enable_storage:
            self.check_revoked_prefetch()
            self.check_backup_progress()

    def check_revoked_prefetch(self):
        queue_size = torch.tensor(
            self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
        )
386
        if self.tp_world_size > 1:
387
388
389
390
391
392
393
394
395
            # synchrnoize TP workers to make the same update to hiradix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            req_id = self.cache_controller.prefetch_revoke_queue.get()
            if req_id in self.ongoing_prefetch:
pansicheng's avatar
pansicheng committed
396
                last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
397
398
                last_host_node.release_host()
                del self.ongoing_prefetch[req_id]
pansicheng's avatar
pansicheng committed
399
                self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
400
401
402
            else:
                # the revoked operation already got terminated
                pass
403
404
405
406
407

    def check_backup_progress(self):
        queue_size = torch.tensor(
            self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
        )
408
        if self.tp_world_size > 1:
409
410
411
412
413
414
415
            # synchrnoize TP workers to make the same update to hiradix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
416
            ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
417
            host_node = self.ongoing_backup[ack_id]
418
419
420
421
422
423
424
425
426
427

            if completed_tokens > 0:
                if completed_tokens < len(host_node.key):
                    # backup is only partially successful, split the node
                    new_node = self._split_node(
                        host_node.key, host_node, completed_tokens
                    )
                    new_node.backuped_storage = True
                else:
                    host_node.backuped_storage = True
428
            host_node.release_host()
429
430
            del self.ongoing_backup[ack_id]

pansicheng's avatar
pansicheng committed
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def can_terminate_prefetch(self, operation: PrefetchOperation):
        can_terminate = True

        if self.prefetch_stop_policy == "best_effort":
            return can_terminate

        completed = (
            operation.completed_tokens == len(operation.hash_value) * self.page_size
        )

        if self.prefetch_stop_policy == "wait_complete":
            can_terminate = completed
        elif self.prefetch_stop_policy == "timeout":
            can_terminate = completed or (
                time.monotonic() - operation.start_time > self.prefetch_timeout
            )
        else:
            # unknown prefetch stop policy, just return True
            return True

        if self.tp_world_size > 1:
            can_terminate = torch.tensor(can_terminate, dtype=torch.int)
            torch.distributed.all_reduce(
                can_terminate,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
            can_terminate = bool(can_terminate.item())

        return can_terminate

    def check_prefetch_progress(self, req_id: str) -> bool:
463
464
        if req_id not in self.ongoing_prefetch:
            # there is no ongoing prefetch for this request or it has been revoked
pansicheng's avatar
pansicheng committed
465
            return True
466
467
468
469
470
471

        # todo: more policies for prefetch progress such as timeout
        # the current policy is to prefetch with best effort and terminate when queuing is over
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
            req_id
        ]
472

Zhiqiang Xie's avatar
Zhiqiang Xie committed
473
474
475
476
        if operation.host_indices is None:
            # prefetch has not been issued due to insufficient host memory
            return True

pansicheng's avatar
pansicheng committed
477
478
479
        if not self.can_terminate_prefetch(operation):
            return False

480
481
482
483
484
        completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
            operation
        )
        logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

485
        min_completed_tokens = completed_tokens
pansicheng's avatar
pansicheng committed
486
        if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
487
            # synchrnoize TP workers to make the same update to hiradix cache
488
489
490
            completed_tokens_tensor = torch.tensor(
                min_completed_tokens, dtype=torch.int
            )
491
            torch.distributed.all_reduce(
492
                completed_tokens_tensor,
493
494
495
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
496
            min_completed_tokens = completed_tokens_tensor.item()
497
498
499
500
501
502
        fetched_token_ids = token_ids[:min_completed_tokens]
        written_indices = host_indices[:min_completed_tokens]
        matched_length = self._insert_helper_host(
            last_host_node,
            fetched_token_ids,
            written_indices,
503
            hash_value[: min_completed_tokens // self.page_size],
504
        )
505
506
        if len(written_indices):
            self.cache_controller.mem_pool_host.update_prefetch(written_indices)
507
508
509
510
511
512
513

        self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
        self.cache_controller.mem_pool_host.free(
            host_indices[min_completed_tokens:completed_tokens]
        )
        last_host_node.release_host()
        del self.ongoing_prefetch[req_id]
pansicheng's avatar
pansicheng committed
514
515
516
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

        return True
517
518

    def match_prefix(self, key: List[int], **kwargs):
519
520
        empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
        if self.disable or len(key) == 0:
521
522
523
524
525
526
            return MatchResult(
                device_indices=empty_value,
                last_device_node=self.root_node,
                last_host_node=self.root_node,
                host_hit_length=0,
            )
527
528
529
530

        if self.page_size != 1:
            page_aligned_len = len(key) // self.page_size * self.page_size
            key = key[:page_aligned_len]
531
532
533

        value, last_node = self._match_prefix_helper(self.root_node, key)
        if value:
534
            value = torch.cat(value)
535
        else:
536
            value = empty_value
537

538
539
        host_hit_length = 0
        last_host_node = last_node
540
        while last_node.evicted:
541
            host_hit_length += len(last_node.host_value)
542
            last_node = last_node.parent
543
544
        while not last_host_node.backuped:
            last_host_node = last_host_node.parent
545

546
547
548
549
550
551
        return MatchResult(
            device_indices=value,
            last_device_node=last_node,
            last_host_node=last_host_node,
            host_hit_length=host_hit_length,
        )
552

553
554
555
556
557
558
559
    def prefetch_from_storage(
        self,
        req_id: str,
        last_host_node: TreeNode,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
    ):
560
561
562
563
564
565
        # align the number of fetching tokens to the page size
        prefetch_length = len(new_input_tokens) - (
            len(new_input_tokens) % self.page_size
        )
        new_input_tokens = new_input_tokens[:prefetch_length]
        if not self.enable_storage or prefetch_length < self.prefetch_threshold:
566
567
568
            return

        last_host_node.protect_host()
569
        host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
570
        if host_indices is None:
571
572
            self.evict_host(prefetch_length)
            host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
573
574
575
576
577
578
579
580
581
        operation = self.cache_controller.prefetch(
            req_id, host_indices, new_input_tokens, last_hash
        )
        self.ongoing_prefetch[req_id] = (
            last_host_node,
            new_input_tokens,
            host_indices,
            operation,
        )
pansicheng's avatar
pansicheng committed
582
        self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

    def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
        node.last_access_time = time.monotonic()
        if len(key) == 0:
            return 0

        child_key = self.get_child_key_fn(key)

        matched_length = 0
        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
            node.last_access_time = time.monotonic()
            prefix_len = self.key_match_fn(node.key, key)
            key = key[prefix_len:]
            host_value = host_value[prefix_len:]
598
            hash_value = hash_value[prefix_len // self.page_size :]
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
            matched_length += prefix_len

            if prefix_len < len(node.key):
                new_node = self._split_node(node.key, node, prefix_len)
                node = new_node

            if len(key):
                child_key = self.get_child_key_fn(key)

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = None
            new_node.host_value = host_value
            new_node.hash_value = hash_value
            node.children[child_key] = new_node
        return matched_length

618
    def _match_prefix_helper(self, node: TreeNode, key: List):
619
        node.last_access_time = time.monotonic()
620
        child_key = self.get_child_key_fn(key)
621
        value = []
622
623
624

        while len(key) > 0 and child_key in node.children.keys():
            child = node.children[child_key]
625
            child.last_access_time = time.monotonic()
626
            prefix_len = self.key_match_fn(child.key, key)
627
628
629
630
            if prefix_len < len(child.key):
                new_node = self._split_node(child.key, child, prefix_len)
                if not new_node.evicted:
                    value.append(new_node.value)
631
632
                node = new_node
                break
633
634
635
            else:
                if not child.evicted:
                    value.append(child.value)
636
637
                node = child
                key = key[prefix_len:]
638
639
640
641

                if len(key):
                    child_key = self.get_child_key_fn(key)

642
        return value, node
643
644
645
646

    def _split_node(self, key, child: TreeNode, split_len: int):
        # child node split into new_node -> child
        new_node = TreeNode()
647
        new_node.children = {self.get_child_key_fn(key[split_len:]): child}
648
649
650
651
        new_node.parent = child.parent
        new_node.lock_ref = child.lock_ref
        new_node.key = child.key[:split_len]
        new_node.loading = child.loading
652
        new_node.hit_count = child.hit_count
653
654
655
656
657
658
659

        # split value and host value if exists
        if child.evicted:
            new_node.value = None
        else:
            new_node.value = child.value[:split_len]
            child.value = child.value[split_len:]
Zhiqiang Xie's avatar
Zhiqiang Xie committed
660
        if child.backuped:
661
662
            new_node.host_value = child.host_value[:split_len]
            child.host_value = child.host_value[split_len:]
663
664
665
666

        if child.hash_value:
            new_node.hash_value = child.hash_value[: split_len // self.page_size]
            child.hash_value = child.hash_value[split_len // self.page_size :]
667
668
        child.parent = new_node
        child.key = child.key[split_len:]
669
        new_node.parent.children[self.get_child_key_fn(key)] = new_node
670
671
672
        return new_node

    def _insert_helper(self, node: TreeNode, key: List, value):
673
        node.last_access_time = time.monotonic()
674
675
676
        if len(key) == 0:
            return 0

677
678
679
680
681
        child_key = self.get_child_key_fn(key)
        total_prefix_length = 0

        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
682
            node.last_access_time = time.monotonic()
683
            prefix_len = self.key_match_fn(node.key, key)
684

685
686
            if prefix_len == len(node.key):
                if node.evicted:
687
688
                    # change the reference if the node is evicted
                    # this often happens in the case of KV cache recomputation
689
690
691
                    node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(node.host_value)
                    self.evictable_size_ += len(node.value)
692
                else:
693
694
                    self.inc_hit_count(node)
                    total_prefix_length += prefix_len
695
            else:
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
                # partial match, split the node
                new_node = self._split_node(node.key, node, prefix_len)
                if new_node.evicted:
                    new_node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(new_node.host_value)
                    self.evictable_size_ += len(new_node.value)
                else:
                    self.inc_hit_count(new_node)
                    total_prefix_length += prefix_len
                node = new_node

            key = key[prefix_len:]
            value = value[prefix_len:]

            if len(key):
                child_key = self.get_child_key_fn(key)
712
713
714
715
716
717

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = value
718
            node.children[child_key] = new_node
719
720
            self.evictable_size_ += len(value)

721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
            if self.enable_storage:
                last_hash = node.get_last_hash_value()
                assert (node == self.root_node) or (
                    last_hash is not None
                ), "Parent node must have a hash value with storage enabled"
                new_node.hash_value = []
                for idx in range(0, len(key), self.page_size):
                    new_node.hash_value.append(
                        self.cache_controller.get_hash_str(
                            key[idx : idx + self.page_size],
                            prior_hash=last_hash,
                        )
                    )
                    last_hash = new_node.hash_value[-1]

Zhiqiang Xie's avatar
Zhiqiang Xie committed
736
737
            if self.cache_controller.write_policy != "write_back":
                self.inc_hit_count(new_node)
738
        return total_prefix_length
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763

    def _collect_leaves_device(self):
        def is_leaf(node):
            if node.evicted:
                return False
            if node == self.root_node:
                return False
            if len(node.children) == 0:
                return True
            for child in node.children.values():
                if not child.evicted:
                    return False
            return True

        ret_list = []
        stack = [self.root_node]
        while stack:
            cur_node = stack.pop()
            if is_leaf(cur_node):
                ret_list.append(cur_node)
            else:
                for cur_child in cur_node.children.values():
                    if not cur_child.evicted:
                        stack.append(cur_child)
        return ret_list