memory_pool.py 7.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""Memory pool."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import logging
19
20
from abc import ABC, abstractmethod
from typing import List, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
25
26
27

import torch

logger = logging.getLogger(__name__)


class ReqToTokenPool:
Mingyi's avatar
Mingyi committed
28
29
30
    """A memory pool that maps a request to its token locations."""

    def __init__(self, size: int, max_context_len: int):
Ying Sheng's avatar
Ying Sheng committed
31
        self.size = size
32
        self.free_slots = list(range(size))
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
        self.req_to_token = torch.empty(
            (size, max_context_len), dtype=torch.int32, device="cuda"
        )

37
38
    def alloc(self, need_size: int) -> List[int]:
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
            return None

41
42
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
Liangsheng Yin's avatar
Liangsheng Yin committed
43

Mingyi's avatar
Mingyi committed
44
        return select_index
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
    def free(self, free_index: Union[int, List[int]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
47
        if isinstance(free_index, (int,)):
48
            self.free_slots.append(free_index)
Lianmin Zheng's avatar
Lianmin Zheng committed
49
        else:
50
            self.free_slots.extend(free_index)
Liangsheng Yin's avatar
Liangsheng Yin committed
51

Lianmin Zheng's avatar
Lianmin Zheng committed
52
    def clear(self):
53
        self.free_slots = list(range(self.size))
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55


56
class BaseTokenToKVPool(ABC):
Mingyi's avatar
Mingyi committed
57
58
    """A memory pool that maps a token to its kv cache locations"""

zhyncs's avatar
zhyncs committed
59
60
61
    def __init__(
        self,
        size: int,
62
        dtype: torch.dtype,
zhyncs's avatar
zhyncs committed
63
    ):
64
        self.size = size
65
66
67
68
69
70
        self.dtype = dtype
        if dtype == torch.float8_e5m2:
            # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
            self.store_dtype = torch.uint8
        else:
            self.store_dtype = dtype
Liangsheng Yin's avatar
Liangsheng Yin committed
71

72
        # We also add one slot. This slot is used for writing dummy output from padded tokens.
Liangsheng Yin's avatar
Liangsheng Yin committed
73
        self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
Lianmin Zheng's avatar
Lianmin Zheng committed
74

75
76
        # Prefetch buffer
        self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
77
        self.prefetch_chunk_size = 512
78

Mingyi's avatar
Mingyi committed
79
        self.can_use_mem_size = self.size
80
81
        self.clear()

Mingyi's avatar
Mingyi committed
82
83
84
    def available_size(self):
        return self.can_use_mem_size + len(self.prefetch_buffer)

85
    def alloc(self, need_size: int):
86
87
88
89
        buffer_len = len(self.prefetch_buffer)
        if need_size <= buffer_len:
            select_index = self.prefetch_buffer[:need_size]
            self.prefetch_buffer = self.prefetch_buffer[need_size:]
90
            return select_index
91
92
93

        addition_size = need_size - buffer_len
        alloc_size = max(addition_size, self.prefetch_chunk_size)
zhyncs's avatar
zhyncs committed
94
95
96
        select_index = (
            torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
        )
97
98

        if select_index.shape[0] < addition_size:
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
            return None

Mingyi's avatar
Mingyi committed
101
102
        self.mem_state[select_index] = False
        self.can_use_mem_size -= len(select_index)
103
104
105
106
107

        self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
        ret_index = self.prefetch_buffer[:need_size]
        self.prefetch_buffer = self.prefetch_buffer[need_size:]

108
        return ret_index
Lianmin Zheng's avatar
Lianmin Zheng committed
109

Mingyi's avatar
Mingyi committed
110
111
112
    def free(self, free_index: torch.Tensor):
        self.mem_state[free_index] = True
        self.can_use_mem_size += len(free_index)
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114

    def clear(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
        self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)

Liangsheng Yin's avatar
Liangsheng Yin committed
117
118
        self.mem_state.fill_(True)
        self.can_use_mem_size = self.size
119
120

        # We also add one slot. This slot is used for writing dummy output from padded tokens.
Liangsheng Yin's avatar
Liangsheng Yin committed
121
        self.mem_state[0] = False
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    @abstractmethod
    def get_key_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    @abstractmethod
    def get_value_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    @abstractmethod
    def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError()

    @abstractmethod
    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ) -> None:
        raise NotImplementedError()

145
146
147
148
149
150
151
152
153
154
155

class MHATokenToKVPool(BaseTokenToKVPool):

    def __init__(
        self,
        size: int,
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
    ):
156
        super().__init__(size, dtype)
157
158
159

        # [size, head_num, head_dim] for each layer
        self.k_buffer = [
160
161
162
            torch.empty(
                (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
            )
163
164
165
            for _ in range(layer_num)
        ]
        self.v_buffer = [
166
167
168
            torch.empty(
                (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
            )
169
170
171
172
            for _ in range(layer_num)
        ]

    def get_key_buffer(self, layer_id: int):
173
174
        if self.store_dtype != self.dtype:
            return self.k_buffer[layer_id].view(self.dtype)
175
176
177
        return self.k_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
178
179
        if self.store_dtype != self.dtype:
            return self.v_buffer[layer_id].view(self.dtype)
180
181
182
        return self.v_buffer[layer_id]

    def get_kv_buffer(self, layer_id: int):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)

    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
        if cache_k.dtype != self.dtype:
            cache_k = cache_k.to(self.dtype)
        if cache_v.dtype != self.dtype:
            cache_v = cache_v.to(self.dtype)
        if self.store_dtype != self.dtype:
            self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
            self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
        else:
            self.k_buffer[layer_id][loc] = cache_k
            self.v_buffer[layer_id][loc] = cache_v
202
203
204
205
206
207
208
209
210
211
212
213


class MLATokenToKVPool(BaseTokenToKVPool):

    def __init__(
        self,
        size: int,
        dtype: torch.dtype,
        kv_lora_rank: int,
        qk_rope_head_dim: int,
        layer_num: int,
    ):
214
        super().__init__(size, dtype)
215
216
217
218
219

        self.kv_lora_rank = kv_lora_rank
        self.kv_buffer = [
            torch.empty(
                (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
220
                dtype=self.store_dtype,
221
222
223
224
225
226
                device="cuda",
            )
            for _ in range(layer_num)
        ]

    def get_key_buffer(self, layer_id: int):
227
228
        if self.store_dtype != self.dtype:
            return self.kv_buffer[layer_id].view(self.dtype)
229
230
231
        return self.kv_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
232
233
        if self.store_dtype != self.dtype:
            return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
234
235
236
237
        return self.kv_buffer[layer_id][..., : self.kv_lora_rank]

    def get_kv_buffer(self, layer_id: int):
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
        if cache_k.dtype != self.dtype:
            cache_k = cache_k.to(self.dtype)
        if self.store_dtype != self.dtype:
            self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
        else:
            self.kv_buffer[layer_id][loc] = cache_k