memory_pool.py 9.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""Memory pool."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import logging
19
from typing import List, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
22
23
24
25
26

import torch

logger = logging.getLogger(__name__)


class ReqToTokenPool:
Mingyi's avatar
Mingyi committed
27
28
    """A memory pool that maps a request to its token locations."""

29
    def __init__(self, size: int, max_context_len: int, device: str):
Ying Sheng's avatar
Ying Sheng committed
30
        self.size = size
31
        self.max_context_len = max_context_len
32
        self.device = device
Lianmin Zheng's avatar
Lianmin Zheng committed
33
        self.req_to_token = torch.empty(
34
            (size, max_context_len), dtype=torch.int32, device=device
Lianmin Zheng's avatar
Lianmin Zheng committed
35
        )
36
37
38
39
        self.free_slots = list(range(size))

    def available_size(self):
        return len(self.free_slots)
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
42
    def alloc(self, need_size: int) -> List[int]:
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
            return None

45
46
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
Liangsheng Yin's avatar
Liangsheng Yin committed
47

Mingyi's avatar
Mingyi committed
48
        return select_index
Lianmin Zheng's avatar
Lianmin Zheng committed
49

50
    def free(self, free_index: Union[int, List[int]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
51
        if isinstance(free_index, (int,)):
52
            self.free_slots.append(free_index)
Lianmin Zheng's avatar
Lianmin Zheng committed
53
        else:
54
            self.free_slots.extend(free_index)
Liangsheng Yin's avatar
Liangsheng Yin committed
55

Lianmin Zheng's avatar
Lianmin Zheng committed
56
    def clear(self):
57
        self.free_slots = list(range(self.size))
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59


60
class BaseTokenToKVPool:
Mingyi's avatar
Mingyi committed
61
62
    """A memory pool that maps a token to its kv cache locations"""

zhyncs's avatar
zhyncs committed
63
64
65
    def __init__(
        self,
        size: int,
66
        dtype: torch.dtype,
67
        device: str,
zhyncs's avatar
zhyncs committed
68
    ):
69
        self.size = size
70
        self.dtype = dtype
71
        self.device = device
72
73
74
75
76
        if dtype == torch.float8_e5m2:
            # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
            self.store_dtype = torch.uint8
        else:
            self.store_dtype = dtype
Liangsheng Yin's avatar
Liangsheng Yin committed
77

78
        self.free_slots = None
79
80
        self.is_not_in_free_group = True
        self.free_group = []
81
82
        self.clear()

Mingyi's avatar
Mingyi committed
83
    def available_size(self):
84
        return len(self.free_slots)
Mingyi's avatar
Mingyi committed
85

86
    def alloc(self, need_size: int):
87
        if need_size > len(self.free_slots):
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
            return None

90
91
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
92

93
        return select_index.to(self.device, non_blocking=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
94

Mingyi's avatar
Mingyi committed
95
    def free(self, free_index: torch.Tensor):
96
97
98
99
100
101
102
103
104
105
106
107
108
        if self.is_not_in_free_group:
            self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
        else:
            self.free_group.append(free_index)

    def free_group_begin(self):
        self.is_not_in_free_group = False
        self.free_group = []

    def free_group_end(self):
        self.is_not_in_free_group = True
        if self.free_group:
            self.free(torch.concat(self.free_group))
Lianmin Zheng's avatar
Lianmin Zheng committed
109
110

    def clear(self):
111
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
112
113
114
        self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
        self.is_in_free_group = False
        self.free_group = []
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def get_key_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    def get_value_buffer(self, layer_id: int) -> torch.Tensor:
        raise NotImplementedError()

    def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError()

    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ) -> None:
        raise NotImplementedError()

134
135
136
137
138
139
140
141
142
143

class MHATokenToKVPool(BaseTokenToKVPool):

    def __init__(
        self,
        size: int,
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
144
        device: str,
145
    ):
146
        super().__init__(size, dtype, device)
147
148

        # [size, head_num, head_dim] for each layer
149
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
150
        self.k_buffer = [
151
            torch.empty(
152
153
154
                (size + 1, head_num, head_dim),
                dtype=self.store_dtype,
                device=device,
155
            )
156
157
158
            for _ in range(layer_num)
        ]
        self.v_buffer = [
159
            torch.empty(
160
161
162
                (size + 1, head_num, head_dim),
                dtype=self.store_dtype,
                device=device,
163
            )
164
165
166
167
            for _ in range(layer_num)
        ]

    def get_key_buffer(self, layer_id: int):
168
169
        if self.store_dtype != self.dtype:
            return self.k_buffer[layer_id].view(self.dtype)
170
171
172
        return self.k_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
173
174
        if self.store_dtype != self.dtype:
            return self.v_buffer[layer_id].view(self.dtype)
175
176
177
        return self.v_buffer[layer_id]

    def get_kv_buffer(self, layer_id: int):
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)

    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
        if cache_k.dtype != self.dtype:
            cache_k = cache_k.to(self.dtype)
        if cache_v.dtype != self.dtype:
            cache_v = cache_v.to(self.dtype)
        if self.store_dtype != self.dtype:
            self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
            self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
        else:
            self.k_buffer[layer_id][loc] = cache_k
            self.v_buffer[layer_id][loc] = cache_v
197
198
199
200
201
202
203
204
205
206
207


class MLATokenToKVPool(BaseTokenToKVPool):

    def __init__(
        self,
        size: int,
        dtype: torch.dtype,
        kv_lora_rank: int,
        qk_rope_head_dim: int,
        layer_num: int,
208
        device: str,
209
    ):
210
        super().__init__(size, dtype, device)
211
212

        self.kv_lora_rank = kv_lora_rank
213
        # The padded slot 0 is used for writing dummy outputs from padded tokens.
214
215
216
        self.kv_buffer = [
            torch.empty(
                (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
217
                dtype=self.store_dtype,
218
                device=device,
219
220
221
222
223
            )
            for _ in range(layer_num)
        ]

    def get_key_buffer(self, layer_id: int):
224
225
        if self.store_dtype != self.dtype:
            return self.kv_buffer[layer_id].view(self.dtype)
226
227
228
        return self.kv_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
229
230
        if self.store_dtype != self.dtype:
            return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
231
232
233
234
        return self.kv_buffer[layer_id][..., : self.kv_lora_rank]

    def get_kv_buffer(self, layer_id: int):
        return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
    ):
        if cache_k.dtype != self.dtype:
            cache_k = cache_k.to(self.dtype)
        if self.store_dtype != self.dtype:
            self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
        else:
            self.kv_buffer[layer_id][loc] = cache_k
Shuo Yang's avatar
Shuo Yang committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306


class DoubleSparseTokenToKVPool(BaseTokenToKVPool):

    def __init__(
        self,
        size: int,
        dtype: torch.dtype,
        head_num: int,
        head_dim: int,
        layer_num: int,
        device: str,
        heavy_channel_num: int,
    ):
        super().__init__(size, dtype, device)

        # [size, head_num, head_dim] for each layer
        self.k_buffer = [
            torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
            for _ in range(layer_num)
        ]
        self.v_buffer = [
            torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
            for _ in range(layer_num)
        ]

        # [size, head_num, heavy_channel_num] for each layer
        self.label_buffer = [
            torch.empty(
                (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
            )
            for _ in range(layer_num)
        ]

    def get_key_buffer(self, layer_id: int):
        return self.k_buffer[layer_id]

    def get_value_buffer(self, layer_id: int):
        return self.v_buffer[layer_id]

    def get_label_buffer(self, layer_id: int):
        return self.label_buffer[layer_id]

    def get_kv_buffer(self, layer_id: int):
        return self.k_buffer[layer_id], self.v_buffer[layer_id]

    def set_kv_buffer(
        self,
        layer_id: int,
        loc: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        cache_label: torch.Tensor,
    ):
        # NOTE(Andy): ignore the dtype check
        self.k_buffer[layer_id][loc] = cache_k
        self.v_buffer[layer_id][loc] = cache_v
        self.label_buffer[layer_id][loc] = cache_label