cache_controller.py 32.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from __future__ import annotations

"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
17
import math
18
import threading
pansicheng's avatar
pansicheng committed
19
import time
20
from queue import Empty, Full, PriorityQueue, Queue
21
from typing import TYPE_CHECKING, List, Optional
22
23
24

import torch

25
26
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig

27
28
29
if TYPE_CHECKING:
    from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
    from sglang.srt.mem_cache.memory_pool_host import HostKVCache
30

31
32
33
34
35
36
37
38
39
from sglang.srt.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
    is_dp_attention_enabled,
)
40
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
41

42
43
44
logger = logging.getLogger(__name__)


45
46
class LayerDoneCounter:
    def __init__(self, num_layers):
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        self.num_layers = num_layers
        # extra producer and consumer counters for overlap mode
        self.num_counters = 3
        self.counters = [num_layers] * self.num_counters
        self.conditions = [threading.Condition() for _ in range(self.num_counters)]
        self.producer_index = 0
        self.consumer_index = 0

    def next_producer(self):
        return (self.producer_index + 1) % self.num_counters

    def update_producer(self):
        self.producer_index = self.next_producer()
        return self.producer_index

    def set_consumer(self, index):
        self.consumer_index = index
64
65

    def increment(self):
66
67
68
        with self.conditions[self.producer_index]:
            self.counters[self.producer_index] += 1
            self.conditions[self.producer_index].notify_all()
69
70

    def wait_until(self, threshold):
71
72
73
        with self.conditions[self.consumer_index]:
            while self.counters[self.consumer_index] <= threshold:
                self.conditions[self.consumer_index].wait()
74
75

    def reset(self):
76
77
        with self.conditions[self.producer_index]:
            self.counters[self.producer_index] = 0
78
79


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
class CacheOperation:

    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        device_indices: torch.Tensor,
        node_id: int,
        priority: Optional[int] = None,
    ):
        self.host_indices = host_indices
        self.device_indices = device_indices
        self.node_ids = [node_id]
        self.data = None

        self.id = CacheOperation.counter
        CacheOperation.counter += 1
        # default priority is the order of creation
        self.priority = priority if priority is not None else self.id

    def merge(self, other: "CacheOperation") -> None:
        # multiple operations can be merged into a single operation for batch processing
        self.host_indices = torch.cat([self.host_indices, other.host_indices])
        self.device_indices = torch.cat([self.device_indices, other.device_indices])
        self.priority = min(self.priority, other.priority)
        self.node_ids.extend(other.node_ids)

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    def split(self, factor) -> List["CacheOperation"]:
        # split an operation into smaller operations to reduce the size of intermediate buffers
        if factor <= 1:
            return [self]

        chunk_size = math.ceil(len(self.host_indices) / factor)
        split_ops = []
        for i in range(0, len(self.host_indices), chunk_size):
            split_ops.append(
                CacheOperation(
                    host_indices=self.host_indices[i : i + chunk_size],
                    device_indices=self.device_indices[i : i + chunk_size],
                    node_id=0,
                )
            )
        # Inherit the node_ids on the final chunk
        if split_ops:
            split_ops[-1].node_ids = self.node_ids

        return split_ops

129
130
131
132
133
134
135
136
137
    def __lt__(self, other: "CacheOperation"):
        return self.priority < other.priority


class TransferBuffer:
    """
    Overlapping buffer preparation and transfer operations to improve throughput.
    """

138
    def __init__(
139
        self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
140
141
    ) -> None:
        self.stop_event = stop_event
142
143
144
145
146
147
148
149
150
151
        self.buffers = Queue(maxsize=buffer_count)
        # todo: adjust the buffer size based on throughput profile of the system
        self.max_buffer_size = max_buffer_size

    def full(self) -> bool:
        return self.buffers.full()

    def empty(self) -> bool:
        return self.buffers.empty()

152
153
154
155
156
157
158
159
160
161
162
    def put(self, item, block=True, timeout=1) -> None:
        while not self.stop_event.is_set():
            try:
                self.buffers.put(item, block=block, timeout=timeout)
                break
            except Full:
                if not block:
                    break
                continue
            except Exception as e:
                logger.error(e)
163

164
    def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
165
        try:
166
167
168
            return self.buffers.get(block=block, timeout=timeout)
        except Empty:
            return None
169
170
171
        except Exception as e:
            logger.error(e)

172
173
174
    def clear(self):
        self.buffers.queue.clear()

175

176
177
178
179
180
181
182
183
class StorageOperation:
    counter = 0

    def __init__(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
184
        hash_value: Optional[List[str]] = None,
185
186
187
188
189
    ):
        self.host_indices = host_indices
        self.token_ids = token_ids
        self.last_hash = last_hash
        self.completed_tokens = 0
190
        self.hash_value = hash_value if hash_value is not None else []
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        self.id = StorageOperation.counter
        StorageOperation.counter += 1

    def __lt__(self, other: "StorageOperation"):
        return self.id < other.id


class PrefetchOperation(StorageOperation):
    def __init__(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        token_ids: List[int],
        last_hash: Optional[str] = None,
    ):
        self.request_id = request_id

        self._done_flag = False
        self._lock = threading.Lock()

pansicheng's avatar
pansicheng committed
212
213
        self.start_time = time.monotonic()

214
215
216
217
218
        super().__init__(host_indices, token_ids, last_hash)

    def increment(self, num_tokens: int):
        with self._lock:
            if self._done_flag:
219
                return False
220
            self.completed_tokens += num_tokens
221
            return True
222
223
224
225
226
227
228
229
230

    def mark_done(self):
        with self._lock:
            self._done_flag = True

    def is_done(self) -> bool:
        return self._done_flag


231
232
233
234
class HiCacheController:

    def __init__(
        self,
235
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
236
        mem_pool_host: HostKVCache,
237
        page_size: int,
238
        tp_group: torch.distributed.ProcessGroup,
239
        load_cache_event: threading.Event = None,
240
        write_policy: str = "write_through_selective",
241
        io_backend: str = "",
242
243
        storage_backend: Optional[str] = None,
        prefetch_threshold: int = 256,
244
245
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
246
    ):
247
248
        self.mem_pool_device_allocator = token_to_kv_pool_allocator
        self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
249
250
        self.mem_pool_host = mem_pool_host
        self.write_policy = write_policy
251
        self.page_size = page_size
252
        self.io_backend = io_backend
253
        self.enable_storage = False
254

255
        if storage_backend is not None:
256
            self.storage_backend_type = storage_backend
257
258
259
            from sglang.srt.mem_cache.hicache_storage import get_hash_str

            self.get_hash_str = get_hash_str
260
261
262
            self.storage_config = self._generate_storage_config(
                model_name, storage_backend_extra_config
            )
263
            # for MLA models, only one rank needs to backup the KV cache
264
            self.backup_skip = (
265
                self.storage_config.is_mla_model
266
                # todo: load balancing
267
                and self.storage_config.tp_rank != 0
268
            )
269

270
            if storage_backend == "file":
271
272
                from sglang.srt.mem_cache.hicache_storage import HiCacheFile

273
                self.storage_backend = HiCacheFile(self.storage_config)
274
            elif storage_backend == "nixl":
275
                from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
276
277

                self.storage_backend = HiCacheNixl()
278
            elif storage_backend == "mooncake":
279
                from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
280
281
282
                    MooncakeStore,
                )

283
                self.storage_backend = MooncakeStore(self.storage_config)
284
                self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
285
                assert self.mem_pool_host.layout == "page_first"
286
            elif storage_backend == "hf3fs":
287
288
289
                from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
                    HiCacheHF3FS,
                )
290

pansicheng's avatar
pansicheng committed
291
292
293
294
295
296
297
298
                if self.mem_pool_host.layout == "page_first":
                    bytes_per_page = (
                        mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
                    )
                elif self.mem_pool_host.layout == "layer_first":
                    bytes_per_page = (
                        mem_pool_host.get_size_per_token() * mem_pool_host.page_size
                    )
299
300
                dtype = mem_pool_host.dtype
                self.storage_backend = HiCacheHF3FS.from_env_config(
301
                    bytes_per_page, dtype, self.storage_config
302
                )
303
304
305
306
            else:
                raise NotImplementedError(
                    f"Unsupported storage backend: {storage_backend}"
                )
307

308
309
310
            self.enable_storage = True
            # todo: threshold policy for prefetching
            self.prefetch_threshold = max(prefetch_threshold, self.page_size)
pansicheng's avatar
pansicheng committed
311
312
313
            self.prefetch_capacity_limit = int(
                0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
            )
314
315
            # granularity of batch storage IO operations, in number of pages
            self.storage_batch_size = 128
pansicheng's avatar
pansicheng committed
316
317
318
            # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
            self.prefetch_tokens_occupied = 0

319
320
321
322
323
324
325
            # create a new communication group for synchronizing storage operations across TP workers
            self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
            if self.tp_world_size > 1:
                group_ranks = torch.distributed.get_process_group_ranks(tp_group)
                self.prefetch_tp_group = torch.distributed.new_group(
                    group_ranks, backend="gloo"
                )
326

327
328
329
330
        self.load_cache_event = load_cache_event
        self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
        self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)

331
332
333
334
335
336
337
338
339
340
341
342
343
        if write_policy not in [
            "write_through",
            "write_through_selective",
            "write_back",
        ]:
            raise ValueError(f"Invalid write policy: {write_policy}")

        self.write_queue = PriorityQueue()
        self.load_queue = PriorityQueue()

        self.ack_write_queue = Queue()
        self.ack_load_queue = Queue()

344
345
346
347
348
        self.stop_event = threading.Event()
        self.write_buffer = TransferBuffer(self.stop_event)
        self.load_buffer = TransferBuffer(
            self.stop_event, buffer_count=10, max_buffer_size=100
        )
349
350
351
352
353

        self.write_stream = torch.cuda.Stream()
        self.load_stream = torch.cuda.Stream()

        self.write_thread = threading.Thread(
354
            target=self.write_thread_func_direct, daemon=True
355
356
        )
        self.load_thread = threading.Thread(
357
            target=self.load_thread_func_layer_by_layer, daemon=True
358
        )
359

360
361
362
        self.write_thread.start()
        self.load_thread.start()

363
364
365
366
367
368
369
370
371
372
373
374
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_queue = Queue()
            self.backup_queue = Queue()

            self.prefetch_revoke_queue = Queue()
            self.ack_backup_queue = Queue()
375
            self.host_mem_release_queue = Queue()
376
377
378
379

            self.prefetch_thread.start()
            self.backup_thread.start()

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def _generate_storage_config(
        self,
        model_name: Optional[str] = None,
        storage_backend_extra_config: Optional[str] = None,
    ):

        if is_dp_attention_enabled():
            self.tp_rank = get_attention_tp_rank()
            self.tp_size = get_attention_tp_size()
        else:
            self.tp_rank = get_tensor_model_parallel_rank()
            self.tp_size = get_tensor_model_parallel_world_size()

        # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
        is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)

        # Parse extra config JSON if provided
        extra_config = None
        if storage_backend_extra_config:
            try:
                import json

                extra_config = json.loads(storage_backend_extra_config)
            except Exception as e:
                logger.error(f"Invalid backend extra config JSON: {e}")

        return HiCacheStorageConfig(
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            is_mla_model=is_mla_backend,
410
            is_page_first_layout=self.mem_pool_host.layout == "page_first",
411
412
413
414
            model_name=model_name,
            extra_config=extra_config,
        )

415
416
417
418
419
420
421
422
423
424
425
    def reset(self):
        self.stop_event.set()
        self.write_thread.join()
        self.load_thread.join()

        self.write_queue.queue.clear()
        self.load_queue.queue.clear()
        self.write_buffer.clear()
        self.load_buffer.clear()
        self.ack_write_queue.queue.clear()
        self.ack_load_queue.queue.clear()
426
427
428
429
430
431
432
        if self.enable_storage:
            self.prefetch_thread.join()
            self.backup_thread.join()
            self.prefetch_queue.queue.clear()
            self.backup_queue.queue.clear()
            self.prefetch_revoke_queue.queue.clear()
            self.ack_backup_queue.queue.clear()
433
434

        self.write_thread = threading.Thread(
435
            target=self.write_thread_func_direct, daemon=True
436
437
        )
        self.load_thread = threading.Thread(
438
            target=self.load_thread_func_layer_by_layer, daemon=True
439
440
441
442
443
        )
        self.stop_event.clear()
        self.write_thread.start()
        self.load_thread.start()

444
445
446
447
448
449
450
451
452
453
        if self.enable_storage:
            self.prefetch_thread = threading.Thread(
                target=self.prefetch_thread_func, daemon=True
            )
            self.backup_thread = threading.Thread(
                target=self.backup_thread_func, daemon=True
            )
            self.prefetch_thread.start()
            self.backup_thread.start()

454
455
456
457
458
459
460
461
462
463
464
465
    def write(
        self,
        device_indices: torch.Tensor,
        priority: Optional[int] = None,
        node_id: int = 0,
    ) -> Optional[torch.Tensor]:
        """
        Back up KV caches from device memory to host memory.
        """
        host_indices = self.mem_pool_host.alloc(len(device_indices))
        if host_indices is None:
            return None
466
        self.mem_pool_host.protect_write(host_indices)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
467
        torch.cuda.current_stream().synchronize()
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        self.write_queue.put(
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return host_indices

    def load(
        self,
        host_indices: torch.Tensor,
        priority: Optional[int] = None,
        node_id: int = 0,
    ) -> Optional[torch.Tensor]:
        """
        Load KV caches from host memory to device memory.
        """
482
        device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
483
484
        if device_indices is None:
            return None
485
        self.mem_pool_host.protect_load(host_indices)
486
487
        # to ensure the device indices are ready before accessed by another CUDA stream
        torch.cuda.current_stream().synchronize()
488
489
490
491
492
        self.load_queue.put(
            CacheOperation(host_indices, device_indices, node_id, priority)
        )
        return device_indices

493
494
495
496
497
    def move_indices(self, host_indices, device_indices):
        # move indices to GPU if using kernels, to host if using direct indexing
        if self.io_backend == "kernel":
            return host_indices.to(self.mem_pool_device.device), device_indices
        elif self.io_backend == "direct":
498
499
500
            device_indices = device_indices.cpu()
            host_indices, idx = host_indices.sort()
            return host_indices, device_indices.index_select(0, idx)
501
502
503
        else:
            raise ValueError(f"Unsupported io backend")

504
505
506
507
    def write_thread_func_direct(self):
        """
        Directly write through KV caches to host memory without buffering.
        """
508
509
510
511
        torch.cuda.set_stream(self.write_stream)
        while not self.stop_event.is_set():
            try:
                operation = self.write_queue.get(block=True, timeout=1)
512
513
514
                host_indices, device_indices = self.move_indices(
                    operation.host_indices, operation.device_indices
                )
515
516
                self.mem_pool_host.backup_from_device_all_layer(
                    self.mem_pool_device, host_indices, device_indices, self.io_backend
517
518
519
520
521
522
523
524
525
526
                )
                self.write_stream.synchronize()
                self.mem_pool_host.complete_io(operation.host_indices)
                for node_id in operation.node_ids:
                    if node_id != 0:
                        self.ack_write_queue.put(node_id)
            except Empty:
                continue
            except Exception as e:
                logger.error(e)
527

528
529
530
531
    def load_thread_func_layer_by_layer(self):
        """
        Load KV caches from host memory to device memory layer by layer.
        """
532
533
534
535
536
537
        torch.cuda.set_stream(self.load_stream)
        while not self.stop_event.is_set():
            self.load_cache_event.wait(timeout=1)
            if not self.load_cache_event.is_set():
                continue
            self.load_cache_event.clear()
538
            self.layer_done_counter.update_producer()
539

540
541
542
            batch_operation = None
            while self.load_queue.qsize() > 0:
                op = self.load_queue.get(block=True)
543
                if batch_operation is None:
544
545
546
547
548
                    batch_operation = op
                else:
                    batch_operation.merge(op)
            if batch_operation is None:
                continue
549

550
            # start layer-wise KV cache transfer from CPU to GPU
551
            self.layer_done_counter.reset()
552
553
554
            host_indices, device_indices = self.move_indices(
                batch_operation.host_indices, batch_operation.device_indices
            )
555
            for i in range(self.mem_pool_host.layer_num):
556
557
                self.mem_pool_host.load_to_device_per_layer(
                    self.mem_pool_device,
558
559
560
561
562
563
                    host_indices,
                    device_indices,
                    i,
                    self.io_backend,
                )
                self.load_stream.synchronize()
564
565
566
567
568
569
                self.layer_done_counter.increment()

            self.mem_pool_host.complete_io(batch_operation.host_indices)
            for node_id in batch_operation.node_ids:
                if node_id != 0:
                    self.ack_load_queue.put(node_id)
570

571
572
573
574
    def evict_device(
        self, device_indices: torch.Tensor, host_indices: torch.Tensor
    ) -> int:
        if self.mem_pool_host.is_synced(host_indices):
575
            self.mem_pool_device_allocator.free(device_indices)
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
            self.mem_pool_host.update_backup(host_indices)
            return len(device_indices)
        else:
            raise ValueError(
                f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
            )

    def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
        if not backup_only:
            raise ValueError("Other eviction policies are not supported yet.")

        if self.mem_pool_host.is_backup(host_indices):
            self.mem_pool_host.free(host_indices)
            return len(host_indices)
        else:
            raise ValueError(
                f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
            )
594
595
596
597
598
599
600

    def prefetch(
        self,
        request_id: str,
        host_indices: torch.Tensor,
        new_input_tokens: List[int],
        last_hash: Optional[str] = None,
pansicheng's avatar
pansicheng committed
601
    ) -> PrefetchOperation:
602
603
604
605
606
607
608
609
610
611
612
613
614
        """
        Prefetch KV caches from storage backend to host memory.
        """
        operation = PrefetchOperation(
            request_id, host_indices, new_input_tokens, last_hash
        )
        self.prefetch_queue.put(operation)
        return operation

    def terminate_prefetch(self, operation):
        operation.mark_done()
        return operation.completed_tokens, operation.hash_value

615
616
617
618
619
    def append_host_mem_release(self, host_indices: torch.Tensor):
        chunks = host_indices.split(self.mem_pool_host.page_size)
        for chunk in chunks:
            self.host_mem_release_queue.put(chunk)

620
    def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
pansicheng's avatar
pansicheng committed
621
        hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
622
623
624
625
626
627
628
629
630
631
632
633
634
635
            hash_values, host_indices
        )
        page_data = self.storage_backend.batch_get(hashes, dsts)
        if page_data:
            operation.increment(self.page_size * len(hashes))
        else:
            logger.warning(
                f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
            )

    def _mooncake_page_get(self, operation, hash_values, host_indices):
        key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
            hash_values,
            host_indices,
636
            self.storage_config.tp_rank,
637
638
639
640
641
        )
        get_result = self.storage_backend.batch_get(
            key_strs,
            target_location=buffer_ptrs,
            target_sizes=buffer_sizes,
pansicheng's avatar
pansicheng committed
642
        )
643
644
645
646
647
648
649
650
        if get_result != len(hash_values):
            logger.warning(
                f"Prefetch operation {operation.request_id} failed or partially failed."
            )
        if get_result != 0:
            operation.increment(get_result * self.page_size)

    def _generic_page_get(self, operation, hash_values, host_indices):
651
652
653
        dummy_page_dst = [
            self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
        ]
654
655
656
657
658
        page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
        if page_data is None:
            return
        for i in range(len(hash_values)):
            if page_data[i] is None:
pansicheng's avatar
pansicheng committed
659
                logger.warning(
660
                    f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
pansicheng's avatar
pansicheng committed
661
662
                )
                break
663
664
665
666
667
668
669
670
            # Must set the data before increasing the completed tokens.
            # Otherwise this page may be read before being set.
            self.mem_pool_host.set_from_flat_data_page(
                host_indices[i * self.page_size],
                page_data[i],
            )
            if not operation.increment(self.page_size):
                break  # Operation terminated by controller
671
672
673

    def _page_transfer(self, operation):
        # Select the get function and batch size
674
        if self.storage_backend_type == "mooncake":
675
            get_func = self._mooncake_page_get
676
677
678
679
680
        elif (
            self.storage_backend_type == "hf3fs"
            and self.mem_pool_host.layout == "page_first"
        ):
            get_func = self._3fs_zero_copy_page_get
681
682
        else:
            get_func = self._generic_page_get
pansicheng's avatar
pansicheng committed
683

684
        # Transfer batch by batch
685
686
        for i in range(0, len(operation.hash_value), self.storage_batch_size):
            batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
687
688
            batch_host_indices = operation.host_indices[
                i * self.page_size : (i + len(batch_hashes)) * self.page_size
pansicheng's avatar
pansicheng committed
689
            ]
690
691
692
693
694
695
696
697
698
699
            prev_completed_tokens = operation.completed_tokens
            # Get one batch token, and update the completed_tokens if succeed
            get_func(operation, batch_hashes, batch_host_indices)
            # Check termination
            if (
                operation.completed_tokens
                != prev_completed_tokens + len(batch_hashes) * self.page_size
            ):
                break  # Some operations fail or operation terminated by controller
        # release pre-allocated memory
700
701
702
        self.append_host_mem_release(
            operation.host_indices[operation.completed_tokens :]
        )
703

704
705
706
707
708
709
710
    def prefetch_io_aux_func(self):
        """
        Auxiliary function conducting IO operations for prefetching.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.prefetch_buffer.get(block=True, timeout=1)
711
                self._page_transfer(operation)
Zhiqiang Xie's avatar
Zhiqiang Xie committed
712
                # operation terminated by controller, release pre-allocated memory
713
                self.append_host_mem_release(
Zhiqiang Xie's avatar
Zhiqiang Xie committed
714
715
                    operation.host_indices[operation.completed_tokens :]
                )
716
717
718
            except Empty:
                continue

719
    def prefetch_rate_limited(self) -> bool:
pansicheng's avatar
pansicheng committed
720
721
722
723
724
        """
        Rate limit the prefetching operations to avoid overwhelming the storage backend.
        """
        # cancel prefetch if too much memory is occupied
        if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
725
            return True
pansicheng's avatar
pansicheng committed
726
        # todo: more sophisticated rate limiting based on storage backend performance
727
        return False
pansicheng's avatar
pansicheng committed
728

729
    def _storage_hit_query(self, operation) -> tuple[list[str], int]:
730
731
732
733
734
        last_hash = operation.last_hash
        tokens_to_fetch = operation.token_ids

        storage_query_count = 0
        hash_value = []
735
736
737
738
739
740

        for start in range(
            0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
        ):
            end = min(
                start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
741
            )
742
743
744
745
746
747
748
749
750
751
752
753
754
            batch_tokens = tokens_to_fetch[start:end]
            batch_hashes = []
            for i in range(0, len(batch_tokens), self.page_size):
                last_hash = self.get_hash_str(
                    batch_tokens[i : i + self.page_size], last_hash
                )
                batch_hashes.append(last_hash)
            hit_page_num = self.storage_backend.batch_exists(batch_hashes)
            hash_value.extend(batch_hashes[:hit_page_num])
            storage_query_count += hit_page_num * self.page_size
            if hit_page_num < len(batch_hashes):
                break
        return hash_value, storage_query_count
755

756
757
758
759
760
761
762
763
764
765
766
767
768
    def prefetch_thread_func(self):
        """
        Manage prefetching operations from storage backend to host memory.
        """
        self.prefetch_buffer = Queue()
        aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
        aux_thread.start()
        while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
            try:
                operation = self.prefetch_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

769
                hash_value, storage_hit_count = self._storage_hit_query(operation)
770
771
772
773
774
775
776
                if self.tp_world_size > 1:
                    storage_hit_count_tensor = torch.tensor(
                        storage_hit_count, dtype=torch.int
                    )
                    torch.distributed.all_reduce(
                        storage_hit_count_tensor,
                        op=torch.distributed.ReduceOp.MIN,
777
                        group=self.prefetch_tp_group,
778
779
780
                    )
                    storage_hit_count = storage_hit_count_tensor.item()

781
782
783
                if storage_hit_count < self.prefetch_threshold:
                    # not to prefetch if not enough benefits
                    self.prefetch_revoke_queue.put(operation.request_id)
784
                    self.append_host_mem_release(operation.host_indices)
785
786
787
                    logger.debug(
                        f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
                    )
788
                else:
789
790
791
792
                    operation.hash_value = hash_value[
                        : (storage_hit_count // self.page_size)
                    ]
                    # free the pre-allocated memory for pages that are not hit
793
794
795
                    self.append_host_mem_release(
                        operation.host_indices[storage_hit_count:]
                    )
796
                    operation.host_indices = operation.host_indices[:storage_hit_count]
797
                    logger.debug(
798
                        f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
799
800
801
802
803
804
805
806
807
808
                    )
                    self.prefetch_buffer.put(operation)

            except Empty:
                continue

    def write_storage(
        self,
        host_indices: torch.Tensor,
        token_ids: List[int],
809
        hash_value: Optional[List[str]] = None,
810
811
812
813
    ) -> int:
        """
        Write KV caches from host memory to storage backend.
        """
814
        operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
815
816
817
        self.backup_queue.put(operation)
        return operation.id

818
819
820
821
822
823
824
825
826
827
828
829
830
    # non-zero copy
    def _generic_page_set(self, hash_values, host_indices) -> bool:
        data = [
            self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
            for i in range(len(hash_values))
        ]
        return self.storage_backend.batch_set(hash_values, data)

    # zero copy
    def _mooncake_page_set(self, hash_values, host_indices) -> bool:
        key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
            hash_values,
            host_indices,
831
            self.storage_config.tp_rank,
pansicheng's avatar
pansicheng committed
832
        )
833
834
835
836
837
838
        success = self.storage_backend.batch_set(
            key_strs,
            target_location=buffer_ptrs,
            target_sizes=buffer_sizes,
        )
        return success
pansicheng's avatar
pansicheng committed
839

840
841
842
843
844
845
846
847
848
849
    # zero copy
    def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
        hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
            hash_values, host_indices
        )
        return self.storage_backend.batch_set(hashes, dsts)

    # Backup batch by batch
    def _page_backup(self, operation):
        # Select the set function and batch size
850
        if self.storage_backend_type == "mooncake":
851
            backup_set_func = self._mooncake_page_set
852
853
854
855
856
        elif (
            self.storage_backend_type == "hf3fs"
            and self.mem_pool_host.layout == "page_first"
        ):
            backup_set_func = self._3fs_zero_copy_page_set
857
858
859
        else:
            backup_set_func = self._generic_page_set
        # Backup batch by batch
860
861
        for i in range(0, len(operation.hash_value), self.storage_batch_size):
            batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
862
863
            batch_host_indices = operation.host_indices[
                i * self.page_size : (i + len(batch_hashes)) * self.page_size
864
            ]
865
866
867
            # Set one batch token, and record if success.
            # todo: allow partial success
            success = backup_set_func(batch_hashes, batch_host_indices)
868
            if not success:
869
870
                logger.warning(
                    f"Write page to storage: {len(batch_hashes)} pages failed."
871
                )
872
873
                break
            operation.completed_tokens += self.page_size * len(batch_hashes)
874

875
876
877
878
879
880
881
882
883
884
    def backup_thread_func(self):
        """
        Manage backup operations from host memory to storage backend.
        """
        while not self.stop_event.is_set():
            try:
                operation = self.backup_queue.get(block=True, timeout=1)
                if operation is None:
                    continue

885
                if not self.backup_skip:
886
                    self._page_backup(operation)
887
                self.ack_backup_queue.put(operation.id)
888
889
890

            except Empty:
                continue