memory_pool.py 23.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17

18
19
20
21
"""
Memory pool.

SGLang has two levels of memory pool.
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
22
ReqToTokenPool maps a request to its token locations.
23
24
TokenToKVPoolAllocator manages the indices to kv cache data.
KVCache actually holds the physical kv cache.
25
"""
26

27
import abc
Lianmin Zheng's avatar
Lianmin Zheng committed
28
import logging
29
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
32
import torch
Ke Bao's avatar
Ke Bao committed
33
34
import triton
import triton.language as tl
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
from sglang.srt.layers.radix_attention import RadixAttention
37
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
38

Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
logger = logging.getLogger(__name__)

41
GB = 1024 * 1024 * 1024
42
_is_cuda = is_cuda()
43

Lianmin Zheng's avatar
Lianmin Zheng committed
44
45

class ReqToTokenPool:
Mingyi's avatar
Mingyi committed
46
47
    """A memory pool that maps a request to its token locations."""

48
49
50
51
52
53
54
55
56
57
58
    def __init__(
        self,
        size: int,
        max_context_len: int,
        device: str,
        enable_memory_saver: bool,
    ):
        memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )

Ying Sheng's avatar
Ying Sheng committed
59
        self.size = size
60
        self.max_context_len = max_context_len
61
        self.device = device
62
63
64
65
        with memory_saver_adapter.region():
            self.req_to_token = torch.zeros(
                (size, max_context_len), dtype=torch.int32, device=device
            )
66
67
        self.free_slots = list(range(size))

68
    def write(self, indices, values):
69
        self.req_to_token[indices] = values
70

71
72
    def available_size(self):
        return len(self.free_slots)
Lianmin Zheng's avatar
Lianmin Zheng committed
73

74
75
    def alloc(self, need_size: int) -> List[int]:
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
            return None

78
79
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
Liangsheng Yin's avatar
Liangsheng Yin committed
80

Mingyi's avatar
Mingyi committed
81
        return select_index
Lianmin Zheng's avatar
Lianmin Zheng committed
82

83
    def free(self, free_index: Union[int, List[int]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
84
        if isinstance(free_index, (int,)):
85
            self.free_slots.append(free_index)
Lianmin Zheng's avatar
Lianmin Zheng committed
86
        else:
87
            self.free_slots.extend(free_index)
Liangsheng Yin's avatar
Liangsheng Yin committed
88

Lianmin Zheng's avatar
Lianmin Zheng committed
89
    def clear(self):
90
        self.free_slots = list(range(self.size))
91

Lianmin Zheng's avatar
Lianmin Zheng committed
92

93
class KVCache(abc.ABC):
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    @abc.abstractmethod
    def __init__(
        self,
        size: int,
        page_size: int,
        dtype: torch.dtype,
        layer_num: int,
        device: str,
        enable_memory_saver: bool,
        start_layer: Optional[int] = None,
        end_layer: Optional[int] = None,
    ):
        self.size = size
        self.page_size = page_size
        self.dtype = dtype
        self.device = device
        if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
            # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
            self.store_dtype = torch.uint8
        else:
            self.store_dtype = dtype
        self.layer_num = layer_num
        self.start_layer = start_layer or 0
        self.end_layer = end_layer or layer_num - 1
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    @abc.abstractmethod
    def get_key_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    @abc.abstractmethod
    def get_value_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    @abc.abstractmethod
    def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError()

    @abc.abstractmethod
    def set_kv_buffer(
        self,
        layer: RadixAttention,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ) -> None:
        raise NotImplementedError()

144
145
146
147
148
149
150
151
152
153
154
155
    def get_flat_data(self, indices):
        raise NotImplementedError()

    def transfer(self, indices, flat_data):
        raise NotImplementedError()

    def transfer_per_layer(self, indices, flat_data, layer_id):
        raise NotImplementedError()

    def register_layer_transfer_counter(self, layer_transfer_counter):
        self.layer_transfer_counter = layer_transfer_counter

156

157
class TokenToKVPoolAllocator:
158
    """An allocator managing the indices to kv cache data."""
Mingyi's avatar
Mingyi committed
159

zhyncs's avatar
zhyncs committed
160
161
162
    def __init__(
        self,
        size: int,
163
        dtype: torch.dtype,
164
        device: str,
165
        kvcache: KVCache,
zhyncs's avatar
zhyncs committed
166
    ):
167
        self.size = size
168
        self.dtype = dtype
169
        self.device = device
Lianmin Zheng's avatar
Lianmin Zheng committed
170
        self.page_size = 1
Liangsheng Yin's avatar
Liangsheng Yin committed
171

172
        self.free_slots = None
173
174
        self.is_not_in_free_group = True
        self.free_group = []
175
176
        self.clear()

177
178
        self._kvcache = kvcache

Mingyi's avatar
Mingyi committed
179
    def available_size(self):
180
        return len(self.free_slots)
Mingyi's avatar
Mingyi committed
181

182
183
184
    def debug_print(self) -> str:
        return ""

185
186
187
    def get_kvcache(self):
        return self._kvcache

188
    def alloc(self, need_size: int):
189
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191
            return None

192
193
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
Lianmin Zheng's avatar
Lianmin Zheng committed
194
        return select_index
Lianmin Zheng's avatar
Lianmin Zheng committed
195

Mingyi's avatar
Mingyi committed
196
    def free(self, free_index: torch.Tensor):
197
198
199
        if free_index.numel() == 0:
            return

200
        if self.is_not_in_free_group:
201
            self.free_slots = torch.cat((self.free_slots, free_index))
202
203
204
205
206
207
208
209
210
211
        else:
            self.free_group.append(free_index)

    def free_group_begin(self):
        self.is_not_in_free_group = False
        self.free_group = []

    def free_group_end(self):
        self.is_not_in_free_group = True
        if self.free_group:
212
            self.free(torch.cat(self.free_group))
Lianmin Zheng's avatar
Lianmin Zheng committed
213

214
215
216
217
218
219
    def backup_state(self):
        return self.free_slots

    def restore_state(self, free_slots):
        self.free_slots = free_slots

Lianmin Zheng's avatar
Lianmin Zheng committed
220
    def clear(self):
221
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
224
        self.free_slots = torch.arange(
            1, self.size + 1, dtype=torch.int64, device=self.device
        )
225
        self.is_not_in_free_group = True
226
        self.free_group = []
227

228
229
230
231
232
233
    def get_cpu_copy(self, indices):
        return self._kvcache.get_cpu_copy(indices)

    def load_cpu_copy(self, kv_cache_cpu, indices):
        return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)

234
235

class MHATokenToKVPool(KVCache):
236
237
238
239

    def __init__(
        self,
        size: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
240
        page_size: int,
241
242
243
244
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
245
        device: str,
246
        enable_memory_saver: bool,
247
248
        start_layer: Optional[int] = None,
        end_layer: Optional[int] = None,
249
    ):
250
251
252
253
254
255
256
257
258
        super().__init__(
            size,
            page_size,
            dtype,
            layer_num,
            device,
            enable_memory_saver,
            start_layer,
            end_layer,
259
260
        )

261
262
263
        self.head_num = head_num
        self.head_dim = head_dim
        self._create_buffers()
264

265
266
        # used for chunked cpu-offloading
        self.chunk_size = 8192
267
        self.layer_transfer_counter = None
268
        self.device_module = torch.get_device_module(self.device)
sogalin's avatar
sogalin committed
269
        self.alt_stream = self.device_module.Stream() if _is_cuda else None
270

271
272
        k_size, v_size = self.get_kv_size_bytes()
        logger.info(
273
            f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
274
275
        )

276
    def _create_buffers(self):
277
278
279
280
        with self.memory_saver_adapter.region():
            # [size, head_num, head_dim] for each layer
            # The padded slot 0 is used for writing dummy outputs from padded tokens.
            self.k_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
                torch.zeros(
                    (self.size + self.page_size, self.head_num, self.head_dim),
283
284
285
286
287
288
                    dtype=self.store_dtype,
                    device=self.device,
                )
                for _ in range(self.layer_num)
            ]
            self.v_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
289
290
                torch.zeros(
                    (self.size + self.page_size, self.head_num, self.head_dim),
291
292
293
294
295
                    dtype=self.store_dtype,
                    device=self.device,
                )
                for _ in range(self.layer_num)
            ]
296

297
298
299
300
    def _clear_buffers(self):
        del self.k_buffer
        del self.v_buffer

301
302
303
304
305
306
307
308
309
310
311
    def get_kv_size_bytes(self):
        assert hasattr(self, "k_buffer")
        assert hasattr(self, "v_buffer")
        k_size_bytes = 0
        for k_cache in self.k_buffer:
            k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
        v_size_bytes = 0
        for v_cache in self.v_buffer:
            v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
        return k_size_bytes, v_size_bytes

Byron Hsu's avatar
Byron Hsu committed
312
313
    # for disagg
    def get_contiguous_buf_infos(self):
314
315
        # layer_num x [seq_len, head_num, head_dim]
        # layer_num x [page_num, page_size, head_num, head_dim]
Byron Hsu's avatar
Byron Hsu committed
316
        kv_data_ptrs = [
317
318
319
320
321
322
            self.get_key_buffer(i).data_ptr()
            for i in range(self.start_layer, self.start_layer + self.layer_num)
        ] + [
            self.get_value_buffer(i).data_ptr()
            for i in range(self.start_layer, self.start_layer + self.layer_num)
        ]
Byron Hsu's avatar
Byron Hsu committed
323
        kv_data_lens = [
324
325
326
327
328
329
            self.get_key_buffer(i).nbytes
            for i in range(self.start_layer, self.start_layer + self.layer_num)
        ] + [
            self.get_value_buffer(i).nbytes
            for i in range(self.start_layer, self.start_layer + self.layer_num)
        ]
Byron Hsu's avatar
Byron Hsu committed
330
        kv_item_lens = [
Byron Hsu's avatar
Byron Hsu committed
331
            self.get_key_buffer(i)[0].nbytes * self.page_size
332
            for i in range(self.start_layer, self.start_layer + self.layer_num)
Byron Hsu's avatar
Byron Hsu committed
333
334
        ] + [
            self.get_value_buffer(i)[0].nbytes * self.page_size
335
            for i in range(self.start_layer, self.start_layer + self.layer_num)
Byron Hsu's avatar
Byron Hsu committed
336
        ]
Byron Hsu's avatar
Byron Hsu committed
337
338
        return kv_data_ptrs, kv_data_lens, kv_item_lens

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    def get_cpu_copy(self, indices):
        torch.cuda.synchronize()
        kv_cache_cpu = []
        for layer_id in range(self.layer_num):
            kv_cache_cpu.append([])
            for i in range(0, len(indices), self.chunk_size):
                chunk_indices = indices[i : i + self.chunk_size]
                k_cpu = self.k_buffer[layer_id][chunk_indices].to(
                    "cpu", non_blocking=True
                )
                v_cpu = self.v_buffer[layer_id][chunk_indices].to(
                    "cpu", non_blocking=True
                )
                kv_cache_cpu[-1].append([k_cpu, v_cpu])
        torch.cuda.synchronize()
        return kv_cache_cpu

    def load_cpu_copy(self, kv_cache_cpu, indices):
        torch.cuda.synchronize()
        for layer_id in range(self.layer_num):
            for i in range(0, len(indices), self.chunk_size):
                chunk_indices = indices[i : i + self.chunk_size]
                k_cpu, v_cpu = (
                    kv_cache_cpu[layer_id][i // self.chunk_size][0],
                    kv_cache_cpu[layer_id][i // self.chunk_size][1],
                )
                assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
                k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
                v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
                self.k_buffer[layer_id][chunk_indices] = k_chunk
                self.v_buffer[layer_id][chunk_indices] = v_chunk
        torch.cuda.synchronize()

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    # Todo: different memory layout
    def get_flat_data(self, indices):
        # prepare a large chunk of contiguous data for efficient transfer
        flatten = torch.stack(
            [
                torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
                torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
            ]
        )
        return flatten

    @debug_timing
    def transfer(self, indices, flat_data):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        k_data, v_data = flat_data[0], flat_data[1]
        for i in range(self.layer_num):
            self.k_buffer[i][indices] = k_data[i]
            self.v_buffer[i][indices] = v_data[i]

392
393
394
395
    def transfer_per_layer(self, indices, flat_data, layer_id):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        k_data, v_data = flat_data[0], flat_data[1]
396
397
        self.k_buffer[layer_id - self.start_layer][indices] = k_data
        self.v_buffer[layer_id - self.start_layer][indices] = v_data
398

399
    def get_key_buffer(self, layer_id: int):
400
        if self.layer_transfer_counter is not None:
401
            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
402

403
        if self.store_dtype != self.dtype:
404
405
            return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
        return self.k_buffer[layer_id - self.start_layer]
406
407

    def get_value_buffer(self, layer_id: int):
408
        if self.layer_transfer_counter is not None:
409
            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
410

411
        if self.store_dtype != self.dtype:
412
413
            return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
        return self.v_buffer[layer_id - self.start_layer]
414
415

    def get_kv_buffer(self, layer_id: int):
416
417
418
419
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)

    def set_kv_buffer(
        self,
420
        layer: RadixAttention,
421
422
423
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
424
425
        k_scale: Optional[float] = None,
        v_scale: Optional[float] = None,
426
    ):
427
428
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode

429
        layer_id = layer.layer_id
430
        if cache_k.dtype != self.dtype:
431
432
433
434
435
436
            if k_scale is not None:
                cache_k.div_(k_scale)
            if v_scale is not None:
                cache_v.div_(v_scale)
            cache_k = cache_k.to(self.dtype)
            cache_v = cache_v.to(self.dtype)
Lianmin Zheng's avatar
Lianmin Zheng committed
437

438
        if self.store_dtype != self.dtype:
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
441
            cache_k = cache_k.view(self.store_dtype)
            cache_v = cache_v.view(self.store_dtype)

442
        if get_is_capture_mode() and self.alt_stream is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
443
            # Overlap the copy of K and V cache for small batch size
444
445
            current_stream = self.device_module.current_stream()
            self.alt_stream.wait_stream(current_stream)
Ke Bao's avatar
Ke Bao committed
446
            self.k_buffer[layer_id - self.start_layer][loc] = cache_k
447
            with self.device_module.stream(self.alt_stream):
Ke Bao's avatar
Ke Bao committed
448
                self.v_buffer[layer_id - self.start_layer][loc] = cache_v
449
            current_stream.wait_stream(self.alt_stream)
450
        else:
451
452
            self.k_buffer[layer_id - self.start_layer][loc] = cache_k
            self.v_buffer[layer_id - self.start_layer][loc] = cache_v
453
454


Ke Bao's avatar
Ke Bao committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
@triton.jit
def set_mla_kv_buffer_kernel(
    kv_buffer_ptr,
    cache_k_nope_ptr,
    cache_k_rope_ptr,
    loc_ptr,
    buffer_stride: tl.constexpr,
    nope_stride: tl.constexpr,
    rope_stride: tl.constexpr,
    nope_dim: tl.constexpr,
    rope_dim: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid_loc = tl.program_id(0)
    pid_blk = tl.program_id(1)

    base = pid_blk * BLOCK
    offs = base + tl.arange(0, BLOCK)
    total_dim = nope_dim + rope_dim
    mask = offs < total_dim

    loc = tl.load(loc_ptr + pid_loc)
    dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs

    if base + BLOCK <= nope_dim:
        src = tl.load(
            cache_k_nope_ptr + pid_loc * nope_stride + offs,
            mask=mask,
        )
    else:
        offs_rope = offs - nope_dim
        src = tl.load(
            cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
            mask=mask,
        )

    tl.store(dst_ptr, src, mask=mask)


def set_mla_kv_buffer_triton(
    kv_buffer: torch.Tensor,
    loc: torch.Tensor,
    cache_k_nope: torch.Tensor,
    cache_k_rope: torch.Tensor,
):
    nope_dim = cache_k_nope.shape[-1]
    rope_dim = cache_k_rope.shape[-1]
    total_dim = nope_dim + rope_dim
    BLOCK = 128
    n_loc = loc.numel()
    grid = (n_loc, triton.cdiv(total_dim, BLOCK))

    set_mla_kv_buffer_kernel[grid](
        kv_buffer,
        cache_k_nope,
        cache_k_rope,
        loc,
        kv_buffer.stride(0),
        cache_k_nope.stride(0),
        cache_k_rope.stride(0),
        nope_dim,
        rope_dim,
        BLOCK=BLOCK,
    )


521
class MLATokenToKVPool(KVCache):
522
523
524
    def __init__(
        self,
        size: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
525
        page_size: int,
526
527
528
529
        dtype: torch.dtype,
        kv_lora_rank: int,
        qk_rope_head_dim: int,
        layer_num: int,
530
        device: str,
531
        enable_memory_saver: bool,
532
533
        start_layer: Optional[int] = None,
        end_layer: Optional[int] = None,
534
    ):
535
536
537
538
539
540
541
542
543
544
545
        super().__init__(
            size,
            page_size,
            dtype,
            layer_num,
            device,
            enable_memory_saver,
            start_layer,
            end_layer,
        )

546
        self.kv_lora_rank = kv_lora_rank
547
        self.qk_rope_head_dim = qk_rope_head_dim
548

549
        with self.memory_saver_adapter.region():
550
551
            # The padded slot 0 is used for writing dummy outputs from padded tokens.
            self.kv_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
                torch.zeros(
                    (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
554
555
556
557
558
                    dtype=self.store_dtype,
                    device=device,
                )
                for _ in range(layer_num)
            ]
559

560
561
        self.layer_transfer_counter = None

562
563
564
565
566
567
568
569
570
571
572
573
        kv_size = self.get_kv_size_bytes()
        logger.info(
            f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
        )

    def get_kv_size_bytes(self):
        assert hasattr(self, "kv_buffer")
        kv_size_bytes = 0
        for kv_cache in self.kv_buffer:
            kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
        return kv_size_bytes

574
575
    # for disagg
    def get_contiguous_buf_infos(self):
576
577
578
        # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
        kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
        kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
579
580
581
        kv_item_lens = [
            self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
        ]
582
583
        return kv_data_ptrs, kv_data_lens, kv_item_lens

584
    def get_key_buffer(self, layer_id: int):
585
        if self.layer_transfer_counter is not None:
586
            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
587

588
        if self.store_dtype != self.dtype:
589
590
            return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
        return self.kv_buffer[layer_id - self.start_layer]
591
592

    def get_value_buffer(self, layer_id: int):
593
        if self.layer_transfer_counter is not None:
594
            self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
595

596
        if self.store_dtype != self.dtype:
597
598
599
600
            return self.kv_buffer[layer_id - self.start_layer][
                ..., : self.kv_lora_rank
            ].view(self.dtype)
        return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
601
602
603

    def get_kv_buffer(self, layer_id: int):
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
604
605
606

    def set_kv_buffer(
        self,
607
        layer: RadixAttention,
608
609
610
611
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
612
        layer_id = layer.layer_id
613
614
615
        if cache_k.dtype != self.dtype:
            cache_k = cache_k.to(self.dtype)
        if self.store_dtype != self.dtype:
616
617
618
            self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
                self.store_dtype
            )
619
        else:
620
            self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
Shuo Yang's avatar
Shuo Yang committed
621

Ke Bao's avatar
Ke Bao committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    def set_mla_kv_buffer(
        self,
        layer: RadixAttention,
        loc: torch.Tensor,
        cache_k_nope: torch.Tensor,
        cache_k_rope: torch.Tensor,
    ):
        layer_id = layer.layer_id
        if cache_k_nope.dtype != self.dtype:
            cache_k_nope = cache_k_nope.to(self.dtype)
            cache_k_rope = cache_k_rope.to(self.dtype)
        if self.store_dtype != self.dtype:
            cache_k_nope = cache_k_nope.view(self.store_dtype)
            cache_k_rope = cache_k_rope.view(self.store_dtype)

        set_mla_kv_buffer_triton(
            self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
        )

641
642
643
644
645
646
647
648
649
650
651
652
653
654
    def get_flat_data(self, indices):
        # prepare a large chunk of contiguous data for efficient transfer
        return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])

    @debug_timing
    def transfer(self, indices, flat_data):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
        for i in range(self.layer_num):
            self.kv_buffer[i][indices] = flat_data[i]

    def transfer_per_layer(self, indices, flat_data, layer_id):
        # transfer prepared data from host to device
        flat_data = flat_data.to(device=self.device, non_blocking=False)
655
        self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
656

Shuo Yang's avatar
Shuo Yang committed
657

658
class DoubleSparseTokenToKVPool(KVCache):
Shuo Yang's avatar
Shuo Yang committed
659
660
661
    def __init__(
        self,
        size: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
662
        page_size: int,
Shuo Yang's avatar
Shuo Yang committed
663
664
665
666
667
668
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
        device: str,
        heavy_channel_num: int,
669
        enable_memory_saver: bool,
670
671
        start_layer: Optional[int] = None,
        end_layer: Optional[int] = None,
Shuo Yang's avatar
Shuo Yang committed
672
    ):
673
674
675
676
677
678
679
680
681
        super().__init__(
            size,
            page_size,
            dtype,
            layer_num,
            device,
            enable_memory_saver,
            start_layer,
            end_layer,
682
683
        )

684
        with self.memory_saver_adapter.region():
685
686
            # [size, head_num, head_dim] for each layer
            self.k_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
687
688
689
                torch.zeros(
                    (size + page_size, head_num, head_dim), dtype=dtype, device=device
                )
690
691
692
                for _ in range(layer_num)
            ]
            self.v_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
693
694
695
                torch.zeros(
                    (size + page_size, head_num, head_dim), dtype=dtype, device=device
                )
696
697
698
699
700
                for _ in range(layer_num)
            ]

            # [size, head_num, heavy_channel_num] for each layer
            self.label_buffer = [
Lianmin Zheng's avatar
Lianmin Zheng committed
701
                torch.zeros(
702
703
704
705
                    (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
                )
                for _ in range(layer_num)
            ]
Shuo Yang's avatar
Shuo Yang committed
706
707

    def get_key_buffer(self, layer_id: int):
708
        return self.k_buffer[layer_id - self.start_layer]
Shuo Yang's avatar
Shuo Yang committed
709
710

    def get_value_buffer(self, layer_id: int):
711
        return self.v_buffer[layer_id - self.start_layer]
Shuo Yang's avatar
Shuo Yang committed
712
713

    def get_label_buffer(self, layer_id: int):
714
        return self.label_buffer[layer_id - self.start_layer]
Shuo Yang's avatar
Shuo Yang committed
715
716

    def get_kv_buffer(self, layer_id: int):
717
718
719
720
        return (
            self.k_buffer[layer_id - self.start_layer],
            self.v_buffer[layer_id - self.start_layer],
        )
Shuo Yang's avatar
Shuo Yang committed
721
722
723

    def set_kv_buffer(
        self,
724
        layer: RadixAttention,
Shuo Yang's avatar
Shuo Yang committed
725
726
727
728
729
730
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        cache_label: torch.Tensor,
    ):
        # NOTE(Andy): ignore the dtype check
731
        layer_id = layer.layer_id
732
733
734
        self.k_buffer[layer_id - self.start_layer][loc] = cache_k
        self.v_buffer[layer_id - self.start_layer][loc] = cache_v
        self.label_buffer[layer_id - self.start_layer][loc] = cache_label
735

736
737
738
739
740
741
742
743
    def get_flat_data(self, indices):
        pass

    def transfer(self, indices, flat_data):
        pass

    def transfer_per_layer(self, indices, flat_data, layer_id):
        pass