cache_controller.py 28.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from __future__ import annotations

"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
17
import math
18
import threading
pansicheng's avatar
pansicheng committed
19
import time
20
from queue import Empty, Full, PriorityQueue, Queue
21
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
23
24

import torch

25
26
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig

27
28
29
if TYPE_CHECKING:
    from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
    from sglang.srt.mem_cache.memory_pool_host import HostKVCache
30

31
32
33
34
35
from sglang.srt.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.dp_attention import (
36
    get_attention_dp_rank,
37
38
39
40
    get_attention_tp_rank,
    get_attention_tp_size,
    is_dp_attention_enabled,
)
41
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
42

43
44
45
logger = logging.getLogger(__name__)


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class LayerLoadingEvent:
    def __init__(self, num_layers: int):
        self._num_layers = num_layers
        self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
        self.start_event = torch.cuda.Event()  # start event on controller stream

    def complete(self, layer_index: int):
        assert 0 <= layer_index < self._num_layers
        self.load_events[layer_index].record()

    def wait(self, layer_index: int):
        torch.cuda.current_stream().wait_event(self.load_events[layer_index])

    @property
    def finish_event(self):
        return self.load_events[-1]


64
class LayerDoneCounter:
65
    def __init__(self, num_layers: int):
66
67
68
        self.num_layers = num_layers
        # extra producer and consumer counters for overlap mode
        self.num_counters = 3
69
70
71
        self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
        self.producer_index = -1
        self.consumer_index = -1
72
73

    def update_producer(self):
74
75
76
77
78
79
        self.producer_index = (self.producer_index + 1) % self.num_counters
        assert self.events[
            self.producer_index
        ].finish_event.query(), (
            "Producer finish event should be ready before being reused."
        )
80
81
        return self.producer_index

82
    def set_consumer(self, index: int):
83
        self.consumer_index = index
84

85
86
87
88
    def wait_until(self, threshold: int):
        if self.consumer_index < 0:
            return
        self.events[self.consumer_index].wait(threshold)
89
90

    def reset(self):
91
92
        self.producer_index = -1
        self.consumer_index = -1
93
94


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class CacheOperation:

    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        device_indices: torch.Tensor,
        node_id: int,
        priority: Optional[int] = None,
    ):
        self.host_indices = host_indices
        self.device_indices = device_indices
        self.node_ids = [node_id]
        self.data = None

        self.id = CacheOperation.counter
        CacheOperation.counter += 1
        # default priority is the order of creation
        self.priority = priority if priority is not None else self.id

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    @staticmethod
    def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
        assert len(ops) > 0
        if len(ops) == 1:
            return ops[0]

        host_indices = torch.cat([op.host_indices for op in ops])
        device_indices = torch.cat([op.device_indices for op in ops])
        node_ids = []
        priority = min(op.priority for op in ops)
        for op in ops:
            node_ids.extend(op.node_ids)
        merged_op = CacheOperation(host_indices, device_indices, -1, priority)
        merged_op.node_ids = node_ids
        return merged_op

    def __lt__(self, other: CacheOperation):
        return self.priority < other.priority
134
135


136
137
138
139
class HiCacheAck(NamedTuple):
    start_event: torch.cuda.Event
    finish_event: torch.cuda.Event
    node_ids: List[int]
140
141
142
143
144
145
146


class TransferBuffer:
    """
    Overlapping buffer preparation and transfer operations to improve throughput.
    """

147
    def __init__(
148
        self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
149
150
    ) -> None:
        self.stop_event = stop_event
151
152
153
154
155
156
157
158
159
160
        self.buffers = Queue(maxsize=buffer_count)
        # todo: adjust the buffer size based on throughput profile of the system
        self.max_buffer_size = max_buffer_size

    def full(self) -> bool:
        return self.buffers.full()

    def empty(self) -> bool:
        return self.buffers.empty()

161
162
163
164
165
166
167
168
169
170
171
    def put(self, item, block=True, timeout=1) -> None:
        while not self.stop_event.is_set():
            try:
                self.buffers.put(item, block=block, timeout=timeout)
                break
            except Full:
                if not block:
                    break
                continue
            except Exception as e:
                logger.error(e)
172

173
    def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
174
        try:
175
176
177
            return self.buffers.get(block=block, timeout=timeout)
        except Empty:
            return None
178
179
180
        except Exception as e:
            logger.error(e)

181
182
183
    def clear(self):
        self.buffers.queue.clear()

184

185
186
187
188
189
190
191
192
class StorageOperation:
    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
193
        hash_value: Optional[List[str]] = None,
194
195
196
197
198
    ):
        self.host_indices = host_indices
        self.token_ids = token_ids
        self.last_hash = last_hash
        self.completed_tokens = 0
199
        self.hash_value = hash_value if hash_value is not None else []
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

        self.id = StorageOperation.counter
        StorageOperation.counter += 1

    def __lt__(self, other: "StorageOperation"):
        return self.id < other.id


class PrefetchOperation(StorageOperation):
    def __init__(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
    ):
        self.request_id = request_id

        self._lock = threading.Lock()
219
        self._terminated_flag = False
pansicheng's avatar
pansicheng committed
220
221
        self.start_time = time.monotonic()

222
223
224
225
        super().__init__(host_indices, token_ids, last_hash)

    def increment(self, num_tokens: int):
        with self._lock:
226
            if self._terminated_flag:
227
                return False
228
            self.completed_tokens += num_tokens
229
            return True
230

231
    def mark_terminate(self):
232
        with self._lock:
233
            self._terminated_flag = True
234

235
236
    def is_terminated(self) -> bool:
        return self._terminated_flag
237
238


239
240
241
242
class HiCacheController:

    def __init__(
        self,
243
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
244
        mem_pool_host: HostKVCache,
245
        page_size: int,
246
        tp_group: torch.distributed.ProcessGroup,
247
        load_cache_event: threading.Event,
248
        write_policy: str = "write_through_selective",
249
        io_backend: str = "",
250
251
        storage_backend: Optional[str] = None,
        prefetch_threshold: int = 256,
252
253
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
254
    ):
255
256
        self.mem_pool_device_allocator = token_to_kv_pool_allocator
        self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
257
258
        self.mem_pool_host = mem_pool_host
        self.write_policy = write_policy
259
        self.page_size = page_size
260
        self.io_backend = io_backend
261
        self.enable_storage = False
262

263
        if storage_backend is not None:
264
            self.storage_backend_type = storage_backend
265
266
267
            from sglang.srt.mem_cache.hicache_storage import get_hash_str

            self.get_hash_str = get_hash_str
268
269
270
            self.storage_config = self._generate_storage_config(
                model_name, storage_backend_extra_config
            )
271
            # for MLA models, only one rank needs to backup the KV cache
272
            self.backup_skip = (
273
                self.storage_config.is_mla_model
274
                # todo: load balancing
275
                and self.storage_config.tp_rank != 0
276
            )
277

278
279
            # Use storage backend factory for dynamic backend creation
            from sglang.srt.mem_cache.storage import StorageBackendFactory
280

281
282
283
            try:
                self.storage_backend = StorageBackendFactory.create_backend(
                    storage_backend, self.storage_config, self.mem_pool_host
284
                )
285
286
            except ValueError as e:
                raise ValueError(f"Failed to create storage backend: {e}") from e
287

pansicheng's avatar
pansicheng committed
288
289
            self.storage_backend.register_mem_pool_host(self.mem_pool_host)

290
291
292
            self.enable_storage = True
            # todo: threshold policy for prefetching
            self.prefetch_threshold = max(prefetch_threshold, self.page_size)
pansicheng's avatar
pansicheng committed
293
294
295
            self.prefetch_capacity_limit = int(
                0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
            )
296
297
            # granularity of batch storage IO operations, in number of pages
            self.storage_batch_size = 128
pansicheng's avatar
pansicheng committed
298
299
300
            # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
            self.prefetch_tokens_occupied = 0

301
302
303
304
305
306
307
            # create a new communication group for synchronizing storage operations across TP workers
            self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
            if self.tp_world_size > 1:
                group_ranks = torch.distributed.get_process_group_ranks(tp_group)
                self.prefetch_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
308

pansicheng's avatar
pansicheng committed
309
310
311
            # Select the get and set functions
            self.page_get_func = self._generic_page_get
            self.page_set_func = self._generic_page_set
pansicheng's avatar
pansicheng committed
312
313
314
315

            if self.storage_backend_type in ["hf3fs", "mooncake"]:
                self.page_get_func = self._page_get_zero_copy
                self.page_set_func = self._page_set_zero_copy
pansicheng's avatar
pansicheng committed
316

317
318
319
        self.device = self.mem_pool_device.device
        self.layer_num = self.mem_pool_device.layer_num
        self.layer_done_counter = LayerDoneCounter(self.layer_num)
320
321
        self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)

322
323
324
325
326
327
328
        if write_policy not in [
            "write_through",
            "write_through_selective",
            "write_back",
        ]:
            raise ValueError(f"Invalid write policy: {write_policy}")

329
330
331
332
333
        # self.write_queue = PriorityQueue[CacheOperation]()
        self.load_queue: List[CacheOperation] = []
        self.write_queue: List[CacheOperation] = []
        self.ack_load_queue: List[HiCacheAck] = []
        self.ack_write_queue: List[HiCacheAck] = []
334

335
336
337
338
339
        self.stop_event = threading.Event()
        self.write_buffer = TransferBuffer(self.stop_event)
        self.load_buffer = TransferBuffer(
            self.stop_event, buffer_count=10, max_buffer_size=100
        )
340
341
342
343

        self.write_stream = torch.cuda.Stream()
        self.load_stream = torch.cuda.Stream()

344
345
346
347
348
349
350
351
352
353
354
355
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_queue = Queue()
            self.backup_queue = Queue()

            self.prefetch_revoke_queue = Queue()
            self.ack_backup_queue = Queue()
356
            self.host_mem_release_queue = Queue()
357
358
359
360

            self.prefetch_thread.start()
            self.backup_thread.start()

361
362
363
364
365
366
367
368
369
    def _generate_storage_config(
        self,
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
    ):

        if is_dp_attention_enabled():
            self.tp_rank = get_attention_tp_rank()
            self.tp_size = get_attention_tp_size()
370
            self.dp_rank = get_attention_dp_rank()
371
372
373
        else:
            self.tp_rank = get_tensor_model_parallel_rank()
            self.tp_size = get_tensor_model_parallel_world_size()
374
            self.dp_rank = 0
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
        is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)

        # Parse extra config JSON if provided
        extra_config = None
        if storage_backend_extra_config:
            try:
                import json

                extra_config = json.loads(storage_backend_extra_config)
            except Exception as e:
                logger.error(f"Invalid backend extra config JSON: {e}")

        return HiCacheStorageConfig(
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            is_mla_model=is_mla_backend,
393
            is_page_first_layout=self.mem_pool_host.layout == "page_first",
394
395
396
397
            model_name=model_name,
            extra_config=extra_config,
        )

398
399
400
    def reset(self):
        self.stop_event.set()

401
402
        self.write_queue.clear()
        self.load_queue.clear()
403
404
        self.write_buffer.clear()
        self.load_buffer.clear()
405
406
        self.ack_write_queue.clear()
        self.ack_load_queue.clear()
407
408
409
410
411
412
413
        if self.enable_storage:
            self.prefetch_thread.join()
            self.backup_thread.join()
            self.prefetch_queue.queue.clear()
            self.backup_queue.queue.clear()
            self.prefetch_revoke_queue.queue.clear()
            self.ack_backup_queue.queue.clear()
414
415
416

        self.stop_event.clear()

417
418
419
420
421
422
423
424
425
426
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_thread.start()
            self.backup_thread.start()

427
428
429
430
    def write(
        self,
        device_indices: torch.Tensor,
        priority: Optional[int] = None,
431
        node_id: int = -1,
432
433
434
435
436
437
438
    ) -> Optional[torch.Tensor]:
        """
        Back up KV caches from device memory to host memory.
        """
        host_indices = self.mem_pool_host.alloc(len(device_indices))
        if host_indices is None:
            return None
439
        self.write_queue.append(
440
441
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
442
        self.start_writing()
443
444
        return host_indices

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    def start_writing(self) -> None:
        if len(self.write_queue) == 0:
            return

        op = CacheOperation.merge_ops(self.write_queue)
        host_indices, device_indices = self.move_indices(op)
        self.write_queue.clear()

        start_event = torch.cuda.Event()
        finish_event = torch.cuda.Event()

        start_event.record()
        with torch.cuda.stream(self.write_stream):
            start_event.wait(self.write_stream)
            self.mem_pool_host.backup_from_device_all_layer(
                self.mem_pool_device, host_indices, device_indices, self.io_backend
            )
            finish_event.record()
            # NOTE: We must save the host indices and device indices here,
            # this is because we need to guarantee that these tensors are
            # still alive when the write stream is executing.
            if host_indices.is_cuda:
                host_indices.record_stream(self.write_stream)
            if device_indices.is_cuda:
                device_indices.record_stream(self.write_stream)

        self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))

473
474
475
476
    def load(
        self,
        host_indices: torch.Tensor,
        priority: Optional[int] = None,
477
        node_id: int = -1,
478
479
480
481
    ) -> Optional[torch.Tensor]:
        """
        Load KV caches from host memory to device memory.
        """
482
        device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
483
484
        if device_indices is None:
            return None
485
        self.load_queue.append(
486
487
488
489
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return device_indices

490
491
    def move_indices(self, op: CacheOperation):
        host_indices, device_indices = op.host_indices, op.device_indices
492
493
        # move indices to GPU if using kernels, to host if using direct indexing
        if self.io_backend == "kernel":
494
495
496
            if not host_indices.is_cuda:
                host_indices = host_indices.to(self.device, non_blocking=True)
            return host_indices, device_indices
497
        elif self.io_backend == "direct":
498
499
500
501
502
503
            if self.mem_pool_host.layout == "layer_first":
                device_indices = device_indices.cpu()
                host_indices, idx = host_indices.sort()
                return host_indices, device_indices.index_select(0, idx)
            elif self.mem_pool_host.layout == "page_first_direct":
                return host_indices, device_indices.cpu()
504
505
506
        else:
            raise ValueError(f"Unsupported io backend")

507
508
509
    def start_loading(self) -> int:
        if len(self.load_queue) == 0:
            return -1
510

511
512
513
514
515
516
        producer_id = self.layer_done_counter.update_producer()
        op = CacheOperation.merge_ops(self.load_queue)
        host_indices, device_indices = self.move_indices(op)
        self.load_queue.clear()
        producer_event = self.layer_done_counter.events[producer_id]
        producer_event.start_event.record()
517

518
519
520
        with torch.cuda.stream(self.load_stream):
            producer_event.start_event.wait(self.load_stream)
            for i in range(self.layer_num):
521
522
                self.mem_pool_host.load_to_device_per_layer(
                    self.mem_pool_device,
523
524
525
526
527
                    host_indices,
                    device_indices,
                    i,
                    self.io_backend,
                )
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
                producer_event.complete(i)
            # NOTE: We must save the host indices and device indices here,
            # this is because we need to guarantee that these tensors are
            # still alive when the load stream is executing.
            if host_indices.is_cuda:
                host_indices.record_stream(self.load_stream)
            if device_indices.is_cuda:
                device_indices.record_stream(self.load_stream)

        self.ack_load_queue.append(
            HiCacheAck(
                start_event=producer_event.start_event,
                finish_event=producer_event.finish_event,
                node_ids=op.node_ids,
            )
        )
        return producer_id
545

546
547
548
    def evict_device(self, device_indices: torch.Tensor) -> int:
        self.mem_pool_device_allocator.free(device_indices)
        return len(device_indices)
549
550
551
552
553

    def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
        if not backup_only:
            raise ValueError("Other eviction policies are not supported yet.")

554
555
        self.mem_pool_host.free(host_indices)
        return len(host_indices)
556
557
558
559
560
561
562

    def prefetch(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
pansicheng's avatar
pansicheng committed
563
    ) -> PrefetchOperation:
564
565
566
567
568
569
570
571
572
573
        """
        Prefetch KV caches from storage backend to host memory.
        """
        operation = PrefetchOperation(
            request_id, host_indices, new_input_tokens, last_hash
        )
        self.prefetch_queue.put(operation)
        return operation

    def terminate_prefetch(self, operation):
574
        operation.mark_terminate()
575
576
        return operation.completed_tokens, operation.hash_value

577
    def append_host_mem_release(self, host_indices: torch.Tensor):
578
579
580
581
582
        if host_indices.numel() == 0:
            return
        pages = host_indices.split(self.mem_pool_host.page_size)
        for page in pages:
            self.host_mem_release_queue.put(page)
583

pansicheng's avatar
pansicheng committed
584
585
586
587
588
589
590
591
592
593
594
    def _page_get_zero_copy(self, operation, hash_values, host_indices):
        results = self.storage_backend.batch_get_v1(hash_values, host_indices)
        inc = 0
        for i in range(len(hash_values)):
            if not results[i]:
                logger.warning(
                    f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
                )
                break
            inc += self.page_size
        operation.increment(inc)
595

pansicheng's avatar
pansicheng committed
596
    # todo: deprecate
597
    def _generic_page_get(self, operation, hash_values, host_indices):
598
599
600
        dummy_page_dst = [
            self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
        ]
601
602
603
604
605
        page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
        if page_data is None:
            return
        for i in range(len(hash_values)):
            if page_data[i] is None:
pansicheng's avatar
pansicheng committed
606
                logger.warning(
607
                    f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
pansicheng's avatar
pansicheng committed
608
609
                )
                break
610
611
612
613
614
615
616
617
            # Must set the data before increasing the completed tokens.
            # Otherwise this page may be read before being set.
            self.mem_pool_host.set_from_flat_data_page(
                host_indices[i * self.page_size],
                page_data[i],
            )
            if not operation.increment(self.page_size):
                break  # Operation terminated by controller
618
619
620

    def _page_transfer(self, operation):
        # Transfer batch by batch
621
622
        for i in range(0, len(operation.hash_value), self.storage_batch_size):
            batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
623
624
            batch_host_indices = operation.host_indices[
                i * self.page_size : (i + len(batch_hashes)) * self.page_size
pansicheng's avatar
pansicheng committed
625
            ]
626
627
            prev_completed_tokens = operation.completed_tokens
            # Get one batch token, and update the completed_tokens if succeed
pansicheng's avatar
pansicheng committed
628
            self.page_get_func(operation, batch_hashes, batch_host_indices)
629
630
631
632
633
            # Check termination
            if (
                operation.completed_tokens
                != prev_completed_tokens + len(batch_hashes) * self.page_size
            ):
634
                operation.mark_terminate()
635
636
                break  # Some operations fail or operation terminated by controller
        # release pre-allocated memory
637
638
639
        self.append_host_mem_release(
            operation.host_indices[operation.completed_tokens :]
        )
640

641
642
643
644
645
646
647
    def prefetch_io_aux_func(self):
        """
        Auxiliary function conducting IO operations for prefetching.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.prefetch_buffer.get(block=True, timeout=1)
648
                self._page_transfer(operation)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
649
                # operation terminated by controller, release pre-allocated memory
650
                self.append_host_mem_release(
Zhiqiang Xie's avatar
Zhiqiang Xie committed
651
652
                    operation.host_indices[operation.completed_tokens :]
                )
653
654
655
            except Empty:
                continue

656
    def prefetch_rate_limited(self) -> bool:
pansicheng's avatar
pansicheng committed
657
658
659
660
661
        """
        Rate limit the prefetching operations to avoid overwhelming the storage backend.
        """
        # cancel prefetch if too much memory is occupied
        if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
662
            return True
pansicheng's avatar
pansicheng committed
663
        # todo: more sophisticated rate limiting based on storage backend performance
664
        return False
pansicheng's avatar
pansicheng committed
665

666
    def _storage_hit_query(self, operation) -> tuple[list[str], int]:
667
668
669
670
671
        last_hash = operation.last_hash
        tokens_to_fetch = operation.token_ids

        storage_query_count = 0
        hash_value = []
672
673
674
675
676
677

        for start in range(
            0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
        ):
            end = min(
                start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
678
            )
679
680
681
682
683
684
685
            batch_tokens = tokens_to_fetch[start:end]
            batch_hashes = []
            for i in range(0, len(batch_tokens), self.page_size):
                last_hash = self.get_hash_str(
                    batch_tokens[i : i + self.page_size], last_hash
                )
                batch_hashes.append(last_hash)
pansicheng's avatar
pansicheng committed
686
            hit_page_num = self.storage_backend.batch_exists(batch_hashes)
687
688
689
690
691
            hash_value.extend(batch_hashes[:hit_page_num])
            storage_query_count += hit_page_num * self.page_size
            if hit_page_num < len(batch_hashes):
                break
        return hash_value, storage_query_count
692

693
694
695
696
697
698
699
700
701
702
703
704
705
    def prefetch_thread_func(self):
        """
        Manage prefetching operations from storage backend to host memory.
        """
        self.prefetch_buffer = Queue()
        aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
        aux_thread.start()
        while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
            try:
                operation = self.prefetch_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

706
                hash_value, storage_hit_count = self._storage_hit_query(operation)
707
708
709
710
711
712
713
                if self.tp_world_size > 1:
                    storage_hit_count_tensor = torch.tensor(
                        storage_hit_count, dtype=torch.int
                    )
                    torch.distributed.all_reduce(
                        storage_hit_count_tensor,
                        op=torch.distributed.ReduceOp.MIN,
714
                        group=self.prefetch_tp_group,
715
716
717
                    )
                    storage_hit_count = storage_hit_count_tensor.item()

718
719
720
                if storage_hit_count < self.prefetch_threshold:
                    # not to prefetch if not enough benefits
                    self.prefetch_revoke_queue.put(operation.request_id)
721
                    self.append_host_mem_release(operation.host_indices)
722
723
724
                    logger.debug(
                        f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
                    )
725
                else:
726
727
728
729
                    operation.hash_value = hash_value[
                        : (storage_hit_count // self.page_size)
                    ]
                    # free the pre-allocated memory for pages that are not hit
730
731
732
                    self.append_host_mem_release(
                        operation.host_indices[storage_hit_count:]
                    )
733
                    operation.host_indices = operation.host_indices[:storage_hit_count]
734
                    logger.debug(
735
                        f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
736
737
738
739
740
741
742
743
744
745
                    )
                    self.prefetch_buffer.put(operation)

            except Empty:
                continue

    def write_storage(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
746
        hash_value: Optional[List[str]] = None,
747
748
749
750
    ) -> int:
        """
        Write KV caches from host memory to storage backend.
        """
751
        operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
752
753
754
        self.backup_queue.put(operation)
        return operation.id

pansicheng's avatar
pansicheng committed
755
    # todo: deprecate
756
757
    def _generic_page_set(self, hash_values, host_indices) -> bool:
        data = [
pansicheng's avatar
pansicheng committed
758
            self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
759
760
761
762
            for i in range(len(hash_values))
        ]
        return self.storage_backend.batch_set(hash_values, data)

pansicheng's avatar
pansicheng committed
763
764
    def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
        return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
765
766
767
768

    # Backup batch by batch
    def _page_backup(self, operation):
        # Backup batch by batch
769
770
        for i in range(0, len(operation.hash_value), self.storage_batch_size):
            batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
771
772
            batch_host_indices = operation.host_indices[
                i * self.page_size : (i + len(batch_hashes)) * self.page_size
773
            ]
774
775
            # Set one batch token, and record if success.
            # todo: allow partial success
pansicheng's avatar
pansicheng committed
776
            success = self.page_set_func(batch_hashes, batch_host_indices)
777
            if not success:
778
779
                logger.warning(
                    f"Write page to storage: {len(batch_hashes)} pages failed."
780
                )
781
782
                break
            operation.completed_tokens += self.page_size * len(batch_hashes)
783

784
785
786
787
788
789
790
791
792
793
    def backup_thread_func(self):
        """
        Manage backup operations from host memory to storage backend.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.backup_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

794
                if not self.backup_skip:
795
                    self._page_backup(operation)
796
                self.ack_backup_queue.put(operation)
797
798
799

            except Empty:
                continue