allocator.py 19.4 KB
Newer Older
1
2
from __future__ import annotations

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Page-aligned memory pool.
"""

22
23
24
import abc
from typing import TYPE_CHECKING

Lianmin Zheng's avatar
Lianmin Zheng committed
25
26
27
28
import torch
import triton
import triton.language as tl

tarinkk's avatar
tarinkk committed
29
from sglang.srt.mem_cache.memory_pool import SWAKVPool
30
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
liucong8560's avatar
liucong8560 committed
31
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel, dcu_alloc_extend_kernel
Lianmin Zheng's avatar
Lianmin Zheng committed
32

33
34
35
36
37
38
39
40
41
42
43
44
45
if TYPE_CHECKING:
    from sglang.srt.mem_cache.memory_pool import KVCache


class BaseTokenToKVPoolAllocator(abc.ABC):
    @abc.abstractmethod
    def __init__(
        self,
        size: int,
        page_size: int,
        dtype: torch.dtype,
        device: str,
        kvcache: KVCache,
46
        need_sort: bool,
47
48
49
50
51
52
    ):
        self.size = size
        self.page_size = page_size
        self.dtype = dtype
        self.device = device
        self._kvcache = kvcache
53
        self.need_sort = need_sort
54
55

        self.free_pages = None
56
        self.release_pages = None
57
58
59
60
61
62
63
        self.is_not_in_free_group = True
        self.free_group = []

    def debug_print(self) -> str:
        return ""

    def available_size(self):
64
        return (len(self.free_pages) + len(self.release_pages)) * self.page_size
65
66
67
68

    def get_kvcache(self):
        return self._kvcache

69
70
    def restore_state(self, state):
        self.free_pages, self.release_pages = state
71
72

    def backup_state(self):
73
        return (self.free_pages, self.release_pages)
74
75
76
77
78
79
80
81
82
83

    def free_group_begin(self):
        self.is_not_in_free_group = False
        self.free_group = []

    def free_group_end(self):
        self.is_not_in_free_group = True
        if self.free_group:
            self.free(torch.cat(self.free_group))

84
85
86
87
88
89
90
91
    def merge_and_sort_free(self):
        if len(self.release_pages) > 0:
            self.free_pages = torch.cat((self.free_pages, self.release_pages))
            self.free_pages, _ = torch.sort(self.free_pages)
            self.release_pages = torch.empty(
                (0,), dtype=self.release_pages.dtype, device=self.device
            )

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def get_cpu_copy(self, *args, **kwargs):
        # FIXME: reuse the get_cpu_copy after paged allocator is implemented
        raise NotImplementedError()

    def load_cpu_copy(self, *args, **kwargs):
        # FIXME: reuse the load_cpu_copy after paged allocator is implemented
        raise NotImplementedError()

    def alloc_extend(self, *args, **kwargs):
        raise NotImplementedError("alloc_extend is only for paged allocator")

    def alloc_decode(self, *args, **kwargs):
        raise NotImplementedError("alloc_decode is only for paged allocator")

    @abc.abstractmethod
    def clear(self):
        raise NotImplementedError()

    @abc.abstractmethod
    def alloc(self, need_size: int):
        raise NotImplementedError()

    @abc.abstractmethod
    def free(self, free_index: torch.Tensor):
        raise NotImplementedError()


class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    """An allocator managing the indices to kv cache data."""

122
123
124
125
126
127
128
129
130
    def __init__(
        self,
        size: int,
        dtype: torch.dtype,
        device: str,
        kvcache: KVCache,
        need_sort: bool,
    ):
        super().__init__(size, 1, dtype, device, kvcache, need_sort)
131
132
133
134
135
136
137
138
139
        self.clear()

    def clear(self):
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
        self.free_pages = torch.arange(
            1, self.size + 1, dtype=torch.int64, device=self.device
        )
        self.is_not_in_free_group = True
        self.free_group = []
140
        self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
141
142
143

    def available_size(self):
        # To avoid minor "len(free_pages) * 1" overhead
144
        return len(self.free_pages) + len(self.release_pages)
145
146

    def alloc(self, need_size: int):
147
        if self.need_sort and need_size > len(self.free_pages):
148
            self.merge_and_sort_free()
Lianmin Zheng's avatar
Lianmin Zheng committed
149

150
151
152
153
154
155
156
157
158
159
160
161
        if need_size > len(self.free_pages):
            return None

        select_index = self.free_pages[:need_size]
        self.free_pages = self.free_pages[need_size:]
        return select_index

    def free(self, free_index: torch.Tensor):
        if free_index.numel() == 0:
            return

        if self.is_not_in_free_group:
162
163
164
165
            if self.need_sort:
                self.release_pages = torch.cat((self.release_pages, free_index))
            else:
                self.free_pages = torch.cat((self.free_pages, free_index))
166
167
168
169
170
171
172
173
174
        else:
            self.free_group.append(free_index)

    def get_cpu_copy(self, indices):
        return self._kvcache.get_cpu_copy(indices)

    def load_cpu_copy(self, kv_cache_cpu, indices):
        return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)

Lianmin Zheng's avatar
Lianmin Zheng committed
175

tarinkk's avatar
tarinkk committed
176
177
178
179
180
181
182
183
184
185
class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    """Allocator for SWA hybrid KV cache."""

    def __init__(
        self,
        size: int,
        size_swa: int,
        dtype: torch.dtype,
        device: str,
        kvcache: SWAKVPool,
186
        need_sort: bool,
tarinkk's avatar
tarinkk committed
187
    ):
188
        super().__init__(size, 1, dtype, device, kvcache, need_sort)
tarinkk's avatar
tarinkk committed
189
190
191
192
193
194
195
196
        assert isinstance(kvcache, SWAKVPool)
        self._size_full = size
        self._size_swa = size_swa
        self.full_attn_allocator = TokenToKVPoolAllocator(
            size,
            dtype,
            device,
            kvcache.full_kv_pool,
197
            need_sort,
tarinkk's avatar
tarinkk committed
198
199
200
201
202
203
        )
        self.swa_attn_allocator = TokenToKVPoolAllocator(
            size_swa,
            dtype,
            device,
            kvcache.swa_kv_pool,
204
            need_sort,
tarinkk's avatar
tarinkk committed
205
206
207
208
209
210
211
212
213
214
215
        )
        self.full_to_swa_index_mapping = torch.empty(
            size + size_swa + 1,
            dtype=torch.int64,
            device=device,
        )
        self.clear()

        self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping

    def available_size(self):
Hanming Lu's avatar
Hanming Lu committed
216
        raise NotImplementedError()
tarinkk's avatar
tarinkk committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    def full_available_size(self):
        return self.full_attn_allocator.available_size()

    def swa_available_size(self):
        return self.swa_attn_allocator.available_size()

    @property
    def size_full(self):
        return self._size_full

    @property
    def size_swa(self):
        return self._size_swa

    def debug_print(self) -> str:
        msg = ""
        msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
        msg += (
            f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
        )
        return msg

    def get_kvcache(self):
        return self._kvcache

    def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
        assert self.full_to_swa_index_mapping is not None
        return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)

    def alloc(self, need_size: int):
        if need_size > self.full_attn_allocator.available_size():
            return None
        if need_size > self.swa_attn_allocator.available_size():
            return None

        alloc_full_indices = self.full_attn_allocator.alloc(need_size)
        alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
        self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
        return alloc_full_indices

    def free(self, free_index: torch.Tensor):
        if free_index.numel() == 0:
            return
        if self.is_not_in_free_group:
            self.full_attn_allocator.free(free_index)
            self.free_swa(free_index)
        else:
            self.free_group.append(free_index)
        assert (
            self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
        )
        assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size

    def free_swa(self, free_index: torch.Tensor):
        swa_indices = self.full_to_swa_index_mapping[free_index]
        swa_indices = swa_indices[swa_indices > 0]
        self.swa_attn_allocator.free(swa_indices)
        self.full_to_swa_index_mapping[free_index] = 0

    def backup_state(self):
278
279
280
281
        return [
            self.full_attn_allocator.backup_state(),
            self.swa_attn_allocator.backup_state(),
        ]
tarinkk's avatar
tarinkk committed
282
283

    def restore_state(self, state):
284
285
286
        assert len(state) == 2
        self.full_attn_allocator.restore_state(state[0])
        self.swa_attn_allocator.restore_state(state[1])
tarinkk's avatar
tarinkk committed
287
288
289
290
291

    def clear(self):
        self.swa_attn_allocator.clear()
        self.full_attn_allocator.clear()
        self.full_to_swa_index_mapping.fill_(0)
292
        self.is_not_in_free_group = True
tarinkk's avatar
tarinkk committed
293
294
295
        self.free_group = []


Lianmin Zheng's avatar
Lianmin Zheng committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
@triton.jit
def alloc_extend_kernel(
    pre_lens_ptr,
    seq_lens_ptr,
    last_loc_ptr,
    free_page_ptr,
    out_indices,
    bs_upper: tl.constexpr,
    page_size: tl.constexpr,
    max_num_extend_tokens: tl.constexpr,
):
    pid = tl.program_id(0)

    load_offset = tl.arange(0, bs_upper)
    seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
    pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
    extend_lens = seq_lens - pre_lens

    seq_len = tl.load(seq_lens_ptr + pid)
    pre_len = tl.load(pre_lens_ptr + pid)
    extend_len = seq_len - pre_len

    sum_extend_lens = tl.sum(extend_lens)
    output_start_loc = sum_extend_lens - extend_len

    num_pages_after = (seq_lens + page_size - 1) // page_size
    num_pages_before = (pre_lens + page_size - 1) // page_size
    num_new_pages = num_pages_after - num_pages_before

    num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
        pre_len + page_size - 1
    ) // page_size
    sum_num_new_pages = tl.sum(num_new_pages)
    new_page_start_loc = sum_num_new_pages - num_page_start_loc_self

    # Part 1: fill the old partial page
    last_loc = tl.load(last_loc_ptr + pid)
    num_part1 = (
        min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
    )
    offset_one_page = tl.arange(0, page_size)
    tl.store(
        out_indices + output_start_loc + offset_one_page,
        last_loc + 1 + offset_one_page,
        mask=offset_one_page < num_part1,
    )
    if pre_len + num_part1 == seq_len:
        return

    # Part 2: fill the new full pages
    num_part2 = (
        seq_len // page_size * page_size
        - (pre_len + page_size - 1) // page_size * page_size
    )

    offset_many_page = tl.arange(0, max_num_extend_tokens)
    page_start = tl.load(
        free_page_ptr + new_page_start_loc + offset_many_page // page_size,
        mask=offset_many_page < num_part2,
    )
    tl.store(
        out_indices + output_start_loc + num_part1 + offset_many_page,
        page_start * page_size + offset_many_page % page_size,
        mask=offset_many_page < num_part2,
    )
    if pre_len + num_part1 + num_part2 == seq_len:
        return

    # Part 3: fill the new partial page
    num_part3 = seq_len - seq_len // page_size * page_size
    start_loc = tl.load(
        free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
    )
    tl.store(
        out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
        start_loc * page_size + offset_one_page,
        mask=offset_one_page < num_part3,
    )


@triton.jit
def alloc_decode_kernel(
    seq_lens_ptr,
    last_loc_ptr,
    free_page_ptr,
    out_indices,
    bs_upper: tl.constexpr,
    page_size: tl.constexpr,
):
    pid = tl.program_id(0)

    load_offset = tl.arange(0, bs_upper)
    seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
    pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)

    seq_len = tl.load(seq_lens_ptr + pid)
    pre_len = seq_len - 1

    num_pages_after = (seq_lens + page_size - 1) // page_size
    num_pages_before = (pre_lens + page_size - 1) // page_size
    num_new_pages = num_pages_after - num_pages_before

    num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
        pre_len + page_size - 1
    ) // page_size
    sum_num_new_pages = tl.sum(num_new_pages)
    new_page_start_loc = sum_num_new_pages - num_page_start_loc_self

    if num_page_start_loc_self == 0:
        last_loc = tl.load(last_loc_ptr + pid)
        tl.store(out_indices + pid, last_loc + 1)
    else:
        page = tl.load(free_page_ptr + new_page_start_loc)
        tl.store(out_indices + pid, page * page_size)


412
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
Lianmin Zheng's avatar
Lianmin Zheng committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    """
    An allocator managing the indices to kv cache data.

    This class has the same interface as `TokenToKVPoolAllocator` but the output
    of one request is always page-aligned.

    TODO: fuse last_loc into the kernel.
    """

    def __init__(
        self,
        size: int,
        page_size: int,
        dtype: torch.dtype,
        device: str,
        kvcache: KVCache,
429
        need_sort: bool,
Lianmin Zheng's avatar
Lianmin Zheng committed
430
    ):
431
        super().__init__(size, page_size, dtype, device, kvcache, need_sort)
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
        self.num_pages = size // page_size
        self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
liucong8560's avatar
liucong8560 committed
434
        self.sglang_kvalloc_kernel = get_bool_env_var("SGLANG_KVALLOC_KERNEL")
435
        self.seen_max_num_extend_tokens_next_power_of_2 = 1
436
        self.clear()
437
438
439
440
441
442
443
444
445

    def alloc(self, need_size: int):
        # page-aligned allocation, returning contiguous indices of pages
        if self.debug_mode:
            assert (
                need_size % self.page_size == 0
            ), "The allocation size should be page-aligned"

        num_pages = need_size // self.page_size
446
        if self.need_sort and num_pages > len(self.free_pages):
447
            self.merge_and_sort_free()
448
449
450
451
452
453
454
455
456
457
458
459
460
        if num_pages > len(self.free_pages):
            return None

        out_pages = self.free_pages[:num_pages]
        self.free_pages = self.free_pages[num_pages:]

        out_indices = (
            out_pages[:, None] * self.page_size
            + torch.arange(self.page_size, device=self.device)
        ).reshape(-1)

        return out_indices

Lianmin Zheng's avatar
Lianmin Zheng committed
461
462
463
    def alloc_extend(
        self,
        prefix_lens: torch.Tensor,
464
        prefix_lens_cpu: torch.Tensor,
Lianmin Zheng's avatar
Lianmin Zheng committed
465
        seq_lens: torch.Tensor,
466
        seq_lens_cpu: torch.Tensor,
Lianmin Zheng's avatar
Lianmin Zheng committed
467
468
469
470
471
472
473
474
        last_loc: torch.Tensor,
        extend_num_tokens: int,
    ):
        if self.debug_mode:
            assert torch.all(
                (last_loc + 1) % self.page_size == prefix_lens % self.page_size
            )

475
476
477
478
479
        self.seen_max_num_extend_tokens_next_power_of_2 = max(
            self.seen_max_num_extend_tokens_next_power_of_2,
            next_power_of_2(extend_num_tokens),
        )

480
        bs = len(prefix_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
481
        if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
482
483
            self.free_pages
        ):
484
485
            self.merge_and_sort_free()

Lianmin Zheng's avatar
Lianmin Zheng committed
486
487
488
        out_indices = torch.empty(
            (extend_num_tokens,), dtype=torch.int64, device=self.device
        )
liucong8560's avatar
liucong8560 committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        if self.sglang_kvalloc_kernel:
            if bs < 3:
                dcu_alloc_extend_kernel(
                    pre_lens_ptr = prefix_lens,
                    seq_lens_ptr = seq_lens,
                    last_loc_ptr = last_loc,
                    free_page_ptr = self.free_pages,
                    out_indices = out_indices,
                    bs = bs,
                    bs_upper = next_power_of_2(bs),
                    page_size = self.page_size,
                    max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
                )
            else:
                alloc_extend_kernel[(bs,)](
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    self.free_pages,
                    out_indices,
                    next_power_of_2(bs),
                    self.page_size,
                    self.seen_max_num_extend_tokens_next_power_of_2,
                )
        else:
            alloc_extend_kernel[(bs,)](
                prefix_lens,
                seq_lens,
                last_loc,
                self.free_pages,
                out_indices,
                next_power_of_2(bs),
                self.page_size,
                self.seen_max_num_extend_tokens_next_power_of_2,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
524

525
526
527
        if self.debug_mode:
            assert len(torch.unique(out_indices)) == len(out_indices)

528
529
530
531
532
        num_new_pages = get_num_new_pages(
            seq_lens=seq_lens_cpu,
            page_size=self.page_size,
            prefix_lens=prefix_lens_cpu,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
533
534
535
536
537
538
539
540
541
        if num_new_pages > len(self.free_pages):
            return None

        self.free_pages = self.free_pages[num_new_pages:]
        return out_indices

    def alloc_decode(
        self,
        seq_lens: torch.Tensor,
542
        seq_lens_cpu: torch.Tensor,
Lianmin Zheng's avatar
Lianmin Zheng committed
543
544
545
546
547
548
549
        last_loc: torch.Tensor,
    ):
        if self.debug_mode:
            assert torch.all(
                (last_loc + 2) % self.page_size == seq_lens % self.page_size
            )

550
        bs = len(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
551
        if self.need_sort and bs > len(self.free_pages):
552
553
            self.merge_and_sort_free()

Lianmin Zheng's avatar
Lianmin Zheng committed
554
        out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
liucong8560's avatar
liucong8560 committed
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574

        if self.sglang_kvalloc_kernel:
            dcu_alloc_decode_kernel(
                seq_lens_ptr = seq_lens,
                last_loc_ptr = last_loc,
                free_page_ptr = self.free_pages,
                out_indices = out_indices,
                bs = bs,
                bs_upper = next_power_of_2(bs),
                page_size = self.page_size,
            )
        else:
            alloc_decode_kernel[(bs,)](
                seq_lens,
                last_loc,
                self.free_pages,
                out_indices,
                next_power_of_2(bs),
                self.page_size,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
575

576
577
578
        if self.debug_mode:
            assert len(torch.unique(out_indices)) == len(out_indices)

579
        num_new_pages = get_num_new_pages(
580
581
582
            seq_lens=seq_lens_cpu,
            page_size=self.page_size,
            decode=True,
583
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
584
585
586
587
588
589
590
591
592
593
594
595
        if num_new_pages > len(self.free_pages):
            return None

        self.free_pages = self.free_pages[num_new_pages:]
        return out_indices

    def free(self, free_index: torch.Tensor):
        if free_index.numel() == 0:
            return

        if self.is_not_in_free_group:
            free_page_indices = torch.unique(free_index // self.page_size)
596
597
598
599
            if self.need_sort:
                self.release_pages = torch.cat((free_page_indices, self.release_pages))
            else:
                self.free_pages = torch.cat((free_page_indices, self.free_pages))
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
602
        else:
            self.free_group.append(free_index)

603
604
605
        if self.debug_mode:
            assert len(torch.unique(self.free_pages)) == len(self.free_pages)

Lianmin Zheng's avatar
Lianmin Zheng committed
606
607
608
609
610
    def clear(self):
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
        self.free_pages = torch.arange(
            1, self.num_pages + 1, dtype=torch.int64, device=self.device
        )
611
        self.is_not_in_free_group = True
Lianmin Zheng's avatar
Lianmin Zheng committed
612
        self.free_group = []
613
        self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
614

615
616
617
618
619
    def get_cpu_copy(self, indices):
        return self._kvcache.get_cpu_copy(indices)

    def load_cpu_copy(self, kv_cache_cpu, indices):
        return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)