"vscode:/vscode.git/clone" did not exist on "12c6eb265d2bd75947a8253f45bb60c89e6e71b8"
cache_controller.py 30.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from __future__ import annotations

"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
17
import math
18
import threading
pansicheng's avatar
pansicheng committed
19
import time
20
from queue import Empty, Full, PriorityQueue, Queue
21
from typing import TYPE_CHECKING, List, Optional
22
23
24

import torch

25
26
27
if TYPE_CHECKING:
    from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
    from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28

29

30
31
32
logger = logging.getLogger(__name__)


33
34
class LayerDoneCounter:
    def __init__(self, num_layers):
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        self.num_layers = num_layers
        # extra producer and consumer counters for overlap mode
        self.num_counters = 3
        self.counters = [num_layers] * self.num_counters
        self.conditions = [threading.Condition() for _ in range(self.num_counters)]
        self.producer_index = 0
        self.consumer_index = 0

    def next_producer(self):
        return (self.producer_index + 1) % self.num_counters

    def update_producer(self):
        self.producer_index = self.next_producer()
        return self.producer_index

    def set_consumer(self, index):
        self.consumer_index = index
52
53

    def increment(self):
54
55
56
        with self.conditions[self.producer_index]:
            self.counters[self.producer_index] += 1
            self.conditions[self.producer_index].notify_all()
57
58

    def wait_until(self, threshold):
59
60
61
        with self.conditions[self.consumer_index]:
            while self.counters[self.consumer_index] <= threshold:
                self.conditions[self.consumer_index].wait()
62
63

    def reset(self):
64
65
        with self.conditions[self.producer_index]:
            self.counters[self.producer_index] = 0
66
67


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class CacheOperation:

    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        device_indices: torch.Tensor,
        node_id: int,
        priority: Optional[int] = None,
    ):
        self.host_indices = host_indices
        self.device_indices = device_indices
        self.node_ids = [node_id]
        self.data = None

        self.id = CacheOperation.counter
        CacheOperation.counter += 1
        # default priority is the order of creation
        self.priority = priority if priority is not None else self.id

    def merge(self, other: "CacheOperation") -> None:
        # multiple operations can be merged into a single operation for batch processing
        self.host_indices = torch.cat([self.host_indices, other.host_indices])
        self.device_indices = torch.cat([self.device_indices, other.device_indices])
        self.priority = min(self.priority, other.priority)
        self.node_ids.extend(other.node_ids)

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    def split(self, factor) -> List["CacheOperation"]:
        # split an operation into smaller operations to reduce the size of intermediate buffers
        if factor <= 1:
            return [self]

        chunk_size = math.ceil(len(self.host_indices) / factor)
        split_ops = []
        for i in range(0, len(self.host_indices), chunk_size):
            split_ops.append(
                CacheOperation(
                    host_indices=self.host_indices[i : i + chunk_size],
                    device_indices=self.device_indices[i : i + chunk_size],
                    node_id=0,
                )
            )
        # Inherit the node_ids on the final chunk
        if split_ops:
            split_ops[-1].node_ids = self.node_ids

        return split_ops

117
118
119
120
121
122
123
124
125
    def __lt__(self, other: "CacheOperation"):
        return self.priority < other.priority


class TransferBuffer:
    """
    Overlapping buffer preparation and transfer operations to improve throughput.
    """

126
    def __init__(
127
        self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
128
129
    ) -> None:
        self.stop_event = stop_event
130
131
132
133
134
135
136
137
138
139
        self.buffers = Queue(maxsize=buffer_count)
        # todo: adjust the buffer size based on throughput profile of the system
        self.max_buffer_size = max_buffer_size

    def full(self) -> bool:
        return self.buffers.full()

    def empty(self) -> bool:
        return self.buffers.empty()

140
141
142
143
144
145
146
147
148
149
150
    def put(self, item, block=True, timeout=1) -> None:
        while not self.stop_event.is_set():
            try:
                self.buffers.put(item, block=block, timeout=timeout)
                break
            except Full:
                if not block:
                    break
                continue
            except Exception as e:
                logger.error(e)
151

152
    def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
153
        try:
154
155
156
            return self.buffers.get(block=block, timeout=timeout)
        except Empty:
            return None
157
158
159
        except Exception as e:
            logger.error(e)

160
161
162
    def clear(self):
        self.buffers.queue.clear()

163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class StorageOperation:
    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
    ):
        self.host_indices = host_indices
        self.token_ids = token_ids
        self.last_hash = last_hash
        self.completed_tokens = 0
        self.hash_value = []

        self.id = StorageOperation.counter
        StorageOperation.counter += 1

    def __lt__(self, other: "StorageOperation"):
        return self.id < other.id


class PrefetchOperation(StorageOperation):
    def __init__(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
    ):
        self.request_id = request_id

        self._done_flag = False
        self._lock = threading.Lock()

pansicheng's avatar
pansicheng committed
199
200
        self.start_time = time.monotonic()

201
202
203
204
205
        super().__init__(host_indices, token_ids, last_hash)

    def increment(self, num_tokens: int):
        with self._lock:
            if self._done_flag:
206
                return False
207
            self.completed_tokens += num_tokens
208
            return True
209
210
211
212
213
214
215
216
217

    def mark_done(self):
        with self._lock:
            self._done_flag = True

    def is_done(self) -> bool:
        return self._done_flag


218
219
220
221
class HiCacheController:

    def __init__(
        self,
222
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
223
        mem_pool_host: HostKVCache,
224
        page_size: int,
225
        tp_group: torch.distributed.ProcessGroup,
226
        load_cache_event: threading.Event = None,
227
        write_policy: str = "write_through_selective",
228
        io_backend: str = "",
229
230
        storage_backend: Optional[str] = None,
        prefetch_threshold: int = 256,
231
    ):
232
233
        self.mem_pool_device_allocator = token_to_kv_pool_allocator
        self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
234
235
        self.mem_pool_host = mem_pool_host
        self.write_policy = write_policy
236
        self.page_size = page_size
237
        self.io_backend = io_backend
238

239
240
241
        self.enable_storage = False
        # todo: move backend initialization to storage backend module
        if storage_backend is not None:
242
            self.storage_backend_type = storage_backend
243
            from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
244

245
246
            if storage_backend == "file":
                self.storage_backend = HiCacheFile()
247
                self.get_hash_str = get_hash_str
248
            elif storage_backend == "nixl":
249
                from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
250
251
252

                self.storage_backend = HiCacheNixl()
                self.get_hash_str = get_hash_str
253
            elif storage_backend == "mooncake":
254
                from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
255
256
257
258
                    MooncakeStore,
                    get_hash_str_mooncake,
                )

259
260
261
                self.storage_backend = MooncakeStore()
                self.get_hash_str = get_hash_str_mooncake
                self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
262
263
            elif storage_backend == "hf3fs":
                from sglang.srt.distributed import get_tensor_model_parallel_rank
264
265
266
                from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
                    HiCacheHF3FS,
                )
267
268
269
270
271
272
273
274
275

                rank = get_tensor_model_parallel_rank()
                bytes_per_page = (
                    mem_pool_host.get_size_per_token() * mem_pool_host.page_size
                )
                dtype = mem_pool_host.dtype
                self.storage_backend = HiCacheHF3FS.from_env_config(
                    rank, bytes_per_page, dtype
                )
276
                self.get_hash_str = get_hash_str
277
278
279
280
            else:
                raise NotImplementedError(
                    f"Unsupported storage backend: {storage_backend}"
                )
281
282
283
            self.enable_storage = True
            # todo: threshold policy for prefetching
            self.prefetch_threshold = max(prefetch_threshold, self.page_size)
pansicheng's avatar
pansicheng committed
284
285
286
287
288
289
            self.prefetch_capacity_limit = int(
                0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
            )
            # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
            self.prefetch_tokens_occupied = 0

290
291
292
293
294
295
296
297
298
299
            # create a new communication group for synchronizing storage operations across TP workers
            self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
            if self.tp_world_size > 1:
                group_ranks = torch.distributed.get_process_group_ranks(tp_group)
                self.prefetch_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
                self.backup_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
300

301
302
303
304
        self.load_cache_event = load_cache_event
        self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
        self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)

305
306
307
308
309
310
311
312
313
314
315
316
317
        if write_policy not in [
            "write_through",
            "write_through_selective",
            "write_back",
        ]:
            raise ValueError(f"Invalid write policy: {write_policy}")

        self.write_queue = PriorityQueue()
        self.load_queue = PriorityQueue()

        self.ack_write_queue = Queue()
        self.ack_load_queue = Queue()

318
319
320
321
322
        self.stop_event = threading.Event()
        self.write_buffer = TransferBuffer(self.stop_event)
        self.load_buffer = TransferBuffer(
            self.stop_event, buffer_count=10, max_buffer_size=100
        )
323
324
325
326
327

        self.write_stream = torch.cuda.Stream()
        self.load_stream = torch.cuda.Stream()

        self.write_thread = threading.Thread(
328
            target=self.write_thread_func_direct, daemon=True
329
330
        )
        self.load_thread = threading.Thread(
331
            target=self.load_thread_func_layer_by_layer, daemon=True
332
        )
333

334
335
336
        self.write_thread.start()
        self.load_thread.start()

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_queue = Queue()
            self.backup_queue = Queue()

            self.prefetch_revoke_queue = Queue()
            self.ack_backup_queue = Queue()

            self.prefetch_thread.start()
            self.backup_thread.start()

353
354
355
356
357
358
359
360
361
362
363
    def reset(self):
        self.stop_event.set()
        self.write_thread.join()
        self.load_thread.join()

        self.write_queue.queue.clear()
        self.load_queue.queue.clear()
        self.write_buffer.clear()
        self.load_buffer.clear()
        self.ack_write_queue.queue.clear()
        self.ack_load_queue.queue.clear()
364
365
366
367
368
369
370
        if self.enable_storage:
            self.prefetch_thread.join()
            self.backup_thread.join()
            self.prefetch_queue.queue.clear()
            self.backup_queue.queue.clear()
            self.prefetch_revoke_queue.queue.clear()
            self.ack_backup_queue.queue.clear()
371
372

        self.write_thread = threading.Thread(
373
            target=self.write_thread_func_direct, daemon=True
374
375
        )
        self.load_thread = threading.Thread(
376
            target=self.load_thread_func_layer_by_layer, daemon=True
377
378
379
380
381
        )
        self.stop_event.clear()
        self.write_thread.start()
        self.load_thread.start()

382
383
384
385
386
387
388
389
390
391
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_thread.start()
            self.backup_thread.start()

392
393
394
395
396
397
398
399
400
401
402
403
    def write(
        self,
        device_indices: torch.Tensor,
        priority: Optional[int] = None,
        node_id: int = 0,
    ) -> Optional[torch.Tensor]:
        """
        Back up KV caches from device memory to host memory.
        """
        host_indices = self.mem_pool_host.alloc(len(device_indices))
        if host_indices is None:
            return None
404
        self.mem_pool_host.protect_write(host_indices)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
405
        torch.cuda.current_stream().synchronize()
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        self.write_queue.put(
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return host_indices

    def load(
        self,
        host_indices: torch.Tensor,
        priority: Optional[int] = None,
        node_id: int = 0,
    ) -> Optional[torch.Tensor]:
        """
        Load KV caches from host memory to device memory.
        """
420
        device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
421
422
        if device_indices is None:
            return None
423
        self.mem_pool_host.protect_load(host_indices)
424
425
        # to ensure the device indices are ready before accessed by another CUDA stream
        torch.cuda.current_stream().synchronize()
426
427
428
429
430
        self.load_queue.put(
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return device_indices

431
432
433
434
435
    def move_indices(self, host_indices, device_indices):
        # move indices to GPU if using kernels, to host if using direct indexing
        if self.io_backend == "kernel":
            return host_indices.to(self.mem_pool_device.device), device_indices
        elif self.io_backend == "direct":
436
437
438
            device_indices = device_indices.cpu()
            host_indices, idx = host_indices.sort()
            return host_indices, device_indices.index_select(0, idx)
439
440
441
        else:
            raise ValueError(f"Unsupported io backend")

442
443
444
445
    def write_thread_func_direct(self):
        """
        Directly write through KV caches to host memory without buffering.
        """
446
447
448
449
        torch.cuda.set_stream(self.write_stream)
        while not self.stop_event.is_set():
            try:
                operation = self.write_queue.get(block=True, timeout=1)
450
451
452
                host_indices, device_indices = self.move_indices(
                    operation.host_indices, operation.device_indices
                )
453
454
                self.mem_pool_host.backup_from_device_all_layer(
                    self.mem_pool_device, host_indices, device_indices, self.io_backend
455
456
457
458
459
460
461
462
463
464
                )
                self.write_stream.synchronize()
                self.mem_pool_host.complete_io(operation.host_indices)
                for node_id in operation.node_ids:
                    if node_id != 0:
                        self.ack_write_queue.put(node_id)
            except Empty:
                continue
            except Exception as e:
                logger.error(e)
465

466
467
468
469
    def load_thread_func_layer_by_layer(self):
        """
        Load KV caches from host memory to device memory layer by layer.
        """
470
471
472
473
474
475
        torch.cuda.set_stream(self.load_stream)
        while not self.stop_event.is_set():
            self.load_cache_event.wait(timeout=1)
            if not self.load_cache_event.is_set():
                continue
            self.load_cache_event.clear()
476
            self.layer_done_counter.update_producer()
477

478
479
480
            batch_operation = None
            while self.load_queue.qsize() > 0:
                op = self.load_queue.get(block=True)
481
                if batch_operation is None:
482
483
484
485
486
                    batch_operation = op
                else:
                    batch_operation.merge(op)
            if batch_operation is None:
                continue
487

488
            # start layer-wise KV cache transfer from CPU to GPU
489
            self.layer_done_counter.reset()
490
491
492
            host_indices, device_indices = self.move_indices(
                batch_operation.host_indices, batch_operation.device_indices
            )
493
            for i in range(self.mem_pool_host.layer_num):
494
495
                self.mem_pool_host.load_to_device_per_layer(
                    self.mem_pool_device,
496
497
498
499
500
501
                    host_indices,
                    device_indices,
                    i,
                    self.io_backend,
                )
                self.load_stream.synchronize()
502
503
504
505
506
507
                self.layer_done_counter.increment()

            self.mem_pool_host.complete_io(batch_operation.host_indices)
            for node_id in batch_operation.node_ids:
                if node_id != 0:
                    self.ack_load_queue.put(node_id)
508

509
510
511
512
    def evict_device(
        self, device_indices: torch.Tensor, host_indices: torch.Tensor
    ) -> int:
        if self.mem_pool_host.is_synced(host_indices):
513
            self.mem_pool_device_allocator.free(device_indices)
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
            self.mem_pool_host.update_backup(host_indices)
            return len(device_indices)
        else:
            raise ValueError(
                f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
            )

    def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
        if not backup_only:
            raise ValueError("Other eviction policies are not supported yet.")

        if self.mem_pool_host.is_backup(host_indices):
            self.mem_pool_host.free(host_indices)
            return len(host_indices)
        else:
            raise ValueError(
                f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
            )
532
533
534
535
536
537
538

    def prefetch(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
pansicheng's avatar
pansicheng committed
539
    ) -> PrefetchOperation:
540
541
542
543
544
545
546
547
548
549
550
551
552
        """
        Prefetch KV caches from storage backend to host memory.
        """
        operation = PrefetchOperation(
            request_id, host_indices, new_input_tokens, last_hash
        )
        self.prefetch_queue.put(operation)
        return operation

    def terminate_prefetch(self, operation):
        operation.mark_done()
        return operation.completed_tokens, operation.hash_value

553
554
555
    def generic_page_transfer(self, operation, batch_size=8):
        for i in range(0, len(operation.hash_value), batch_size):
            page_hashes = operation.hash_value[i : i + batch_size]
556
557
558
559
560
            # todo: zero copy
            dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
                page_hashes
            )
            page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
            if page_data is None:
                logger.warning(
                    f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
                )
                break
            completed_tokens = operation.completed_tokens
            if operation.increment(self.page_size * len(page_hashes)):
                for i in range(len(page_hashes)):
                    self.mem_pool_host.set_from_flat_data_page(
                        operation.host_indices[completed_tokens],
                        page_data[i],
                    )
                    completed_tokens += self.page_size
            else:
                break

    def mooncake_page_transfer(self, operation):
        key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
            operation.hash_value, operation.host_indices
        )
        self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
        operation.increment(len(operation.hash_value) * self.page_size)

584
585
586
    def is_mooncake_backend(self):
        return self.storage_backend_type == "mooncake"

587
588
589
590
591
592
593
    def prefetch_io_aux_func(self):
        """
        Auxiliary function conducting IO operations for prefetching.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.prefetch_buffer.get(block=True, timeout=1)
594
                if self.is_mooncake_backend():
595
                    self.mooncake_page_transfer(operation)
pansicheng's avatar
pansicheng committed
596
597
                elif self.storage_backend_type == "hf3fs":
                    self.generic_page_transfer(operation, batch_size=128)
598
599
                else:
                    self.generic_page_transfer(operation)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
600
601
602
603
604
605
606
607

                if self.tp_world_size > 1:
                    # to ensure all TP workers release the host memory at the same time
                    torch.distributed.barrier(group=self.prefetch_tp_group)
                # operation terminated by controller, release pre-allocated memory
                self.mem_pool_host.free(
                    operation.host_indices[operation.completed_tokens :]
                )
608
609
610
            except Empty:
                continue

pansicheng's avatar
pansicheng committed
611
612
613
614
615
616
617
618
619
620
    def prefetch_rate_limit_check(self) -> bool:
        """
        Rate limit the prefetching operations to avoid overwhelming the storage backend.
        """
        # cancel prefetch if too much memory is occupied
        if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
            return False
        # todo: more sophisticated rate limiting based on storage backend performance
        return True

621
622
623
624
625
626
627
628
629
630
631
632
633
634
    def prefetch_thread_func(self):
        """
        Manage prefetching operations from storage backend to host memory.
        """
        self.prefetch_buffer = Queue()
        aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
        aux_thread.start()
        while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
            try:
                operation = self.prefetch_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

                storage_hit_count = 0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
635
636
637
                if (
                    operation.host_indices is not None
                ) and self.prefetch_rate_limit_check():
pansicheng's avatar
pansicheng committed
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
                    last_hash = operation.last_hash
                    tokens_to_fetch = operation.token_ids

                    remaining_tokens = len(tokens_to_fetch)
                    hash_value = []
                    while remaining_tokens >= self.page_size:
                        last_hash = self.get_hash_str(
                            tokens_to_fetch[
                                storage_hit_count : storage_hit_count + self.page_size
                            ],
                            last_hash,
                        )

                        # todo, more unified interface
                        if not self.is_mooncake_backend():
                            if not self.storage_backend.exists(last_hash):
                                break
                        hash_value.append(last_hash)
                        storage_hit_count += self.page_size
                        remaining_tokens -= self.page_size

                    if self.is_mooncake_backend():
                        # deferring to batch exists for mooncake store
                        exist_result = self.storage_backend.exists(hash_value)
                        storage_hit_count = (
                            sum(1 for v in exist_result.values() if v != 0)
                            * self.page_size
                        )
666

667
668
669
670
671
672
673
                if self.tp_world_size > 1:
                    storage_hit_count_tensor = torch.tensor(
                        storage_hit_count, dtype=torch.int
                    )
                    torch.distributed.all_reduce(
                        storage_hit_count_tensor,
                        op=torch.distributed.ReduceOp.MIN,
674
                        group=self.prefetch_tp_group,
675
676
677
                    )
                    storage_hit_count = storage_hit_count_tensor.item()

678
679
680
                if storage_hit_count < self.prefetch_threshold:
                    # not to prefetch if not enough benefits
                    self.prefetch_revoke_queue.put(operation.request_id)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
681
682
                    if operation.host_indices is not None:
                        self.mem_pool_host.free(operation.host_indices)
683
684
685
                    logger.debug(
                        f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
                    )
686
                else:
687
688
689
690
691
692
                    operation.hash_value = hash_value[
                        : (storage_hit_count // self.page_size)
                    ]
                    # free the pre-allocated memory for pages that are not hit
                    self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
                    operation.host_indices = operation.host_indices[:storage_hit_count]
693
                    logger.debug(
694
                        f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
                    )
                    self.prefetch_buffer.put(operation)

            except Empty:
                continue

    def write_storage(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
    ) -> int:
        """
        Write KV caches from host memory to storage backend.
        """
        operation = StorageOperation(host_indices, token_ids, last_hash)
        self.backup_queue.put(operation)
        return operation.id

714
715
716
717
    def generic_page_backup(self, operation, batch_size=8):
        for i in range(0, len(operation.hash_value), batch_size):
            page_hashes = operation.hash_value[i : i + batch_size]
            page_data = [
718
                self.mem_pool_host.get_flat_data_page(
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
                    operation.host_indices[j * self.page_size]
                )
                for j in range(i, i + len(page_hashes))
            ]
            success = self.storage_backend.batch_set(page_hashes, page_data)
            if not success:
                logger.warning(f"Failed to write page {page_hashes} to storage.")
                break
            operation.completed_tokens += self.page_size * len(page_hashes)

    def mooncake_page_backup(self, operation):
        if len(operation.hash_value):
            exist_hashvalues = self.storage_backend.exists(operation.hash_value)
            indices = operation.host_indices.tolist()
            non_exist_keys = []
            non_exist_indices = []
            for i in range(len(operation.hash_value)):
                if not exist_hashvalues[operation.hash_value[i]]:
                    non_exist_keys.append(operation.hash_value[i])
                    non_exist_indices.extend(
                        indices[i * self.page_size : (i + 1) * self.page_size]
                    )
            if len(non_exist_keys) > 0:
                key_strs, buffer_ptrs, buffer_sizes = (
                    self.mem_pool_host.get_buffer_meta(
                        non_exist_keys, non_exist_indices
                    )
                )
                # TODO: check the return value of batch set to see how many tokens are set successfully
                self.storage_backend.batch_set(
                    key_strs,
                    target_location=buffer_ptrs,
                    target_sizes=buffer_sizes,
                )
        operation.completed_tokens += len(operation.hash_value) * self.page_size

755
756
757
758
759
760
761
762
763
764
765
766
767
    def backup_thread_func(self):
        """
        Manage backup operations from host memory to storage backend.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.backup_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

                last_hash = operation.last_hash
                tokens_to_backup = operation.token_ids

768
769
770
771
772
773
774
775
776
                backup_hit_count = 0
                remaining_tokens = len(tokens_to_backup)
                hash_value = []
                while remaining_tokens >= self.page_size:
                    last_hash = self.get_hash_str(
                        tokens_to_backup[
                            backup_hit_count : backup_hit_count + self.page_size
                        ],
                        last_hash,
777
                    )
778
779
780
781
                    backup_hit_count += self.page_size
                    hash_value.append(last_hash)
                    remaining_tokens -= self.page_size
                operation.hash_value = hash_value
782

783
                if self.is_mooncake_backend():
784
                    self.mooncake_page_backup(operation)
pansicheng's avatar
pansicheng committed
785
786
                elif self.storage_backend_type == "hf3fs":
                    self.generic_page_backup(operation, batch_size=128)
787
                else:
788
                    self.generic_page_backup(operation)
789

790
791
792
793
794
795
796
797
                min_completed_tokens = operation.completed_tokens
                if self.tp_world_size > 1:
                    completed_tokens_tensor = torch.tensor(
                        min_completed_tokens, dtype=torch.int
                    )
                    torch.distributed.all_reduce(
                        completed_tokens_tensor,
                        op=torch.distributed.ReduceOp.MIN,
798
                        group=self.backup_tp_group,
799
800
801
802
803
804
805
806
807
808
                    )
                    min_completed_tokens = completed_tokens_tensor.item()

                self.ack_backup_queue.put(
                    (
                        operation.id,
                        operation.hash_value[: min_completed_tokens // self.page_size],
                        min_completed_tokens,
                    )
                )
809
810
811

            except Empty:
                continue