hiradix_cache.py 30.5 KB
Newer Older
1
2
import heapq
import logging
3
import threading
4
import time
pansicheng's avatar
pansicheng committed
5
from queue import Queue
6
7
8
9
from typing import List, Optional

import torch

pansicheng's avatar
pansicheng committed
10
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
11
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
12
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
13
from sglang.srt.mem_cache.memory_pool import (
14
15
    MHATokenToKVPool,
    MLATokenToKVPool,
16
17
    ReqToTokenPool,
)
18
19
20
21
from sglang.srt.mem_cache.memory_pool_host import (
    MHATokenToKVPoolHost,
    MLATokenToKVPoolHost,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
from sglang.srt.metrics.collector import StorageMetricsCollector
24
25
26
27
28
29
30
31
32

logger = logging.getLogger(__name__)


class HiRadixCache(RadixCache):

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
33
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
34
        tp_cache_group: torch.distributed.ProcessGroup,
35
        page_size: int,
36
        hicache_ratio: float,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
37
38
        hicache_size: int,
        hicache_write_policy: str,
39
        hicache_io_backend: str,
40
        hicache_mem_layout: str,
41
        enable_metrics: bool,
42
        hicache_storage_backend: Optional[str] = None,
pansicheng's avatar
pansicheng committed
43
        hicache_storage_prefetch_policy: Optional[str] = "best_effort",
44
45
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
46
    ):
47
48
49
50
51
52
53
54

        if hicache_io_backend == "direct":
            if hicache_mem_layout == "page_first":
                hicache_mem_layout = "layer_first"
                logger.warning(
                    "Page first layout is not supported with direct IO backend, switching to layer first layout"
                )

55
56
        self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
        if isinstance(self.kv_cache, MHATokenToKVPool):
57
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
58
59
60
61
62
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
63
            )
64
        elif isinstance(self.kv_cache, MLATokenToKVPool):
65
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(
66
67
68
69
70
                self.kv_cache,
                hicache_ratio,
                hicache_size,
                page_size,
                hicache_mem_layout,
71
            )
72
        else:
73
            raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
74

75
        self.tp_group = tp_cache_group
76
        self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
77
        self.enable_storage = hicache_storage_backend is not None
78
79
        self.enable_storage_metrics = self.enable_storage and enable_metrics

80
        # todo: customizable storage prefetch threshold and timeout
81
        self.prefetch_threshold = 256
82
83
        self.prefetch_timeout = 3  # seconds
        self.prefetch_stop_policy = hicache_storage_prefetch_policy
84
85

        self.load_cache_event = threading.Event()
86
        self.cache_controller = HiCacheController(
87
88
            token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
89
            page_size,
90
            self.tp_group,
91
            load_cache_event=self.load_cache_event,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
92
            write_policy=hicache_write_policy,
93
            io_backend=hicache_io_backend,
94
95
            storage_backend=hicache_storage_backend,
            prefetch_threshold=self.prefetch_threshold,
96
97
            model_name=model_name,
            storage_backend_extra_config=storage_backend_extra_config,
98
        )
99
100
101
102
103
104
105
106
        if self.enable_storage_metrics:
            # TODO: support pp
            labels = {
                "storage_backend": hicache_storage_backend,
                "tp_rank": self.cache_controller.tp_rank,
                "dp_rank": self.cache_controller.dp_rank,
            }
            self.metrics_collector = StorageMetricsCollector(labels=labels)
107
108
109
110
111

        # record the nodes with ongoing write through
        self.ongoing_write_through = {}
        # record the node segments with ongoing load back
        self.ongoing_load_back = {}
112
113
114
        # record the ongoing prefetch requests
        self.ongoing_prefetch = {}
        self.ongoing_backup = {}
115
        # todo: dynamically adjust the threshold
Zhiqiang Xie's avatar
Zhiqiang Xie committed
116
        self.write_through_threshold = (
117
            1 if hicache_write_policy == "write_through" else 2
Zhiqiang Xie's avatar
Zhiqiang Xie committed
118
        )
119
        self.load_back_threshold = 10
120
        super().__init__(
121
            req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
122
        )
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    def reset(self):
        TreeNode.counter = 0
        self.cache_controller.reset()
        self.token_to_kv_pool_host.clear()
        super().reset()

    def get_height(self, node: TreeNode):
        height = 0
        while node != self.root_node:
            node = node.parent
            height += 1
        return height

137
    def clear_storage_backend(self) -> bool:
138
        if self.enable_storage:
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            try:
                # Check if the storage backend has a clear method (for nixl backends)
                if hasattr(self.cache_controller.storage_backend, "clear"):
                    self.cache_controller.storage_backend.clear()
                    logger.info(
                        "Hierarchical cache storage backend cleared successfully!"
                    )
                    return True
                else:
                    logger.warning(
                        f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
                    )
                    return False
            except Exception as e:
                logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
                return False
155
156
157
158
        else:
            logger.warning("Hierarchical cache storage backend is not enabled.")
            return False

Zhiqiang Xie's avatar
Zhiqiang Xie committed
159
    def write_backup(self, node: TreeNode, write_back=False):
160
161
162
163
164
165
166
167
168
169
170
171
        host_indices = self.cache_controller.write(
            device_indices=node.value,
            node_id=node.id,
        )
        if host_indices is None:
            self.evict_host(len(node.value))
            host_indices = self.cache_controller.write(
                device_indices=node.value,
                node_id=node.id,
            )
        if host_indices is not None:
            node.host_value = host_indices
172
            assert len(node.host_value) > 0
173
            self.ongoing_write_through[node.id] = node
Zhiqiang Xie's avatar
Zhiqiang Xie committed
174
175
176
            if not write_back:
                # no need to lock nodes if write back
                self.inc_lock_ref(node)
177
        else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
178
            return 0
179
180
181

        return len(host_indices)

182
183
    def write_backup_storage(self, node: TreeNode):
        operation_id = self.cache_controller.write_storage(
184
            node.host_value, node.key, node.hash_value
185
186
187
188
        )
        self.ongoing_backup[operation_id] = node
        node.protect_host()

189
190
191
    def _inc_hit_count(self, node: TreeNode, chunked=False):
        # skip the hit count update for chunked requests
        if self.cache_controller.write_policy == "write_back" or chunked:
192
193
            return
        node.hit_count += 1
194
195
196
197
198

        if not node.backuped:
            if node.hit_count >= self.write_through_threshold:
                # write to host if the node is not backuped
                self.write_backup(node)
199

Zhiqiang Xie's avatar
Zhiqiang Xie committed
200
201
202
203
204
205
206
    def writing_check(self, write_back=False):
        if write_back:
            # blocking till all write back complete
            while len(self.ongoing_write_through) > 0:
                ack_id = self.cache_controller.ack_write_queue.get()
                del self.ongoing_write_through[ack_id]
            return
207
208
209
        queue_size = torch.tensor(
            self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
        )
210
        if self.tp_world_size > 1:
211
212
213
214
215
216
217
218
            # synchrnoize TP workers to make the same update to radix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            ack_id = self.cache_controller.ack_write_queue.get()
219
220
            backuped_node = self.ongoing_write_through[ack_id]
            self.dec_lock_ref(backuped_node)
221
            del self.ongoing_write_through[ack_id]
222
223
            if self.enable_storage:
                self.write_backup_storage(backuped_node)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    def loading_check(self):
        while not self.cache_controller.ack_load_queue.empty():
            try:
                ack_id = self.cache_controller.ack_load_queue.get_nowait()
                start_node, end_node = self.ongoing_load_back[ack_id]
                self.dec_lock_ref(end_node)
                while end_node != start_node:
                    assert end_node.loading
                    end_node.loading = False
                    end_node = end_node.parent
                # clear the reference
                del self.ongoing_load_back[ack_id]
            except Exception:
                break

    def evictable_size(self):
        return self.evictable_size_

Lianmin Zheng's avatar
Lianmin Zheng committed
243
    def evict(self, num_tokens: int):
244
245
246
247
        leaves = self._collect_leaves_device()
        heapq.heapify(leaves)

        num_evicted = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
248
        write_back_nodes = []
249
250
251
252
253
254
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x.lock_ref > 0:
                continue

Zhiqiang Xie's avatar
Zhiqiang Xie committed
255
            if not x.backuped:
256
                if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
257
258
259
                    # write to host if the node is not backuped
                    num_evicted += self.write_backup(x, write_back=True)
                    write_back_nodes.append(x)
260
                else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
261
                    num_evicted += self._evict_regular(x)
262
            else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
263
                num_evicted += self._evict_backuped(x)
264
265

            for child in x.parent.children.values():
Zhiqiang Xie's avatar
Zhiqiang Xie committed
266
                if child in write_back_nodes:
267
268
269
270
271
272
273
274
                    continue
                if not child.evicted:
                    break
            else:
                # all children are evicted or no children
                heapq.heappush(leaves, x.parent)

        if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
275
276
277
278
            self.writing_check(write_back=True)
            for node in write_back_nodes:
                assert node.backuped
                self._evict_backuped(node)
279

Zhiqiang Xie's avatar
Zhiqiang Xie committed
280
    def _evict_backuped(self, node: TreeNode):
281
282
283
284
285
286
287
        # evict a node already written to host
        num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
        assert num_evicted > 0
        self.evictable_size_ -= num_evicted
        node.value = None
        return num_evicted

Zhiqiang Xie's avatar
Zhiqiang Xie committed
288
    def _evict_regular(self, node: TreeNode):
289
        # evict a node not initiated write to host
290
        self.cache_controller.mem_pool_device_allocator.free(node.value)
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        num_evicted = len(node.value)
        self._delete_leaf(node)
        return num_evicted

    def evict_host(self, num_tokens: int):
        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)
            if x == self.root_node:
                break
            # only evict the host value of evicted nodes
            if not x.evicted:
                continue

308
309
310
311
            # node is protected from eviction as it has ongoing prefetch or backup to storage
            if x.host_ref_counter > 0:
                continue

312
313
            num_evicted += self.cache_controller.evict_host(x.host_value)

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            for k, v in x.parent.children.items():
                if v == x:
                    break
            del x.parent.children[k]

            if len(x.parent.children) == 0 and x.parent.evicted:
                heapq.heappush(leaves, x.parent)

    def load_back(
        self, node: TreeNode, mem_quota: Optional[int] = None
    ) -> Optional[torch.Tensor]:
        # todo: more loading policies

        last_hit_node = node
        nodes_to_load = []
        while node.evicted:
            assert (
                node.backuped
            ), "No backup available on evicted nodes, should not happen"
            nodes_to_load.insert(0, node)
            node = node.parent
        else:
            ancester_node = node

        # protect the ancestor nodes from eviction
        delta = self.inc_lock_ref(ancester_node)

        # load it all or not at all
        host_indices = torch.cat([n.host_value for n in nodes_to_load])
        if len(host_indices) < self.load_back_threshold or (
            len(host_indices) > mem_quota + delta if mem_quota is not None else False
        ):
            # skip loading back if the total size is too small or exceeding the memory quota
            self.dec_lock_ref(ancester_node)
            return None

        device_indices = self.cache_controller.load(
            host_indices=host_indices, node_id=last_hit_node.id
        )
        if device_indices is None:
            self.evict(len(host_indices))
            device_indices = self.cache_controller.load(
                host_indices=host_indices, node_id=last_hit_node.id
            )
        self.dec_lock_ref(ancester_node)
        if device_indices is None:
            # no sufficient GPU memory to load back KV caches
            return None

        self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
        offset = 0
        for node in nodes_to_load:
            node.value = device_indices[offset : offset + len(node.host_value)]
            offset += len(node.host_value)
            node.loading = True
        self.evictable_size_ += len(device_indices)
        self.inc_lock_ref(last_hit_node)

        return device_indices

    def init_load_back(
        self,
        last_node: TreeNode,
377
        host_hit_length: int,
378
379
        mem_quota: Optional[int] = None,
    ):
380
        _ = host_hit_length  # unused, but kept for compatibility
381
382
383
384
385
386
        if last_node.evicted:
            loading_values = self.load_back(last_node, mem_quota)
            if loading_values is not None:
                logger.debug(
                    f"loading back {len(loading_values)} tokens for node {last_node.id}"
                )
387
                return loading_values, last_node
388
389
390
391

            while last_node.evicted:
                last_node = last_node.parent

392
393
394
395
        return (
            torch.empty((0,), dtype=torch.int64, device=self.device),
            last_node,
        )
396

397
    def ready_to_load_host_cache(self):
398
        producer_index = self.cache_controller.layer_done_counter.next_producer()
399
        self.load_cache_event.set()
400
        return producer_index
401

402
403
404
    def check_hicache_events(self):
        self.writing_check()
        self.loading_check()
405
        if self.enable_storage:
406
            self.drain_storage_control_queues()
407
408
409
410
        if self.enable_storage_metrics:
            self.metrics_collector.log_storage_metrics(
                self.cache_controller.storage_backend.get_stats()
            )
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

    def drain_storage_control_queues(self):
        """
        Combine prefetch revoke, backup ack, and host mem release checks
        to minimize TP synchronization and Python overhead.
        """
        cc = self.cache_controller

        qsizes = torch.tensor(
            [
                cc.prefetch_revoke_queue.qsize(),
                cc.ack_backup_queue.qsize(),
                cc.host_mem_release_queue.qsize(),
            ],
            dtype=torch.int,
426
        )
427
        if self.tp_world_size > 1:
428
            torch.distributed.all_reduce(
429
                qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
430
431
            )

432
433
434
435
436
437
438
439
440
441
442
443
444
445
        n_revoke, n_backup, n_release = map(int, qsizes.tolist())

        # process prefetch revokes
        for _ in range(n_revoke):
            req_id = cc.prefetch_revoke_queue.get()
            info = self.ongoing_prefetch.pop(req_id, None)
            if info is not None:
                last_host_node, token_ids, _, _ = info
                last_host_node.release_host()
                cc.prefetch_tokens_occupied -= len(token_ids)
            # else: the revoked operation already got terminated, nothing to do

        # process backup acks
        for _ in range(n_backup):
446
447
            operation = cc.ack_backup_queue.get()
            ack_id = operation.id
448
449
450
            entry = self.ongoing_backup.pop(ack_id, None)
            if entry is not None:
                entry.release_host()
451
452
            if self.enable_storage_metrics:
                self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
453
454
455
456
457
458
459
460

        # release host memory
        host_indices_list = []
        for _ in range(n_release):
            host_indices_list.append(cc.host_mem_release_queue.get())
        if host_indices_list:
            host_indices = torch.cat(host_indices_list, dim=0)
            cc.mem_pool_host.free(host_indices)
461

pansicheng's avatar
pansicheng committed
462
463
464
465
466
467
    def can_terminate_prefetch(self, operation: PrefetchOperation):
        can_terminate = True

        if self.prefetch_stop_policy == "best_effort":
            return can_terminate

468
469
470
471
472
473
        if len(operation.hash_value) == 0:
            completed = False
        else:
            completed = (
                operation.completed_tokens == len(operation.hash_value) * self.page_size
            )
pansicheng's avatar
pansicheng committed
474
475
476
477
478
479
480
481
482
483
484

        if self.prefetch_stop_policy == "wait_complete":
            can_terminate = completed
        elif self.prefetch_stop_policy == "timeout":
            can_terminate = completed or (
                time.monotonic() - operation.start_time > self.prefetch_timeout
            )
        else:
            # unknown prefetch stop policy, just return True
            return True

485
        operation_terminated = operation.is_terminated()
pansicheng's avatar
pansicheng committed
486
        if self.tp_world_size > 1:
487
488
489
490
            states = torch.tensor(
                [1 - int(can_terminate), int(operation_terminated)],
                dtype=torch.int,
            )
pansicheng's avatar
pansicheng committed
491
            torch.distributed.all_reduce(
492
493
                states,
                op=torch.distributed.ReduceOp.MAX,
pansicheng's avatar
pansicheng committed
494
495
                group=self.tp_group,
            )
496
497
498
499
500
            can_terminate = states[0].item() == 0
            operation_terminated = states[1].item() == 1
        # the operation should be terminated if it is already terminated on any TP worker
        # or it meets the termination condition on all TP workers
        can_terminate = can_terminate or operation_terminated
pansicheng's avatar
pansicheng committed
501
502
503
        return can_terminate

    def check_prefetch_progress(self, req_id: str) -> bool:
504
505
        if req_id not in self.ongoing_prefetch:
            # there is no ongoing prefetch for this request or it has been revoked
pansicheng's avatar
pansicheng committed
506
            return True
507
508
509

        # todo: more policies for prefetch progress such as timeout
        # the current policy is to prefetch with best effort and terminate when queuing is over
510
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
511
            req_id
512
        ]
513

Zhiqiang Xie's avatar
Zhiqiang Xie committed
514
515
516
517
        if operation.host_indices is None:
            # prefetch has not been issued due to insufficient host memory
            return True

pansicheng's avatar
pansicheng committed
518
519
520
        if not self.can_terminate_prefetch(operation):
            return False

521
522
523
524
525
        completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
            operation
        )
        logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

526
        min_completed_tokens = completed_tokens
527
        if self.tp_world_size > 1:
528
            # synchrnoize TP workers to make the same update to hiradix cache
529
530
531
            completed_tokens_tensor = torch.tensor(
                min_completed_tokens, dtype=torch.int
            )
532
            torch.distributed.all_reduce(
533
                completed_tokens_tensor,
534
535
536
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
537
            min_completed_tokens = completed_tokens_tensor.item()
538
539
540
541
542
543
        fetched_token_ids = token_ids[:min_completed_tokens]
        written_indices = host_indices[:min_completed_tokens]
        matched_length = self._insert_helper_host(
            last_host_node,
            fetched_token_ids,
            written_indices,
544
            hash_value[: min_completed_tokens // self.page_size],
545
        )
546
547
        if len(written_indices):
            self.cache_controller.mem_pool_host.update_prefetch(written_indices)
548
549

        self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
550
        self.cache_controller.append_host_mem_release(
551
552
553
            host_indices[min_completed_tokens:completed_tokens]
        )
        last_host_node.release_host()
554
        del self.ongoing_prefetch[req_id]
pansicheng's avatar
pansicheng committed
555
556
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

557
558
559
560
561
        if self.enable_storage_metrics:
            self.metrics_collector.log_prefetched_tokens(
                min_completed_tokens - matched_length
            )

pansicheng's avatar
pansicheng committed
562
        return True
563
564

    def match_prefix(self, key: List[int], **kwargs):
565
566
        empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
        if self.disable or len(key) == 0:
567
568
569
570
571
572
            return MatchResult(
                device_indices=empty_value,
                last_device_node=self.root_node,
                last_host_node=self.root_node,
                host_hit_length=0,
            )
573
574
575
576

        if self.page_size != 1:
            page_aligned_len = len(key) // self.page_size * self.page_size
            key = key[:page_aligned_len]
577
578
579

        value, last_node = self._match_prefix_helper(self.root_node, key)
        if value:
580
            value = torch.cat(value)
581
        else:
582
            value = empty_value
583

584
585
        host_hit_length = 0
        last_host_node = last_node
586
        while last_node.evicted:
587
            host_hit_length += len(last_node.host_value)
588
            last_node = last_node.parent
589
590
        while not last_host_node.backuped:
            last_host_node = last_host_node.parent
591

592
593
594
595
596
597
        return MatchResult(
            device_indices=value,
            last_device_node=last_node,
            last_host_node=last_host_node,
            host_hit_length=host_hit_length,
        )
598

599
600
601
602
603
604
605
    def prefetch_from_storage(
        self,
        req_id: str,
        last_host_node: TreeNode,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
    ):
606
607
608
609
610
        # align the number of fetching tokens to the page size
        prefetch_length = len(new_input_tokens) - (
            len(new_input_tokens) % self.page_size
        )
        new_input_tokens = new_input_tokens[:prefetch_length]
611
612
613
614
615
        if (
            not self.enable_storage
            or prefetch_length < self.prefetch_threshold
            or self.cache_controller.prefetch_rate_limited()
        ):
616
617
618
            return

        last_host_node.protect_host()
619
        host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
620
        if host_indices is None:
621
622
            self.evict_host(prefetch_length)
            host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
623
624
625
626
        if host_indices is None:
            last_host_node.release_host()
            # no sufficient host memory for prefetch
            return
627
628
629
630
631
632
633
634
635
        operation = self.cache_controller.prefetch(
            req_id, host_indices, new_input_tokens, last_hash
        )
        self.ongoing_prefetch[req_id] = (
            last_host_node,
            new_input_tokens,
            host_indices,
            operation,
        )
pansicheng's avatar
pansicheng committed
636
        self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651

    def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
        node.last_access_time = time.monotonic()
        if len(key) == 0:
            return 0

        child_key = self.get_child_key_fn(key)

        matched_length = 0
        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
            node.last_access_time = time.monotonic()
            prefix_len = self.key_match_fn(node.key, key)
            key = key[prefix_len:]
            host_value = host_value[prefix_len:]
652
            hash_value = hash_value[prefix_len // self.page_size :]
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
            matched_length += prefix_len

            if prefix_len < len(node.key):
                new_node = self._split_node(node.key, node, prefix_len)
                node = new_node

            if len(key):
                child_key = self.get_child_key_fn(key)

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = None
            new_node.host_value = host_value
            new_node.hash_value = hash_value
            node.children[child_key] = new_node
        return matched_length

672
    def _match_prefix_helper(self, node: TreeNode, key: List):
673
        node.last_access_time = time.monotonic()
674
        child_key = self.get_child_key_fn(key)
675
        value = []
676
677
678

        while len(key) > 0 and child_key in node.children.keys():
            child = node.children[child_key]
679
            child.last_access_time = time.monotonic()
680
            prefix_len = self.key_match_fn(child.key, key)
681
682
683
684
            if prefix_len < len(child.key):
                new_node = self._split_node(child.key, child, prefix_len)
                if not new_node.evicted:
                    value.append(new_node.value)
685
686
                node = new_node
                break
687
688
689
            else:
                if not child.evicted:
                    value.append(child.value)
690
691
                node = child
                key = key[prefix_len:]
692
693
694
695

                if len(key):
                    child_key = self.get_child_key_fn(key)

696
        return value, node
697
698
699
700

    def _split_node(self, key, child: TreeNode, split_len: int):
        # child node split into new_node -> child
        new_node = TreeNode()
701
        new_node.children = {self.get_child_key_fn(key[split_len:]): child}
702
703
704
705
        new_node.parent = child.parent
        new_node.lock_ref = child.lock_ref
        new_node.key = child.key[:split_len]
        new_node.loading = child.loading
706
        new_node.hit_count = child.hit_count
707
708
709
710
711
712
713

        # split value and host value if exists
        if child.evicted:
            new_node.value = None
        else:
            new_node.value = child.value[:split_len]
            child.value = child.value[split_len:]
Zhiqiang Xie's avatar
Zhiqiang Xie committed
714
        if child.backuped:
715
716
            new_node.host_value = child.host_value[:split_len]
            child.host_value = child.host_value[split_len:]
717
718
719
720

        if child.hash_value:
            new_node.hash_value = child.hash_value[: split_len // self.page_size]
            child.hash_value = child.hash_value[split_len // self.page_size :]
721
722
        child.parent = new_node
        child.key = child.key[split_len:]
723
        new_node.parent.children[self.get_child_key_fn(key)] = new_node
724
725
        return new_node

726
    def insert(self, key: List, value, chunked=False):
727
728
729
        if len(key) == 0:
            return 0

730
        node = self.root_node
731
732
733
734
735
        child_key = self.get_child_key_fn(key)
        total_prefix_length = 0

        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
736
            node.last_access_time = time.monotonic()
737
            prefix_len = self.key_match_fn(node.key, key)
738

739
740
            if prefix_len == len(node.key):
                if node.evicted:
741
742
                    # change the reference if the node is evicted
                    # this often happens in the case of KV cache recomputation
743
744
745
                    node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(node.host_value)
                    self.evictable_size_ += len(node.value)
746
                else:
747
                    self._inc_hit_count(node, chunked)
748
                    total_prefix_length += prefix_len
749
            else:
750
751
752
753
754
755
756
                # partial match, split the node
                new_node = self._split_node(node.key, node, prefix_len)
                if new_node.evicted:
                    new_node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(new_node.host_value)
                    self.evictable_size_ += len(new_node.value)
                else:
757
                    self._inc_hit_count(new_node, chunked)
758
759
760
761
762
763
764
765
                    total_prefix_length += prefix_len
                node = new_node

            key = key[prefix_len:]
            value = value[prefix_len:]

            if len(key):
                child_key = self.get_child_key_fn(key)
766
767
768
769
770
771

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = value
772
            node.children[child_key] = new_node
773
774
            self.evictable_size_ += len(value)

775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
            if self.enable_storage:
                last_hash = node.get_last_hash_value()
                assert (node == self.root_node) or (
                    last_hash is not None
                ), "Parent node must have a hash value with storage enabled"
                new_node.hash_value = []
                for idx in range(0, len(key), self.page_size):
                    new_node.hash_value.append(
                        self.cache_controller.get_hash_str(
                            key[idx : idx + self.page_size],
                            prior_hash=last_hash,
                        )
                    )
                    last_hash = new_node.hash_value[-1]

Zhiqiang Xie's avatar
Zhiqiang Xie committed
790
            if self.cache_controller.write_policy != "write_back":
791
                self._inc_hit_count(new_node, chunked)
792
        return total_prefix_length
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817

    def _collect_leaves_device(self):
        def is_leaf(node):
            if node.evicted:
                return False
            if node == self.root_node:
                return False
            if len(node.children) == 0:
                return True
            for child in node.children.values():
                if not child.evicted:
                    return False
            return True

        ret_list = []
        stack = [self.root_node]
        while stack:
            cur_node = stack.pop()
            if is_leaf(cur_node):
                ret_list.append(cur_node)
            else:
                for cur_child in cur_node.children.values():
                    if not cur_child.evicted:
                        stack.append(cur_child)
        return ret_list
818
819
820
821
822

    def release_aborted_request(self, rid: str):
        if rid not in self.ongoing_prefetch:
            return

823
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
824
825
826
827
828
829
830
        if operation.host_indices is None:
            return

        completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
        if self.tp_world_size > 1:
            torch.distributed.barrier(group=self.tp_group)
        last_host_node.release_host()
831
        del self.ongoing_prefetch[rid]
832
833
        self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)