hiradix_cache.py 24.5 KB
Newer Older
1
2
import heapq
import logging
3
import threading
4
5
6
7
8
9
import time
from typing import List, Optional

import torch

from sglang.srt.managers.cache_controller import HiCacheController
10
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
12
from sglang.srt.mem_cache.memory_pool import (
13
14
    MHATokenToKVPool,
    MLATokenToKVPool,
15
16
    ReqToTokenPool,
)
17
18
19
20
from sglang.srt.mem_cache.memory_pool_host import (
    MHATokenToKVPoolHost,
    MLATokenToKVPoolHost,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
22
23
24
25
26
27
28
29
30

logger = logging.getLogger(__name__)


class HiRadixCache(RadixCache):

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
31
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
32
        tp_cache_group: torch.distributed.ProcessGroup,
33
        page_size: int,
34
        hicache_ratio: float,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
35
36
        hicache_size: int,
        hicache_write_policy: str,
37
        hicache_io_backend: str,
38
        hicache_storage_backend: Optional[str] = None,
39
    ):
40
41
        self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
        if isinstance(self.kv_cache, MHATokenToKVPool):
42
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
Zhiqiang Xie's avatar
Zhiqiang Xie committed
43
                self.kv_cache, hicache_ratio, hicache_size, page_size
44
            )
45
        elif isinstance(self.kv_cache, MLATokenToKVPool):
46
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(
Zhiqiang Xie's avatar
Zhiqiang Xie committed
47
                self.kv_cache, hicache_ratio, hicache_size, page_size
48
            )
49
        else:
50
            raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
51

52
        self.tp_group = tp_cache_group
53
        self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
54
55
56
        self.enable_storage = hicache_storage_backend is not None
        # todo: customizable storage prefetch threshold
        self.prefetch_threshold = 256
57
58

        self.load_cache_event = threading.Event()
59
        self.cache_controller = HiCacheController(
60
61
            token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
62
            page_size,
63
            self.tp_group,
64
            load_cache_event=self.load_cache_event,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
65
            write_policy=hicache_write_policy,
66
            io_backend=hicache_io_backend,
67
68
            storage_backend=hicache_storage_backend,
            prefetch_threshold=self.prefetch_threshold,
69
70
71
72
73
74
        )

        # record the nodes with ongoing write through
        self.ongoing_write_through = {}
        # record the node segments with ongoing load back
        self.ongoing_load_back = {}
75
76
77
        # record the ongoing prefetch requests
        self.ongoing_prefetch = {}
        self.ongoing_backup = {}
78
        # todo: dynamically adjust the threshold
Zhiqiang Xie's avatar
Zhiqiang Xie committed
79
80
81
        self.write_through_threshold = (
            1 if hicache_write_policy == "write_through" else 3
        )
82
83
84
        self.write_through_threshold_storage = (
            1 if hicache_write_policy == "write_through" else 3
        )
85
        self.load_back_threshold = 10
86
        super().__init__(
87
            req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
88
        )
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    def reset(self):
        TreeNode.counter = 0
        self.cache_controller.reset()
        self.token_to_kv_pool_host.clear()
        super().reset()

    def get_height(self, node: TreeNode):
        height = 0
        while node != self.root_node:
            node = node.parent
            height += 1
        return height

Zhiqiang Xie's avatar
Zhiqiang Xie committed
103
    def write_backup(self, node: TreeNode, write_back=False):
104
105
106
107
108
109
110
111
112
113
114
115
        host_indices = self.cache_controller.write(
            device_indices=node.value,
            node_id=node.id,
        )
        if host_indices is None:
            self.evict_host(len(node.value))
            host_indices = self.cache_controller.write(
                device_indices=node.value,
                node_id=node.id,
            )
        if host_indices is not None:
            node.host_value = host_indices
116
            assert len(node.host_value) > 0
117
            self.ongoing_write_through[node.id] = node
Zhiqiang Xie's avatar
Zhiqiang Xie committed
118
119
120
            if not write_back:
                # no need to lock nodes if write back
                self.inc_lock_ref(node)
121
        else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
122
            return 0
123
124
125

        return len(host_indices)

126
127
128
129
130
131
132
    def write_backup_storage(self, node: TreeNode):
        operation_id = self.cache_controller.write_storage(
            node.host_value, node.key, node.parent.get_last_hash_value()
        )
        self.ongoing_backup[operation_id] = node
        node.protect_host()

133
    def inc_hit_count(self, node: TreeNode):
134
        if self.cache_controller.write_policy == "write_back":
135
136
            return
        node.hit_count += 1
137
138
139
140
141
142
143
144
145
146
147
148
149

        if not node.backuped:
            if node.hit_count >= self.write_through_threshold:
                # write to host if the node is not backuped
                self.write_backup(node)
        else:
            if (
                self.enable_storage
                and (not node.backuped_storage)
                and node.hit_count >= self.write_through_threshold_storage
            ):
                # if the node is backuped on host memory but not on storage
                self.write_backup_storage(node)
150

Zhiqiang Xie's avatar
Zhiqiang Xie committed
151
152
153
154
155
156
157
    def writing_check(self, write_back=False):
        if write_back:
            # blocking till all write back complete
            while len(self.ongoing_write_through) > 0:
                ack_id = self.cache_controller.ack_write_queue.get()
                del self.ongoing_write_through[ack_id]
            return
158
159
160
        queue_size = torch.tensor(
            self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
        )
161
        if self.tp_world_size > 1:
162
163
164
165
166
167
168
169
170
171
            # synchrnoize TP workers to make the same update to radix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            ack_id = self.cache_controller.ack_write_queue.get()
            self.dec_lock_ref(self.ongoing_write_through[ack_id])
            del self.ongoing_write_through[ack_id]
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    def loading_check(self):
        while not self.cache_controller.ack_load_queue.empty():
            try:
                ack_id = self.cache_controller.ack_load_queue.get_nowait()
                start_node, end_node = self.ongoing_load_back[ack_id]
                self.dec_lock_ref(end_node)
                while end_node != start_node:
                    assert end_node.loading
                    end_node.loading = False
                    end_node = end_node.parent
                # clear the reference
                del self.ongoing_load_back[ack_id]
            except Exception:
                break

    def evictable_size(self):
        return self.evictable_size_

Lianmin Zheng's avatar
Lianmin Zheng committed
191
    def evict(self, num_tokens: int):
192
193
194
195
        leaves = self._collect_leaves_device()
        heapq.heapify(leaves)

        num_evicted = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
196
        write_back_nodes = []
197
198
199
200
201
202
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x.lock_ref > 0:
                continue

Zhiqiang Xie's avatar
Zhiqiang Xie committed
203
            if not x.backuped:
204
                if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
205
206
207
                    # write to host if the node is not backuped
                    num_evicted += self.write_backup(x, write_back=True)
                    write_back_nodes.append(x)
208
                else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
209
                    num_evicted += self._evict_regular(x)
210
            else:
Zhiqiang Xie's avatar
Zhiqiang Xie committed
211
                num_evicted += self._evict_backuped(x)
212
213

            for child in x.parent.children.values():
Zhiqiang Xie's avatar
Zhiqiang Xie committed
214
                if child in write_back_nodes:
215
216
217
218
219
220
221
222
                    continue
                if not child.evicted:
                    break
            else:
                # all children are evicted or no children
                heapq.heappush(leaves, x.parent)

        if self.cache_controller.write_policy == "write_back":
Zhiqiang Xie's avatar
Zhiqiang Xie committed
223
224
225
226
            self.writing_check(write_back=True)
            for node in write_back_nodes:
                assert node.backuped
                self._evict_backuped(node)
227

Zhiqiang Xie's avatar
Zhiqiang Xie committed
228
    def _evict_backuped(self, node: TreeNode):
229
230
231
232
233
234
235
        # evict a node already written to host
        num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
        assert num_evicted > 0
        self.evictable_size_ -= num_evicted
        node.value = None
        return num_evicted

Zhiqiang Xie's avatar
Zhiqiang Xie committed
236
    def _evict_regular(self, node: TreeNode):
237
        # evict a node not initiated write to host
238
        self.cache_controller.mem_pool_device_allocator.free(node.value)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        num_evicted = len(node.value)
        self._delete_leaf(node)
        return num_evicted

    def evict_host(self, num_tokens: int):
        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)
            if x == self.root_node:
                break
            # only evict the host value of evicted nodes
            if not x.evicted:
                continue

256
257
258
259
            # node is protected from eviction as it has ongoing prefetch or backup to storage
            if x.host_ref_counter > 0:
                continue

260
261
            num_evicted += self.cache_controller.evict_host(x.host_value)

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            for k, v in x.parent.children.items():
                if v == x:
                    break
            del x.parent.children[k]

            if len(x.parent.children) == 0 and x.parent.evicted:
                heapq.heappush(leaves, x.parent)

    def load_back(
        self, node: TreeNode, mem_quota: Optional[int] = None
    ) -> Optional[torch.Tensor]:
        # todo: more loading policies

        last_hit_node = node
        nodes_to_load = []
        while node.evicted:
            assert (
                node.backuped
            ), "No backup available on evicted nodes, should not happen"
            nodes_to_load.insert(0, node)
            node = node.parent
        else:
            ancester_node = node

        # protect the ancestor nodes from eviction
        delta = self.inc_lock_ref(ancester_node)

        # load it all or not at all
        host_indices = torch.cat([n.host_value for n in nodes_to_load])
        if len(host_indices) < self.load_back_threshold or (
            len(host_indices) > mem_quota + delta if mem_quota is not None else False
        ):
            # skip loading back if the total size is too small or exceeding the memory quota
            self.dec_lock_ref(ancester_node)
            return None

        device_indices = self.cache_controller.load(
            host_indices=host_indices, node_id=last_hit_node.id
        )
        if device_indices is None:
            self.evict(len(host_indices))
            device_indices = self.cache_controller.load(
                host_indices=host_indices, node_id=last_hit_node.id
            )
        self.dec_lock_ref(ancester_node)
        if device_indices is None:
            # no sufficient GPU memory to load back KV caches
            return None

        self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
        offset = 0
        for node in nodes_to_load:
            node.value = device_indices[offset : offset + len(node.host_value)]
            offset += len(node.host_value)
            node.loading = True
        self.evictable_size_ += len(device_indices)
        self.inc_lock_ref(last_hit_node)

        return device_indices

    def init_load_back(
        self,
        last_node: TreeNode,
325
        host_hit_length: int,
326
327
        mem_quota: Optional[int] = None,
    ):
328
        _ = host_hit_length  # unused, but kept for compatibility
329
330
331
332
333
334
        if last_node.evicted:
            loading_values = self.load_back(last_node, mem_quota)
            if loading_values is not None:
                logger.debug(
                    f"loading back {len(loading_values)} tokens for node {last_node.id}"
                )
335
                return loading_values, last_node
336
337
338
339

            while last_node.evicted:
                last_node = last_node.parent

340
341
342
343
        return (
            torch.empty((0,), dtype=torch.int64, device=self.device),
            last_node,
        )
344

345
    def ready_to_load_host_cache(self):
346
        producer_index = self.cache_controller.layer_done_counter.next_producer()
347
        self.load_cache_event.set()
348
        return producer_index
349

350
351
352
    def check_hicache_events(self):
        self.writing_check()
        self.loading_check()
353
354
355
356
357
358
359
360
        if self.enable_storage:
            self.check_revoked_prefetch()
            self.check_backup_progress()

    def check_revoked_prefetch(self):
        queue_size = torch.tensor(
            self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
        )
361
        if self.tp_world_size > 1:
362
363
364
365
366
367
368
369
370
            # synchrnoize TP workers to make the same update to hiradix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
            req_id = self.cache_controller.prefetch_revoke_queue.get()
            if req_id in self.ongoing_prefetch:
371
                last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
372
373
                last_host_node.release_host()
                del self.ongoing_prefetch[req_id]
374
375
376
            else:
                # the revoked operation already got terminated
                pass
377
378
379
380
381

    def check_backup_progress(self):
        queue_size = torch.tensor(
            self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
        )
382
        if self.tp_world_size > 1:
383
384
385
386
387
388
389
            # synchrnoize TP workers to make the same update to hiradix cache
            torch.distributed.all_reduce(
                queue_size,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
        for _ in range(queue_size.item()):
390
391
392
393
            ack_id, hash_value, completed_tokens = (
                self.cache_controller.ack_backup_queue.get()
            )
            host_node = self.ongoing_backup[ack_id]
394
395
396
            if completed_tokens == 0:
                host_node.hash_value = None
            elif completed_tokens < len(host_node.key):
397
398
399
                # backup is only partially successful, split the node
                new_node = self._split_node(host_node.key, host_node, completed_tokens)
                new_node.hash_value = hash_value
400
401
            else:
                host_node.hash_value = hash_value
402
            host_node.release_host()
403
404
405
406
407
408
409
410
411
412
413
414
            del self.ongoing_backup[ack_id]

    def check_prefetch_progress(self, req_id: str):
        if req_id not in self.ongoing_prefetch:
            # there is no ongoing prefetch for this request or it has been revoked
            return

        # todo: more policies for prefetch progress such as timeout
        # the current policy is to prefetch with best effort and terminate when queuing is over
        last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
            req_id
        ]
415

416
417
418
419
420
        completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
            operation
        )
        logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

421
422
        min_completed_tokens = completed_tokens
        if self.tp_world_size > 1:
423
            # synchrnoize TP workers to make the same update to hiradix cache
424
425
426
            completed_tokens_tensor = torch.tensor(
                min_completed_tokens, dtype=torch.int
            )
427
            torch.distributed.all_reduce(
428
                completed_tokens_tensor,
429
430
431
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
432
            min_completed_tokens = completed_tokens_tensor.item()
433
434
435
436
437
438
439
440
        fetched_token_ids = token_ids[:min_completed_tokens]
        written_indices = host_indices[:min_completed_tokens]
        matched_length = self._insert_helper_host(
            last_host_node,
            fetched_token_ids,
            written_indices,
            hash_value[:min_completed_tokens],
        )
441
442
        if len(written_indices):
            self.cache_controller.mem_pool_host.update_prefetch(written_indices)
443
444
445
446
447
448
449

        self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
        self.cache_controller.mem_pool_host.free(
            host_indices[min_completed_tokens:completed_tokens]
        )
        last_host_node.release_host()
        del self.ongoing_prefetch[req_id]
450
451

    def match_prefix(self, key: List[int], **kwargs):
452
453
        empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
        if self.disable or len(key) == 0:
454
455
456
457
458
459
            return MatchResult(
                device_indices=empty_value,
                last_device_node=self.root_node,
                last_host_node=self.root_node,
                host_hit_length=0,
            )
460
461
462
463

        if self.page_size != 1:
            page_aligned_len = len(key) // self.page_size * self.page_size
            key = key[:page_aligned_len]
464
465
466

        value, last_node = self._match_prefix_helper(self.root_node, key)
        if value:
467
            value = torch.cat(value)
468
        else:
469
            value = empty_value
470

471
472
        host_hit_length = 0
        last_host_node = last_node
473
        while last_node.evicted:
474
            host_hit_length += len(last_node.host_value)
475
476
            last_node = last_node.parent

477
478
479
480
481
482
        return MatchResult(
            device_indices=value,
            last_device_node=last_node,
            last_host_node=last_host_node,
            host_hit_length=host_hit_length,
        )
483

484
485
486
487
488
489
490
    def prefetch_from_storage(
        self,
        req_id: str,
        last_host_node: TreeNode,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
    ):
491
492
493
494
495
496
        # align the number of fetching tokens to the page size
        prefetch_length = len(new_input_tokens) - (
            len(new_input_tokens) % self.page_size
        )
        new_input_tokens = new_input_tokens[:prefetch_length]
        if not self.enable_storage or prefetch_length < self.prefetch_threshold:
497
498
499
            return

        last_host_node.protect_host()
500
        host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
501
        if host_indices is None:
502
503
            self.evict_host(prefetch_length)
            host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        if host_indices is None:
            last_host_node.release_host()
            # no sufficient host memory to prefetch
            return
        operation = self.cache_controller.prefetch(
            req_id, host_indices, new_input_tokens, last_hash
        )
        self.ongoing_prefetch[req_id] = (
            last_host_node,
            new_input_tokens,
            host_indices,
            operation,
        )

    def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
        node.last_access_time = time.monotonic()
        if len(key) == 0:
            return 0

        child_key = self.get_child_key_fn(key)

        matched_length = 0
        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
            node.last_access_time = time.monotonic()
            prefix_len = self.key_match_fn(node.key, key)
            key = key[prefix_len:]
            host_value = host_value[prefix_len:]
            hash_value = hash_value[prefix_len:]
            matched_length += prefix_len

            if prefix_len < len(node.key):
                new_node = self._split_node(node.key, node, prefix_len)
                node = new_node

            if len(key):
                child_key = self.get_child_key_fn(key)

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = None
            new_node.host_value = host_value
            new_node.hash_value = hash_value
            node.children[child_key] = new_node
        return matched_length

552
    def _match_prefix_helper(self, node: TreeNode, key: List):
553
        node.last_access_time = time.monotonic()
554
        child_key = self.get_child_key_fn(key)
555
        value = []
556
557
558

        while len(key) > 0 and child_key in node.children.keys():
            child = node.children[child_key]
559
            child.last_access_time = time.monotonic()
560
            prefix_len = self.key_match_fn(child.key, key)
561
562
563
564
            if prefix_len < len(child.key):
                new_node = self._split_node(child.key, child, prefix_len)
                if not new_node.evicted:
                    value.append(new_node.value)
565
566
                node = new_node
                break
567
568
569
            else:
                if not child.evicted:
                    value.append(child.value)
570
571
                node = child
                key = key[prefix_len:]
572
573
574
575

                if len(key):
                    child_key = self.get_child_key_fn(key)

576
        return value, node
577
578
579
580

    def _split_node(self, key, child: TreeNode, split_len: int):
        # child node split into new_node -> child
        new_node = TreeNode()
581
        new_node.children = {self.get_child_key_fn(key[split_len:]): child}
582
583
584
585
        new_node.parent = child.parent
        new_node.lock_ref = child.lock_ref
        new_node.key = child.key[:split_len]
        new_node.loading = child.loading
586
        new_node.hit_count = child.hit_count
587
588
589
590
591
592
593

        # split value and host value if exists
        if child.evicted:
            new_node.value = None
        else:
            new_node.value = child.value[:split_len]
            child.value = child.value[split_len:]
Zhiqiang Xie's avatar
Zhiqiang Xie committed
594
        if child.backuped:
595
596
597
598
            new_node.host_value = child.host_value[:split_len]
            child.host_value = child.host_value[split_len:]
        child.parent = new_node
        child.key = child.key[split_len:]
599
        new_node.parent.children[self.get_child_key_fn(key)] = new_node
600
601
602
        return new_node

    def _insert_helper(self, node: TreeNode, key: List, value):
603
        node.last_access_time = time.monotonic()
604
605
606
        if len(key) == 0:
            return 0

607
608
609
610
611
        child_key = self.get_child_key_fn(key)
        total_prefix_length = 0

        while len(key) > 0 and child_key in node.children.keys():
            node = node.children[child_key]
612
            node.last_access_time = time.monotonic()
613
            prefix_len = self.key_match_fn(node.key, key)
614

615
616
            if prefix_len == len(node.key):
                if node.evicted:
617
618
                    # change the reference if the node is evicted
                    # this often happens in the case of KV cache recomputation
619
620
621
                    node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(node.host_value)
                    self.evictable_size_ += len(node.value)
622
                else:
623
624
                    self.inc_hit_count(node)
                    total_prefix_length += prefix_len
625
            else:
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
                # partial match, split the node
                new_node = self._split_node(node.key, node, prefix_len)
                if new_node.evicted:
                    new_node.value = value[:prefix_len]
                    self.token_to_kv_pool_host.update_synced(new_node.host_value)
                    self.evictable_size_ += len(new_node.value)
                else:
                    self.inc_hit_count(new_node)
                    total_prefix_length += prefix_len
                node = new_node

            key = key[prefix_len:]
            value = value[prefix_len:]

            if len(key):
                child_key = self.get_child_key_fn(key)
642
643
644
645
646
647

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.key = key
            new_node.value = value
648
            node.children[child_key] = new_node
649
650
            self.evictable_size_ += len(value)

Zhiqiang Xie's avatar
Zhiqiang Xie committed
651
652
            if self.cache_controller.write_policy != "write_back":
                self.inc_hit_count(new_node)
653
        return total_prefix_length
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678

    def _collect_leaves_device(self):
        def is_leaf(node):
            if node.evicted:
                return False
            if node == self.root_node:
                return False
            if len(node.children) == 0:
                return True
            for child in node.children.values():
                if not child.evicted:
                    return False
            return True

        ret_list = []
        stack = [self.root_node]
        while stack:
            cur_node = stack.pop()
            if is_leaf(cur_node):
                ret_list.append(cur_node)
            else:
                for cur_child in cur_node.children.values():
                    if not cur_child.evicted:
                        stack.append(cur_child)
        return ret_list