memory_pool.py 26.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17

18
19
20
21
"""
Memory pool.

SGLang has two levels of memory pool.
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
22
ReqToTokenPool maps a request to its token locations.
23
24
TokenToKVPoolAllocator manages the indices to kv cache data.
KVCache actually holds the physical kv cache.
25
"""
26

27
import abc
Lianmin Zheng's avatar
Lianmin Zheng committed
28
import logging
29
30
31
import threading
from enum import IntEnum
from functools import wraps
32
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import numpy as np
35
import psutil
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
import torch

38
from sglang.srt.layers.radix_attention import RadixAttention
39
from sglang.srt.utils import debug_timing, get_compiler_backend
40

Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
logger = logging.getLogger(__name__)

43
44
GB = 1024 * 1024 * 1024

Lianmin Zheng's avatar
Lianmin Zheng committed
45
46

class ReqToTokenPool:
Mingyi's avatar
Mingyi committed
47
48
    """A memory pool that maps a request to its token locations."""

49
50
51
52
53
54
55
56
57
58
59
    def __init__(
        self,
        size: int,
        max_context_len: int,
        device: str,
        enable_memory_saver: bool,
    ):
        memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )

Ying Sheng's avatar
Ying Sheng committed
60
        self.size = size
61
        self.max_context_len = max_context_len
62
        self.device = device
63
64
65
66
        with memory_saver_adapter.region():
            self.req_to_token = torch.zeros(
                (size, max_context_len), dtype=torch.int32, device=device
            )
67
68
        self.free_slots = list(range(size))

69
    def write(self, indices, values):
70
        self.req_to_token[indices] = values
71

72
73
    def available_size(self):
        return len(self.free_slots)
Lianmin Zheng's avatar
Lianmin Zheng committed
74

75
76
    def alloc(self, need_size: int) -> List[int]:
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
            return None

79
80
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
Liangsheng Yin's avatar
Liangsheng Yin committed
81

Mingyi's avatar
Mingyi committed
82
        return select_index
Lianmin Zheng's avatar
Lianmin Zheng committed
83

84
    def free(self, free_index: Union[int, List[int]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        if isinstance(free_index, (int,)):
86
            self.free_slots.append(free_index)
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        else:
88
            self.free_slots.extend(free_index)
Liangsheng Yin's avatar
Liangsheng Yin committed
89

Lianmin Zheng's avatar
Lianmin Zheng committed
90
    def clear(self):
91
        self.free_slots = list(range(self.size))
92

Lianmin Zheng's avatar
Lianmin Zheng committed
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class KVCache(abc.ABC):

    @abc.abstractmethod
    def get_key_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    @abc.abstractmethod
    def get_value_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    @abc.abstractmethod
    def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError()

    @abc.abstractmethod
    def set_kv_buffer(
        self,
        layer: RadixAttention,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ) -> None:
        raise NotImplementedError()

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    @abc.abstractmethod
    def get_flat_data(self, indices):
        raise NotImplementedError()

    @abc.abstractmethod
    def transfer(self, indices, flat_data):
        raise NotImplementedError()

    @abc.abstractmethod
    def transfer_per_layer(self, indices, flat_data, layer_id):
        raise NotImplementedError()

    def register_layer_transfer_counter(self, layer_transfer_counter):
        self.layer_transfer_counter = layer_transfer_counter

133

134
class TokenToKVPoolAllocator:
135
    """An allocator managing the indices to kv cache data."""
Mingyi's avatar
Mingyi committed
136

zhyncs's avatar
zhyncs committed
137
138
139
    def __init__(
        self,
        size: int,
140
        dtype: torch.dtype,
141
        device: str,
142
        kvcache: KVCache,
zhyncs's avatar
zhyncs committed
143
    ):
144
        self.size = size
145
        self.dtype = dtype
146
        self.device = device
Lianmin Zheng's avatar
Lianmin Zheng committed
147
        self.page_size = 1
Liangsheng Yin's avatar
Liangsheng Yin committed
148

149
        self.free_slots = None
150
151
        self.is_not_in_free_group = True
        self.free_group = []
152
153
        self.clear()

154
155
        self._kvcache = kvcache

Mingyi's avatar
Mingyi committed
156
    def available_size(self):
157
        return len(self.free_slots)
Mingyi's avatar
Mingyi committed
158

159
160
161
    def get_kvcache(self):
        return self._kvcache

162
    def alloc(self, need_size: int):
163
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
164
165
            return None

166
167
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
Lianmin Zheng's avatar
Lianmin Zheng committed
168
        return select_index
Lianmin Zheng's avatar
Lianmin Zheng committed
169

Mingyi's avatar
Mingyi committed
170
    def free(self, free_index: torch.Tensor):
171
172
173
        if free_index.numel() == 0:
            return

174
        if self.is_not_in_free_group:
175
            self.free_slots = torch.cat((self.free_slots, free_index))
176
177
178
179
180
181
182
183
184
185
        else:
            self.free_group.append(free_index)

    def free_group_begin(self):
        self.is_not_in_free_group = False
        self.free_group = []

    def free_group_end(self):
        self.is_not_in_free_group = True
        if self.free_group:
186
            self.free(torch.cat(self.free_group))
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188

    def clear(self):
189
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191
192
        self.free_slots = torch.arange(
            1, self.size + 1, dtype=torch.int64, device=self.device
        )
193
194
        self.is_in_free_group = False
        self.free_group = []
195

196
197

class MHATokenToKVPool(KVCache):
198
199
200
201

    def __init__(
        self,
        size: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
202
        page_size: int,
203
204
205
206
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
207
        device: str,
208
        enable_memory_saver: bool,
209
    ):
210
        self.size = size
Lianmin Zheng's avatar
Lianmin Zheng committed
211
        self.page_size = page_size
212
213
214
215
216
217
218
        self.dtype = dtype
        self.device = device
        if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
            # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
            self.store_dtype = torch.uint8
        else:
            self.store_dtype = dtype
219
220
221
222
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )

223
224
225
226
        self.head_num = head_num
        self.head_dim = head_dim
        self.layer_num = layer_num
        self._create_buffers()
227

228
        self.layer_transfer_counter = None
Lianmin Zheng's avatar
Lianmin Zheng committed
229
        self.capture_mode = False
230
231
        self.device_module = torch.get_device_module(self.device)
        self.alt_stream = self.device_module.Stream()
232

233
234
        k_size, v_size = self.get_kv_size_bytes()
        logger.info(
235
            f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
236
237
        )

238
    def _create_buffers(self):
239
240
241
242
        with self.memory_saver_adapter.region():
            # [size, head_num, head_dim] for each layer
            # The padded slot 0 is used for writing dummy outputs from padded tokens.
            self.k_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
                torch.zeros(
                    (self.size + self.page_size, self.head_num, self.head_dim),
245
246
247
248
249
250
                    dtype=self.store_dtype,
                    device=self.device,
                )
                for _ in range(self.layer_num)
            ]
            self.v_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
251
252
                torch.zeros(
                    (self.size + self.page_size, self.head_num, self.head_dim),
253
254
255
256
257
                    dtype=self.store_dtype,
                    device=self.device,
                )
                for _ in range(self.layer_num)
            ]
258

259
260
261
262
    def _clear_buffers(self):
        del self.k_buffer
        del self.v_buffer

263
264
265
266
267
268
269
270
271
272
273
    def get_kv_size_bytes(self):
        assert hasattr(self, "k_buffer")
        assert hasattr(self, "v_buffer")
        k_size_bytes = 0
        for k_cache in self.k_buffer:
            k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
        v_size_bytes = 0
        for v_cache in self.v_buffer:
            v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
        return k_size_bytes, v_size_bytes

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    # Todo: different memory layout
    def get_flat_data(self, indices):
        # prepare a large chunk of contiguous data for efficient transfer
        flatten = torch.stack(
            [
                torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
                torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
            ]
        )
        return flatten

    @debug_timing
    def transfer(self, indices, flat_data):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        k_data, v_data = flat_data[0], flat_data[1]
        for i in range(self.layer_num):
            self.k_buffer[i][indices] = k_data[i]
            self.v_buffer[i][indices] = v_data[i]

294
295
296
297
298
299
300
    def transfer_per_layer(self, indices, flat_data, layer_id):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        k_data, v_data = flat_data[0], flat_data[1]
        self.k_buffer[layer_id][indices] = k_data
        self.v_buffer[layer_id][indices] = v_data

301
    def get_key_buffer(self, layer_id: int):
302
303
304
        if self.layer_transfer_counter is not None:
            self.layer_transfer_counter.wait_until(layer_id)

305
306
        if self.store_dtype != self.dtype:
            return self.k_buffer[layer_id].view(self.dtype)
307
308
309
        return self.k_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
310
311
312
        if self.layer_transfer_counter is not None:
            self.layer_transfer_counter.wait_until(layer_id)

313
314
        if self.store_dtype != self.dtype:
            return self.v_buffer[layer_id].view(self.dtype)
315
316
317
        return self.v_buffer[layer_id]

    def get_kv_buffer(self, layer_id: int):
318
319
320
321
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)

    def set_kv_buffer(
        self,
322
        layer: RadixAttention,
323
324
325
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
326
327
        k_scale: Optional[float] = None,
        v_scale: Optional[float] = None,
328
    ):
329
        layer_id = layer.layer_id
330
        if cache_k.dtype != self.dtype:
331
332
333
334
335
336
            if k_scale is not None:
                cache_k.div_(k_scale)
            if v_scale is not None:
                cache_v.div_(v_scale)
            cache_k = cache_k.to(self.dtype)
            cache_v = cache_v.to(self.dtype)
Lianmin Zheng's avatar
Lianmin Zheng committed
337

338
        if self.store_dtype != self.dtype:
Lianmin Zheng's avatar
Lianmin Zheng committed
339
340
341
            cache_k = cache_k.view(self.store_dtype)
            cache_v = cache_v.view(self.store_dtype)

342
        if self.capture_mode and cache_k.shape[0] < 4:
Lianmin Zheng's avatar
Lianmin Zheng committed
343
            # Overlap the copy of K and V cache for small batch size
344
345
346
            current_stream = self.device_module.current_stream()
            self.alt_stream.wait_stream(current_stream)
            with self.device_module.stream(self.alt_stream):
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
                self.k_buffer[layer_id][loc] = cache_k
            self.v_buffer[layer_id][loc] = cache_v
349
            current_stream.wait_stream(self.alt_stream)
350
351
352
        else:
            self.k_buffer[layer_id][loc] = cache_k
            self.v_buffer[layer_id][loc] = cache_v
353
354


Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
@torch.compile
def fused_downcast(
    cache_k: torch.Tensor,
    cache_v: torch.Tensor,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    dtype: torch.dtype,
    store_dtype: torch.dtype,
    max_fp8: float,
    min_fp8: float,
):
    cache_k = cache_k / k_scale
    cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
    cache_v = cache_v / v_scale
    cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
    cache_k = cache_k.to(dtype)
    cache_v = cache_v.to(dtype)
    cache_k = cache_k.view(store_dtype)
    cache_v = cache_v.view(store_dtype)
    return cache_k, cache_v


377
378
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
379
@torch.compile(dynamic=True, backend=get_compiler_backend())
380
381
382
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
    dst_1[loc] = src_1.to(dtype).view(store_dtype)
    dst_2[loc] = src_2.to(dtype).view(store_dtype)
383
384


385
class MLATokenToKVPool(KVCache):
386
387
388
    def __init__(
        self,
        size: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
389
        page_size: int,
390
391
392
393
        dtype: torch.dtype,
        kv_lora_rank: int,
        qk_rope_head_dim: int,
        layer_num: int,
394
        device: str,
395
        enable_memory_saver: bool,
396
    ):
397
398
399
400
401
402
403
404
        self.size = size
        self.dtype = dtype
        self.device = device
        if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
            # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
            self.store_dtype = torch.uint8
        else:
            self.store_dtype = dtype
405
        self.kv_lora_rank = kv_lora_rank
406
407
        self.qk_rope_head_dim = qk_rope_head_dim
        self.layer_num = layer_num
408
409
410
411
412
413
414
415

        memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )

        with memory_saver_adapter.region():
            # The padded slot 0 is used for writing dummy outputs from padded tokens.
            self.kv_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
416
417
                torch.zeros(
                    (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
418
419
420
421
422
                    dtype=self.store_dtype,
                    device=device,
                )
                for _ in range(layer_num)
            ]
423

424
425
        self.layer_transfer_counter = None

426
    def get_key_buffer(self, layer_id: int):
427
428
429
        if self.layer_transfer_counter is not None:
            self.layer_transfer_counter.wait_until(layer_id)

430
431
        if self.store_dtype != self.dtype:
            return self.kv_buffer[layer_id].view(self.dtype)
432
433
434
        return self.kv_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
435
436
437
        if self.layer_transfer_counter is not None:
            self.layer_transfer_counter.wait_until(layer_id)

438
439
        if self.store_dtype != self.dtype:
            return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
440
441
442
443
        return self.kv_buffer[layer_id][..., : self.kv_lora_rank]

    def get_kv_buffer(self, layer_id: int):
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
444
445
446

    def set_kv_buffer(
        self,
447
        layer: RadixAttention,
448
449
450
451
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
452
        layer_id = layer.layer_id
453
454
455
456
457
458
        if cache_k.dtype != self.dtype:
            cache_k = cache_k.to(self.dtype)
        if self.store_dtype != self.dtype:
            self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
        else:
            self.kv_buffer[layer_id][loc] = cache_k
Shuo Yang's avatar
Shuo Yang committed
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    def get_flat_data(self, indices):
        # prepare a large chunk of contiguous data for efficient transfer
        return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])

    @debug_timing
    def transfer(self, indices, flat_data):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        for i in range(self.layer_num):
            self.kv_buffer[i][indices] = flat_data[i]

    def transfer_per_layer(self, indices, flat_data, layer_id):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        self.kv_buffer[layer_id][indices] = flat_data

Shuo Yang's avatar
Shuo Yang committed
476

477
class DoubleSparseTokenToKVPool(KVCache):
Shuo Yang's avatar
Shuo Yang committed
478
479
480
    def __init__(
        self,
        size: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
481
        page_size: int,
Shuo Yang's avatar
Shuo Yang committed
482
483
484
485
486
487
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
        device: str,
        heavy_channel_num: int,
488
        enable_memory_saver: bool,
Shuo Yang's avatar
Shuo Yang committed
489
    ):
490
        self.size = size
Lianmin Zheng's avatar
Lianmin Zheng committed
491
        self.page_size = page_size
492
493
494
495
496
497
498
        self.dtype = dtype
        self.device = device
        if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
            # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
            self.store_dtype = torch.uint8
        else:
            self.store_dtype = dtype
499
500
501
502
503
504
505
        memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )

        with memory_saver_adapter.region():
            # [size, head_num, head_dim] for each layer
            self.k_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
506
507
508
                torch.zeros(
                    (size + page_size, head_num, head_dim), dtype=dtype, device=device
                )
509
510
511
                for _ in range(layer_num)
            ]
            self.v_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
514
                torch.zeros(
                    (size + page_size, head_num, head_dim), dtype=dtype, device=device
                )
515
516
517
518
519
                for _ in range(layer_num)
            ]

            # [size, head_num, heavy_channel_num] for each layer
            self.label_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
520
                torch.zeros(
521
522
523
524
                    (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
                )
                for _ in range(layer_num)
            ]
Shuo Yang's avatar
Shuo Yang committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539

    def get_key_buffer(self, layer_id: int):
        return self.k_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
        return self.v_buffer[layer_id]

    def get_label_buffer(self, layer_id: int):
        return self.label_buffer[layer_id]

    def get_kv_buffer(self, layer_id: int):
        return self.k_buffer[layer_id], self.v_buffer[layer_id]

    def set_kv_buffer(
        self,
540
        layer: RadixAttention,
Shuo Yang's avatar
Shuo Yang committed
541
542
543
544
545
546
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        cache_label: torch.Tensor,
    ):
        # NOTE(Andy): ignore the dtype check
547
        layer_id = layer.layer_id
Shuo Yang's avatar
Shuo Yang committed
548
549
550
        self.k_buffer[layer_id][loc] = cache_k
        self.v_buffer[layer_id][loc] = cache_v
        self.label_buffer[layer_id][loc] = cache_label
551

552
553
554
555
556
557
558
559
560
    def get_flat_data(self, indices):
        pass

    def transfer(self, indices, flat_data):
        pass

    def transfer_per_layer(self, indices, flat_data, layer_id):
        pass

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

class MemoryStateInt(IntEnum):
    IDLE = 0
    RESERVED = 1
    PROTECTED = 2
    SYNCED = 3
    BACKUP = 4


def synchronized(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        with self.lock:
            return func(self, *args, **kwargs)

    return wrapper


579
class HostKVCache(abc.ABC):
580
581
582
583

    def __init__(
        self,
        device_pool: MHATokenToKVPool,
584
        host_to_device_ratio: float,
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        pin_memory: bool = False,  # no need to use pin memory with the double buffering
        device: str = "cpu",
    ):
        assert (
            host_to_device_ratio >= 1
        ), "The host memory should be larger than the device memory with the current protocol"
        # todo, other ways of configuring the size

        self.device_pool = device_pool
        self.host_to_device_ratio = host_to_device_ratio
        self.pin_memory = pin_memory
        self.device = device

        self.size = int(device_pool.size * host_to_device_ratio)
        self.dtype = device_pool.store_dtype
600
        self.size_per_token = self.get_size_per_token()
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

        # Verify there is enough available host memory.
        host_mem = psutil.virtual_memory()
        requested_bytes = self.size * self.size_per_token
        # preserve at least 10GB for other usage
        ten_gb = 10 * (1024**3)
        if requested_bytes > host_mem.available - ten_gb:
            raise ValueError(
                f"Not enough host memory available. Requesting "
                f"{requested_bytes / 1e9:.2f} GB but only have "
                f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
                f"size of the hierarchical cache."
            )
        else:
            logger.info(
                f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
            )

619
        self.kv_buffer = self.init_kv_buffer()
620
621
622
623
624
625
626
627

        # Initialize memory states and tracking structures.
        self.mem_state = torch.zeros(
            (self.size,), dtype=torch.uint8, device=self.device
        )

        # A lock for synchronized operations on memory allocation and state transitions.
        self.lock = threading.RLock()
628
        self.clear()
629

630
631
632
633
634
635
636
637
638
639
640
641
642
    @abc.abstractmethod
    def get_size_per_token(self):
        raise NotImplementedError()

    @abc.abstractmethod
    def init_kv_buffer(self):
        raise NotImplementedError()

    @abc.abstractmethod
    def transfer(self, indices, flat_data):
        raise NotImplementedError()

    @abc.abstractmethod
643
    def get_flat_data(self, indices):
644
        raise NotImplementedError()
645

646
    @abc.abstractmethod
647
    def get_flat_data_by_layer(self, indices, layer_id):
648
        raise NotImplementedError()
649

650
    @abc.abstractmethod
651
    def assign_flat_data(self, indices, flat_data):
652
        raise NotImplementedError()
653
654
655
656
657

    @synchronized
    def clear(self):
        self.mem_state.fill_(0)
        self.can_use_mem_size = self.size
658
        self.free_slots = torch.arange(self.size, dtype=torch.int64)
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740

    @synchronized
    def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
        assert len(indices) > 0, "The indices should not be empty"
        states = self.mem_state[indices]
        assert (
            states == states[0]
        ).all(), "The memory slots should have the same state {}".format(states)
        return MemoryStateInt(states[0].item())

    @synchronized
    def alloc(self, need_size: int) -> torch.Tensor:
        if need_size > self.can_use_mem_size:
            return None

        # todo: de-fragementation
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]

        self.mem_state[select_index] = MemoryStateInt.RESERVED
        self.can_use_mem_size -= need_size

        return select_index

    @synchronized
    def is_reserved(self, indices: torch.Tensor) -> bool:
        return self.get_state(indices) == MemoryStateInt.RESERVED

    @synchronized
    def is_protected(self, indices: torch.Tensor) -> bool:
        return self.get_state(indices) == MemoryStateInt.PROTECTED

    @synchronized
    def is_synced(self, indices: torch.Tensor) -> bool:
        return self.get_state(indices) == MemoryStateInt.SYNCED

    @synchronized
    def is_backup(self, indices: torch.Tensor) -> bool:
        return self.get_state(indices) == MemoryStateInt.BACKUP

    @synchronized
    def update_backup(self, indices: torch.Tensor):
        assert self.is_synced(indices), (
            f"The host memory slots should be in SYNCED state before turning into BACKUP. "
            f"Current state: {self.get_state(indices)}"
        )
        self.mem_state[indices] = MemoryStateInt.BACKUP

    @synchronized
    def update_synced(self, indices: torch.Tensor):
        self.mem_state[indices] = MemoryStateInt.SYNCED

    @synchronized
    def protect_write(self, indices: torch.Tensor):
        assert self.is_reserved(indices), (
            f"The host memory slots should be RESERVED before write operations. "
            f"Current state: {self.get_state(indices)}"
        )
        self.mem_state[indices] = MemoryStateInt.PROTECTED

    @synchronized
    def protect_load(self, indices: torch.Tensor):
        assert self.is_backup(indices), (
            f"The host memory slots should be in BACKUP state before load operations. "
            f"Current state: {self.get_state(indices)}"
        )
        self.mem_state[indices] = MemoryStateInt.PROTECTED

    @synchronized
    def complete_io(self, indices: torch.Tensor):
        assert self.is_protected(indices), (
            f"The host memory slots should be PROTECTED during I/O operations. "
            f"Current state: {self.get_state(indices)}"
        )
        self.mem_state[indices] = MemoryStateInt.SYNCED

    def available_size(self):
        return len(self.free_slots)

    @synchronized
    def free(self, indices: torch.Tensor) -> int:
        self.mem_state[indices] = MemoryStateInt.IDLE
741
        self.free_slots = torch.cat([self.free_slots, indices])
742
743
        self.can_use_mem_size += len(indices)
        return len(indices)
744
745
746
747
748
749


class MHATokenToKVPoolHost(HostKVCache):
    def __init__(
        self,
        device_pool: MHATokenToKVPool,
750
        host_to_device_ratio: float,
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
        pin_memory: bool = False,  # no need to use pin memory with the double buffering
        device: str = "cpu",
    ):
        super().__init__(device_pool, host_to_device_ratio, pin_memory, device)

    def get_size_per_token(self):
        self.head_num = self.device_pool.head_num
        self.head_dim = self.device_pool.head_dim
        self.layer_num = self.device_pool.layer_num

        return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2

    def init_kv_buffer(self):
        return torch.empty(
            (2, self.layer_num, self.size, self.head_num, self.head_dim),
            dtype=self.dtype,
            device=self.device,
            pin_memory=self.pin_memory,
        )

    @debug_timing
    def transfer(self, indices, flat_data):
        # backup prepared data from device to host
        self.kv_buffer[:, :, indices] = flat_data.to(
            device=self.device, non_blocking=False
        )

    def get_flat_data(self, indices):
        return self.kv_buffer[:, :, indices]

    def get_flat_data_by_layer(self, indices, layer_id):
        return self.kv_buffer[:, layer_id, indices]

    def assign_flat_data(self, indices, flat_data):
        self.kv_buffer[:, :, indices] = flat_data


class MLATokenToKVPoolHost(HostKVCache):
    def __init__(
        self,
        device_pool: MLATokenToKVPool,
792
        host_to_device_ratio: float,
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        pin_memory: bool = False,  # no need to use pin memory with the double buffering
        device: str = "cpu",
    ):
        super().__init__(device_pool, host_to_device_ratio, pin_memory, device)

    def get_size_per_token(self):
        self.kv_lora_rank = self.device_pool.kv_lora_rank
        self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
        self.layer_num = self.device_pool.layer_num

        return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize

    def init_kv_buffer(self):
        return torch.empty(
            (
                self.layer_num,
                self.size,
                1,
                self.kv_lora_rank + self.qk_rope_head_dim,
            ),
            dtype=self.dtype,
            device=self.device,
            pin_memory=self.pin_memory,
        )

    @debug_timing
    def transfer(self, indices, flat_data):
        # backup prepared data from device to host
        self.kv_buffer[:, indices] = flat_data.to(
            device=self.device, non_blocking=False
        )

    def get_flat_data(self, indices):
        return self.kv_buffer[:, indices]

    def get_flat_data_by_layer(self, indices, layer_id):
        return self.kv_buffer[layer_id, indices]

    def assign_flat_data(self, indices, flat_data):
        self.kv_buffer[:, indices] = flat_data