cache_controller.py 32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from __future__ import annotations

"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
17
import math
18
import threading
pansicheng's avatar
pansicheng committed
19
import time
20
from queue import Empty, Full, PriorityQueue, Queue
21
from typing import TYPE_CHECKING, List, Optional
22
23
24

import torch

25
26
27
if TYPE_CHECKING:
    from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
    from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28

29

30
31
32
logger = logging.getLogger(__name__)


33
34
class LayerDoneCounter:
    def __init__(self, num_layers):
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        self.num_layers = num_layers
        # extra producer and consumer counters for overlap mode
        self.num_counters = 3
        self.counters = [num_layers] * self.num_counters
        self.conditions = [threading.Condition() for _ in range(self.num_counters)]
        self.producer_index = 0
        self.consumer_index = 0

    def next_producer(self):
        return (self.producer_index + 1) % self.num_counters

    def update_producer(self):
        self.producer_index = self.next_producer()
        return self.producer_index

    def set_consumer(self, index):
        self.consumer_index = index
52
53

    def increment(self):
54
55
56
        with self.conditions[self.producer_index]:
            self.counters[self.producer_index] += 1
            self.conditions[self.producer_index].notify_all()
57
58

    def wait_until(self, threshold):
59
60
61
        with self.conditions[self.consumer_index]:
            while self.counters[self.consumer_index] <= threshold:
                self.conditions[self.consumer_index].wait()
62
63

    def reset(self):
64
65
        with self.conditions[self.producer_index]:
            self.counters[self.producer_index] = 0
66
67


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class CacheOperation:

    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        device_indices: torch.Tensor,
        node_id: int,
        priority: Optional[int] = None,
    ):
        self.host_indices = host_indices
        self.device_indices = device_indices
        self.node_ids = [node_id]
        self.data = None

        self.id = CacheOperation.counter
        CacheOperation.counter += 1
        # default priority is the order of creation
        self.priority = priority if priority is not None else self.id

    def merge(self, other: "CacheOperation") -> None:
        # multiple operations can be merged into a single operation for batch processing
        self.host_indices = torch.cat([self.host_indices, other.host_indices])
        self.device_indices = torch.cat([self.device_indices, other.device_indices])
        self.priority = min(self.priority, other.priority)
        self.node_ids.extend(other.node_ids)

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    def split(self, factor) -> List["CacheOperation"]:
        # split an operation into smaller operations to reduce the size of intermediate buffers
        if factor <= 1:
            return [self]

        chunk_size = math.ceil(len(self.host_indices) / factor)
        split_ops = []
        for i in range(0, len(self.host_indices), chunk_size):
            split_ops.append(
                CacheOperation(
                    host_indices=self.host_indices[i : i + chunk_size],
                    device_indices=self.device_indices[i : i + chunk_size],
                    node_id=0,
                )
            )
        # Inherit the node_ids on the final chunk
        if split_ops:
            split_ops[-1].node_ids = self.node_ids

        return split_ops

117
118
119
120
121
122
123
124
125
    def __lt__(self, other: "CacheOperation"):
        return self.priority < other.priority


class TransferBuffer:
    """
    Overlapping buffer preparation and transfer operations to improve throughput.
    """

126
    def __init__(
127
        self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
128
129
    ) -> None:
        self.stop_event = stop_event
130
131
132
133
134
135
136
137
138
139
        self.buffers = Queue(maxsize=buffer_count)
        # todo: adjust the buffer size based on throughput profile of the system
        self.max_buffer_size = max_buffer_size

    def full(self) -> bool:
        return self.buffers.full()

    def empty(self) -> bool:
        return self.buffers.empty()

140
141
142
143
144
145
146
147
148
149
150
    def put(self, item, block=True, timeout=1) -> None:
        while not self.stop_event.is_set():
            try:
                self.buffers.put(item, block=block, timeout=timeout)
                break
            except Full:
                if not block:
                    break
                continue
            except Exception as e:
                logger.error(e)
151

152
    def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
153
        try:
154
155
156
            return self.buffers.get(block=block, timeout=timeout)
        except Empty:
            return None
157
158
159
        except Exception as e:
            logger.error(e)

160
161
162
    def clear(self):
        self.buffers.queue.clear()

163

164
165
166
167
168
169
170
171
class StorageOperation:
    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
172
        hash_value: Optional[List[str]] = None,
173
174
175
176
177
    ):
        self.host_indices = host_indices
        self.token_ids = token_ids
        self.last_hash = last_hash
        self.completed_tokens = 0
178
        self.hash_value = hash_value if hash_value is not None else []
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        self.id = StorageOperation.counter
        StorageOperation.counter += 1

    def __lt__(self, other: "StorageOperation"):
        return self.id < other.id


class PrefetchOperation(StorageOperation):
    def __init__(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
    ):
        self.request_id = request_id

        self._done_flag = False
        self._lock = threading.Lock()

pansicheng's avatar
pansicheng committed
200
201
        self.start_time = time.monotonic()

202
203
204
205
206
        super().__init__(host_indices, token_ids, last_hash)

    def increment(self, num_tokens: int):
        with self._lock:
            if self._done_flag:
207
                return False
208
            self.completed_tokens += num_tokens
209
            return True
210
211
212
213
214
215
216
217
218

    def mark_done(self):
        with self._lock:
            self._done_flag = True

    def is_done(self) -> bool:
        return self._done_flag


219
220
221
222
class HiCacheController:

    def __init__(
        self,
223
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
224
        mem_pool_host: HostKVCache,
225
        page_size: int,
226
        tp_group: torch.distributed.ProcessGroup,
227
        load_cache_event: threading.Event = None,
228
        write_policy: str = "write_through_selective",
229
        io_backend: str = "",
230
231
        storage_backend: Optional[str] = None,
        prefetch_threshold: int = 256,
232
    ):
233
234
        self.mem_pool_device_allocator = token_to_kv_pool_allocator
        self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
235
236
        self.mem_pool_host = mem_pool_host
        self.write_policy = write_policy
237
        self.page_size = page_size
238
        self.io_backend = io_backend
239

240
241
242
        self.enable_storage = False
        # todo: move backend initialization to storage backend module
        if storage_backend is not None:
243
            self.storage_backend_type = storage_backend
244
            from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
245

246
247
            if storage_backend == "file":
                self.storage_backend = HiCacheFile()
248
                self.get_hash_str = get_hash_str
249
            elif storage_backend == "nixl":
250
                from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
251
252
253

                self.storage_backend = HiCacheNixl()
                self.get_hash_str = get_hash_str
254
            elif storage_backend == "mooncake":
255
                from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
256
257
258
259
                    MooncakeStore,
                    get_hash_str_mooncake,
                )

260
261
262
                self.storage_backend = MooncakeStore()
                self.get_hash_str = get_hash_str_mooncake
                self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
263
                assert self.mem_pool_host.layout == "page_first"
264
265
            elif storage_backend == "hf3fs":
                from sglang.srt.distributed import get_tensor_model_parallel_rank
266
267
268
                from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
                    HiCacheHF3FS,
                )
269
270

                rank = get_tensor_model_parallel_rank()
pansicheng's avatar
pansicheng committed
271
272
273
274
275
276
277
278
                if self.mem_pool_host.layout == "page_first":
                    bytes_per_page = (
                        mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
                    )
                elif self.mem_pool_host.layout == "layer_first":
                    bytes_per_page = (
                        mem_pool_host.get_size_per_token() * mem_pool_host.page_size
                    )
279
280
281
282
                dtype = mem_pool_host.dtype
                self.storage_backend = HiCacheHF3FS.from_env_config(
                    rank, bytes_per_page, dtype
                )
283
                self.get_hash_str = get_hash_str
284
285
286
287
            else:
                raise NotImplementedError(
                    f"Unsupported storage backend: {storage_backend}"
                )
288
289
290
            self.enable_storage = True
            # todo: threshold policy for prefetching
            self.prefetch_threshold = max(prefetch_threshold, self.page_size)
pansicheng's avatar
pansicheng committed
291
292
293
294
295
296
            self.prefetch_capacity_limit = int(
                0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
            )
            # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
            self.prefetch_tokens_occupied = 0

297
298
299
300
301
302
303
            # create a new communication group for synchronizing storage operations across TP workers
            self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
            if self.tp_world_size > 1:
                group_ranks = torch.distributed.get_process_group_ranks(tp_group)
                self.prefetch_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
pansicheng's avatar
pansicheng committed
304
305
306
                self.prefetch_io_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
307
308
309
                self.backup_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
310

311
312
313
314
        self.load_cache_event = load_cache_event
        self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
        self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)

315
316
317
318
319
320
321
322
323
324
325
326
327
        if write_policy not in [
            "write_through",
            "write_through_selective",
            "write_back",
        ]:
            raise ValueError(f"Invalid write policy: {write_policy}")

        self.write_queue = PriorityQueue()
        self.load_queue = PriorityQueue()

        self.ack_write_queue = Queue()
        self.ack_load_queue = Queue()

328
329
330
331
332
        self.stop_event = threading.Event()
        self.write_buffer = TransferBuffer(self.stop_event)
        self.load_buffer = TransferBuffer(
            self.stop_event, buffer_count=10, max_buffer_size=100
        )
333
334
335
336
337

        self.write_stream = torch.cuda.Stream()
        self.load_stream = torch.cuda.Stream()

        self.write_thread = threading.Thread(
338
            target=self.write_thread_func_direct, daemon=True
339
340
        )
        self.load_thread = threading.Thread(
341
            target=self.load_thread_func_layer_by_layer, daemon=True
342
        )
343

344
345
346
        self.write_thread.start()
        self.load_thread.start()

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_queue = Queue()
            self.backup_queue = Queue()

            self.prefetch_revoke_queue = Queue()
            self.ack_backup_queue = Queue()

            self.prefetch_thread.start()
            self.backup_thread.start()

363
364
365
366
367
368
369
370
371
372
373
    def reset(self):
        self.stop_event.set()
        self.write_thread.join()
        self.load_thread.join()

        self.write_queue.queue.clear()
        self.load_queue.queue.clear()
        self.write_buffer.clear()
        self.load_buffer.clear()
        self.ack_write_queue.queue.clear()
        self.ack_load_queue.queue.clear()
374
375
376
377
378
379
380
        if self.enable_storage:
            self.prefetch_thread.join()
            self.backup_thread.join()
            self.prefetch_queue.queue.clear()
            self.backup_queue.queue.clear()
            self.prefetch_revoke_queue.queue.clear()
            self.ack_backup_queue.queue.clear()
381
382

        self.write_thread = threading.Thread(
383
            target=self.write_thread_func_direct, daemon=True
384
385
        )
        self.load_thread = threading.Thread(
386
            target=self.load_thread_func_layer_by_layer, daemon=True
387
388
389
390
391
        )
        self.stop_event.clear()
        self.write_thread.start()
        self.load_thread.start()

392
393
394
395
396
397
398
399
400
401
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_thread.start()
            self.backup_thread.start()

402
403
404
405
406
407
408
409
410
411
412
413
    def write(
        self,
        device_indices: torch.Tensor,
        priority: Optional[int] = None,
        node_id: int = 0,
    ) -> Optional[torch.Tensor]:
        """
        Back up KV caches from device memory to host memory.
        """
        host_indices = self.mem_pool_host.alloc(len(device_indices))
        if host_indices is None:
            return None
414
        self.mem_pool_host.protect_write(host_indices)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
415
        torch.cuda.current_stream().synchronize()
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        self.write_queue.put(
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return host_indices

    def load(
        self,
        host_indices: torch.Tensor,
        priority: Optional[int] = None,
        node_id: int = 0,
    ) -> Optional[torch.Tensor]:
        """
        Load KV caches from host memory to device memory.
        """
430
        device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
431
432
        if device_indices is None:
            return None
433
        self.mem_pool_host.protect_load(host_indices)
434
435
        # to ensure the device indices are ready before accessed by another CUDA stream
        torch.cuda.current_stream().synchronize()
436
437
438
439
440
        self.load_queue.put(
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return device_indices

441
442
443
444
445
    def move_indices(self, host_indices, device_indices):
        # move indices to GPU if using kernels, to host if using direct indexing
        if self.io_backend == "kernel":
            return host_indices.to(self.mem_pool_device.device), device_indices
        elif self.io_backend == "direct":
446
447
448
            device_indices = device_indices.cpu()
            host_indices, idx = host_indices.sort()
            return host_indices, device_indices.index_select(0, idx)
449
450
451
        else:
            raise ValueError(f"Unsupported io backend")

452
453
454
455
    def write_thread_func_direct(self):
        """
        Directly write through KV caches to host memory without buffering.
        """
456
457
458
459
        torch.cuda.set_stream(self.write_stream)
        while not self.stop_event.is_set():
            try:
                operation = self.write_queue.get(block=True, timeout=1)
460
461
462
                host_indices, device_indices = self.move_indices(
                    operation.host_indices, operation.device_indices
                )
463
464
                self.mem_pool_host.backup_from_device_all_layer(
                    self.mem_pool_device, host_indices, device_indices, self.io_backend
465
466
467
468
469
470
471
472
473
474
                )
                self.write_stream.synchronize()
                self.mem_pool_host.complete_io(operation.host_indices)
                for node_id in operation.node_ids:
                    if node_id != 0:
                        self.ack_write_queue.put(node_id)
            except Empty:
                continue
            except Exception as e:
                logger.error(e)
475

476
477
478
479
    def load_thread_func_layer_by_layer(self):
        """
        Load KV caches from host memory to device memory layer by layer.
        """
480
481
482
483
484
485
        torch.cuda.set_stream(self.load_stream)
        while not self.stop_event.is_set():
            self.load_cache_event.wait(timeout=1)
            if not self.load_cache_event.is_set():
                continue
            self.load_cache_event.clear()
486
            self.layer_done_counter.update_producer()
487

488
489
490
            batch_operation = None
            while self.load_queue.qsize() > 0:
                op = self.load_queue.get(block=True)
491
                if batch_operation is None:
492
493
494
495
496
                    batch_operation = op
                else:
                    batch_operation.merge(op)
            if batch_operation is None:
                continue
497

498
            # start layer-wise KV cache transfer from CPU to GPU
499
            self.layer_done_counter.reset()
500
501
502
            host_indices, device_indices = self.move_indices(
                batch_operation.host_indices, batch_operation.device_indices
            )
503
            for i in range(self.mem_pool_host.layer_num):
504
505
                self.mem_pool_host.load_to_device_per_layer(
                    self.mem_pool_device,
506
507
508
509
510
511
                    host_indices,
                    device_indices,
                    i,
                    self.io_backend,
                )
                self.load_stream.synchronize()
512
513
514
515
516
517
                self.layer_done_counter.increment()

            self.mem_pool_host.complete_io(batch_operation.host_indices)
            for node_id in batch_operation.node_ids:
                if node_id != 0:
                    self.ack_load_queue.put(node_id)
518

519
520
521
522
    def evict_device(
        self, device_indices: torch.Tensor, host_indices: torch.Tensor
    ) -> int:
        if self.mem_pool_host.is_synced(host_indices):
523
            self.mem_pool_device_allocator.free(device_indices)
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
            self.mem_pool_host.update_backup(host_indices)
            return len(device_indices)
        else:
            raise ValueError(
                f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
            )

    def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
        if not backup_only:
            raise ValueError("Other eviction policies are not supported yet.")

        if self.mem_pool_host.is_backup(host_indices):
            self.mem_pool_host.free(host_indices)
            return len(host_indices)
        else:
            raise ValueError(
                f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
            )
542
543
544
545
546
547
548

    def prefetch(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
pansicheng's avatar
pansicheng committed
549
    ) -> PrefetchOperation:
550
551
552
553
554
555
556
557
558
559
560
561
562
        """
        Prefetch KV caches from storage backend to host memory.
        """
        operation = PrefetchOperation(
            request_id, host_indices, new_input_tokens, last_hash
        )
        self.prefetch_queue.put(operation)
        return operation

    def terminate_prefetch(self, operation):
        operation.mark_done()
        return operation.completed_tokens, operation.hash_value

pansicheng's avatar
pansicheng committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    def zerocopy_page_transfer(self, operation, batch_size=8):
        hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
            operation.hash_value, operation.host_indices
        )
        for i in range(0, len(hashes), batch_size):
            page_hashes = hashes[i : i + batch_size]
            page_dsts = dsts[i : i + batch_size]
            page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
            if page_data is None:
                logger.warning(
                    f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
                )
                break
            completed_tokens = operation.completed_tokens
            if operation.increment(self.page_size * len(page_hashes)):
                for i in range(len(page_hashes)):
                    completed_tokens += self.page_size
            else:
                break

583
584
585
    def generic_page_transfer(self, operation, batch_size=8):
        for i in range(0, len(operation.hash_value), batch_size):
            page_hashes = operation.hash_value[i : i + batch_size]
586
            # todo: zero copy
pansicheng's avatar
pansicheng committed
587
588
589
590
            dummy_page_dst = [
                self.mem_pool_host.get_dummy_flat_data_page()
                for _ in range(len(page_hashes))
            ]
591
            page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            if page_data is None:
                logger.warning(
                    f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
                )
                break
            completed_tokens = operation.completed_tokens
            if operation.increment(self.page_size * len(page_hashes)):
                for i in range(len(page_hashes)):
                    self.mem_pool_host.set_from_flat_data_page(
                        operation.host_indices[completed_tokens],
                        page_data[i],
                    )
                    completed_tokens += self.page_size
            else:
                break

    def mooncake_page_transfer(self, operation):
        key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
            operation.hash_value, operation.host_indices
        )
        self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
        operation.increment(len(operation.hash_value) * self.page_size)

615
616
617
    def is_mooncake_backend(self):
        return self.storage_backend_type == "mooncake"

618
619
620
621
622
623
624
    def prefetch_io_aux_func(self):
        """
        Auxiliary function conducting IO operations for prefetching.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.prefetch_buffer.get(block=True, timeout=1)
625
                if self.is_mooncake_backend():
626
                    self.mooncake_page_transfer(operation)
pansicheng's avatar
pansicheng committed
627
                elif self.storage_backend_type == "hf3fs":
pansicheng's avatar
pansicheng committed
628
629
630
631
                    if self.mem_pool_host.layout == "page_first":
                        self.zerocopy_page_transfer(operation, batch_size=128)
                    elif self.mem_pool_host.layout == "layer_first":
                        self.generic_page_transfer(operation, batch_size=128)
632
633
                else:
                    self.generic_page_transfer(operation)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
634
635
636

                if self.tp_world_size > 1:
                    # to ensure all TP workers release the host memory at the same time
pansicheng's avatar
pansicheng committed
637
                    torch.distributed.barrier(group=self.prefetch_io_tp_group)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
638
639
640
641
                # operation terminated by controller, release pre-allocated memory
                self.mem_pool_host.free(
                    operation.host_indices[operation.completed_tokens :]
                )
642
643
644
            except Empty:
                continue

pansicheng's avatar
pansicheng committed
645
646
647
648
649
650
651
652
653
654
    def prefetch_rate_limit_check(self) -> bool:
        """
        Rate limit the prefetching operations to avoid overwhelming the storage backend.
        """
        # cancel prefetch if too much memory is occupied
        if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
            return False
        # todo: more sophisticated rate limiting based on storage backend performance
        return True

655
656
657
658
659
660
661
662
663
664
665
666
667
668
    def prefetch_thread_func(self):
        """
        Manage prefetching operations from storage backend to host memory.
        """
        self.prefetch_buffer = Queue()
        aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
        aux_thread.start()
        while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
            try:
                operation = self.prefetch_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

                storage_hit_count = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
669
670
671
                if (
                    operation.host_indices is not None
                ) and self.prefetch_rate_limit_check():
pansicheng's avatar
pansicheng committed
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                    last_hash = operation.last_hash
                    tokens_to_fetch = operation.token_ids

                    remaining_tokens = len(tokens_to_fetch)
                    hash_value = []
                    while remaining_tokens >= self.page_size:
                        last_hash = self.get_hash_str(
                            tokens_to_fetch[
                                storage_hit_count : storage_hit_count + self.page_size
                            ],
                            last_hash,
                        )

                        # todo, more unified interface
                        if not self.is_mooncake_backend():
                            if not self.storage_backend.exists(last_hash):
                                break
                        hash_value.append(last_hash)
                        storage_hit_count += self.page_size
                        remaining_tokens -= self.page_size

                    if self.is_mooncake_backend():
                        # deferring to batch exists for mooncake store
                        exist_result = self.storage_backend.exists(hash_value)
                        storage_hit_count = (
                            sum(1 for v in exist_result.values() if v != 0)
                            * self.page_size
                        )
700

701
702
703
704
705
706
707
                if self.tp_world_size > 1:
                    storage_hit_count_tensor = torch.tensor(
                        storage_hit_count, dtype=torch.int
                    )
                    torch.distributed.all_reduce(
                        storage_hit_count_tensor,
                        op=torch.distributed.ReduceOp.MIN,
708
                        group=self.prefetch_tp_group,
709
710
711
                    )
                    storage_hit_count = storage_hit_count_tensor.item()

712
713
714
                if storage_hit_count < self.prefetch_threshold:
                    # not to prefetch if not enough benefits
                    self.prefetch_revoke_queue.put(operation.request_id)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
715
716
                    if operation.host_indices is not None:
                        self.mem_pool_host.free(operation.host_indices)
717
718
719
                    logger.debug(
                        f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
                    )
720
                else:
721
722
723
724
725
726
                    operation.hash_value = hash_value[
                        : (storage_hit_count // self.page_size)
                    ]
                    # free the pre-allocated memory for pages that are not hit
                    self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
                    operation.host_indices = operation.host_indices[:storage_hit_count]
727
                    logger.debug(
728
                        f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
729
730
731
732
733
734
735
736
737
738
                    )
                    self.prefetch_buffer.put(operation)

            except Empty:
                continue

    def write_storage(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
739
        hash_value: Optional[List[str]] = None,
740
741
742
743
    ) -> int:
        """
        Write KV caches from host memory to storage backend.
        """
744
        operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
745
746
747
        self.backup_queue.put(operation)
        return operation.id

pansicheng's avatar
pansicheng committed
748
749
750
751
752
753
754
755
756
757
758
759
760
    def zerocopy_page_backup(self, operation, batch_size=8):
        hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
            operation.hash_value, operation.host_indices
        )
        for i in range(0, len(hashes), batch_size):
            page_hashes = hashes[i : i + batch_size]
            page_data = dsts[i : i + batch_size]
            success = self.storage_backend.batch_set(page_hashes, page_data)
            if not success:
                logger.warning(f"Failed to write page {page_hashes} to storage.")
                break
            operation.completed_tokens += self.page_size * len(page_hashes)

761
762
763
764
    def generic_page_backup(self, operation, batch_size=8):
        for i in range(0, len(operation.hash_value), batch_size):
            page_hashes = operation.hash_value[i : i + batch_size]
            page_data = [
765
                self.mem_pool_host.get_flat_data_page(
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
                    operation.host_indices[j * self.page_size]
                )
                for j in range(i, i + len(page_hashes))
            ]
            success = self.storage_backend.batch_set(page_hashes, page_data)
            if not success:
                logger.warning(f"Failed to write page {page_hashes} to storage.")
                break
            operation.completed_tokens += self.page_size * len(page_hashes)

    def mooncake_page_backup(self, operation):
        if len(operation.hash_value):
            exist_hashvalues = self.storage_backend.exists(operation.hash_value)
            indices = operation.host_indices.tolist()
            non_exist_keys = []
            non_exist_indices = []
            for i in range(len(operation.hash_value)):
                if not exist_hashvalues[operation.hash_value[i]]:
                    non_exist_keys.append(operation.hash_value[i])
                    non_exist_indices.extend(
                        indices[i * self.page_size : (i + 1) * self.page_size]
                    )
            if len(non_exist_keys) > 0:
                key_strs, buffer_ptrs, buffer_sizes = (
                    self.mem_pool_host.get_buffer_meta(
                        non_exist_keys, non_exist_indices
                    )
                )
                # TODO: check the return value of batch set to see how many tokens are set successfully
                self.storage_backend.batch_set(
                    key_strs,
                    target_location=buffer_ptrs,
                    target_sizes=buffer_sizes,
                )
        operation.completed_tokens += len(operation.hash_value) * self.page_size

802
803
804
805
806
807
808
809
810
811
    def backup_thread_func(self):
        """
        Manage backup operations from host memory to storage backend.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.backup_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

812
                if self.is_mooncake_backend():
813
                    self.mooncake_page_backup(operation)
pansicheng's avatar
pansicheng committed
814
                elif self.storage_backend_type == "hf3fs":
pansicheng's avatar
pansicheng committed
815
816
817
818
                    if self.mem_pool_host.layout == "page_first":
                        self.zerocopy_page_backup(operation, batch_size=128)
                    elif self.mem_pool_host.layout == "layer_first":
                        self.generic_page_backup(operation, batch_size=128)
819
                else:
820
                    self.generic_page_backup(operation)
821

822
823
824
825
826
827
828
829
                min_completed_tokens = operation.completed_tokens
                if self.tp_world_size > 1:
                    completed_tokens_tensor = torch.tensor(
                        min_completed_tokens, dtype=torch.int
                    )
                    torch.distributed.all_reduce(
                        completed_tokens_tensor,
                        op=torch.distributed.ReduceOp.MIN,
830
                        group=self.backup_tp_group,
831
832
833
834
835
836
837
838
839
                    )
                    min_completed_tokens = completed_tokens_tensor.item()

                self.ack_backup_queue.put(
                    (
                        operation.id,
                        min_completed_tokens,
                    )
                )
840
841
842

            except Empty:
                continue