lora_manager.py 12.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17

# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"

18
import logging
19
20
21
22
23
24
import re

import torch

from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
25
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
from sglang.srt.utils import is_flashinfer_available, replace_submodule
27

28
29
logger = logging.getLogger(__name__)

30
if is_flashinfer_available():
31
    from flashinfer import SegmentGEMMWrapper
32
33


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def get_module_name(name):
    # Fallback solution of mapping from config module name to module name in model class.
    # Please check if it aligns with your base model.
    # Please implement the function in the model class if it is not.
    # You can reference this function in llama.py.
    params_mapping = {
        "q_proj": "qkv_proj",
        "k_proj": "qkv_proj",
        "v_proj": "qkv_proj",
        "gate_proj": "gate_up_proj",
        "up_proj": "gate_up_proj",
    }
    return params_mapping.get(name, name)


def get_hidden_dim(module_name, config):
    # Fallback solution of get_hidden_dim for different modules
    # Please check if it aligns with your base model.
    # Please implement the function in the model class if it is not.
    # You can reference this function in llama.py.
    if module_name in ["q_proj", "o_proj", "qkv_proj"]:
        return config.hidden_size, config.hidden_size
    elif module_name in ["kv_proj"]:
        return config.hidden_size, config.hidden_size // (
            config.num_attention_heads // config.num_key_value_heads
        )
    elif module_name == "gate_up_proj":
        return config.hidden_size, config.intermediate_size
    elif module_name == "down_proj":
        return config.intermediate_size, config.hidden_size
    else:
        raise NotImplementedError()


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def get_stacked_name(name):
    # origin name -> (name for A, name for B)
    params_mapping = {
        "q_proj": ("qkv_proj", "q_proj"),
        "k_proj": ("qkv_proj", "kv_proj"),
        "v_proj": ("qkv_proj", "kv_proj"),
        "gate_proj": ("gate_up_proj", "gate_up_proj"),
        "up_proj": ("gate_up_proj", "gate_up_proj"),
    }
    return params_mapping.get(name, (name, name))


def get_layer_id(name):
    match = re.search(r"layers\.(\d+)\.", name)
    if match is None:
        return None
    return int(match.group(1))


class LoRAManager:
    def __init__(
        self,
        base_model,
        lora_paths,
        base_hf_config,
        max_loras_per_batch,
        load_config,
        dtype,
    ):
        self.base_model = base_model
        self.lora_paths = lora_paths
        self.base_hf_config = base_hf_config
        self.max_loras_per_batch = max_loras_per_batch
        self.load_config = load_config
        self.dtype = dtype

        workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
        self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)

        self.init_loras()
        self.init_lora_memory_pool()
        self.init_lora_batch()

    def match_target_modules(self, module_name):
        for target_module in self.target_modules:
            if module_name.split(".")[-1] == target_module:
                return True
        return False

    def get_target_modules(self):
        modules = []
        for module_name, module in self.base_model.named_modules():
            if self.match_target_modules(module_name):
                modules.append((module_name, module))
        return modules

    def set_lora_module(self, module_name, module):
        lora_module = get_lora_layer(
            module, self.segment_gemm, self.max_lora_dim, self.scaling
        )
        replace_submodule(self.base_model, module_name, lora_module)
        return lora_module

    def init_loras(self):
        # get configs and target modules
        self.configs = {}
        self.origin_target_modules = set()
135
136
        for name, path in self.lora_paths.items():
            self.configs[name] = LoRAConfig(path)
137
            self.origin_target_modules = set(self.origin_target_modules) | set(
138
                self.configs[name].target_modules
139
            )
140
141
        if hasattr(self.base_model, "get_module_name"):
            self.target_modules = {
142
143
                self.base_model.get_module_name(module)
                for module in self.origin_target_modules
144
145
146
            }
        else:
            logger.warning(
147
148
149
                "WARNING: get_module_name() is not defined, "
                "which is used to map config module name to model implementation module name."
                "Use the default one, but please check if it is correct for your model."
150
151
152
153
            )
            self.target_modules = {
                get_module_name(module) for module in self.origin_target_modules
            }
154
155
156
157
158
159
160
        self.target_weights = set(
            [get_stacked_name(module) for module in self.origin_target_modules]
        )

        # load all weights to cpu
        self.loras = []
        self.lora_id = {}
161
162
        for name in self.lora_paths.keys():
            self.lora_id[name] = len(self.loras)
163
164
            self.loras.append(
                LoRAAdapter(
165
                    name, self.configs[name], self.base_hf_config, self.load_config
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                )
            )
            self.loras[-1].initialize_weights()

        # misc lora configs
        self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
        self.scaling = self.loras[0].scaling
        # FIXME remove the restrictions
        assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
        assert all(x.scaling == self.scaling for x in self.loras)

        # monkey patch to use the LoRA version
        self.lora_modules = []
        for module_name, module in self.get_target_modules():
            self.lora_modules.append(
                (module_name, self.set_lora_module(module_name, module))
            )

    def init_lora_memory_pool(self):
        # preallocate lora memory pool
        self.A_buffer = {}
        self.B_buffer = {}
        num_layer = self.base_hf_config.num_hidden_layers
        for module_A, module_B in self.target_weights:
            # init A tensor, column_major=True
191
192
193
194
            if hasattr(self.base_model, "get_hidden_dim"):
                hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
            else:
                logger.warning(
195
196
197
                    "WARNING: get_hidden_dim() is not defined, "
                    "which is used to get the hidden dim for different lora modules"
                    "Use the default one, but please check if it is correct for your model."
198
199
                )
                hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            c = self.loras[-1].get_stacked_multiply(module_A)
            if module_A not in self.A_buffer:
                self.A_buffer[module_A] = [
                    torch.empty(
                        (
                            self.max_loras_per_batch,
                            self.max_lora_dim * c,
                            hidden_dim_A,
                        ),
                        dtype=self.dtype,
                        device="cuda",
                    )
                    for i in range(num_layer)
                ]
            # init B tensor, column_major=True
215
216
217
218
            if hasattr(self.base_model, "get_hidden_dim"):
                _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
            else:
                logger.warning(
219
220
221
                    "WARNING: get_hidden_dim() is not defined, "
                    "which is used to get the hidden dim for different lora modules"
                    "Use the default one, but please check if it is correct for your model."
222
223
                )
                _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
            c = self.loras[-1].get_stacked_multiply(module_B)
            if module_B not in self.B_buffer:
                self.B_buffer[module_B] = [
                    torch.empty(
                        (
                            self.max_loras_per_batch,
                            hidden_dim_B * c,
                            self.max_lora_dim,
                        ),
                        dtype=self.dtype,
                        device="cuda",
                    )
                    for i in range(num_layer)
                ]

    def init_lora_batch(self):
        self.active_uids = set()  # set of active loras
        self.buffer_id = {}  # lora uid -> idx in memory pool

    def get_weight_name(self, name, idx):
        for target_weight_name in self.target_weights:
            if target_weight_name[idx] in name:
                return target_weight_name[idx]

    def load_lora(self, uid, buffer_id):
        num_layer = self.base_hf_config.num_hidden_layers
        if uid is None:
            for i in range(num_layer):
                for k in self.A_buffer.keys():
                    self.A_buffer[k][i][buffer_id] *= 0
            return

        for i in range(num_layer):
            layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
            for name, weights in layer_weights.items():
                if "lora_A" in name:
                    lora_weight_name = self.get_weight_name(name, 0)
                    if lora_weight_name:
                        self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
                else:
                    lora_weight_name = self.get_weight_name(name, 1)
                    if lora_weight_name:
                        self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)

268
    def prepare_lora_batch(self, forward_batch: ForwardBatch):
269
        # load active loras into lora memory pool
270
        cur_uids = set(forward_batch.lora_paths)
271
272
        assert len(cur_uids) <= self.max_loras_per_batch
        i = 0
273
        j = len(self.active_uids)
274
275
276
        evictable_uids = list(self.active_uids)
        for uid in cur_uids:
            if uid not in self.active_uids:
277
278
279
280
281
282
283
                if j < self.max_loras_per_batch:
                    index = j
                    j += 1
                else:
                    while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
                        i += 1
                    assert i < len(evictable_uids)
284
285
                    self.active_uids.remove(evictable_uids[i])
                    self.buffer_id.pop(evictable_uids[i])
286
287
288
                    index = i
                    i += 1
                self.load_lora(uid, index)
289
                self.active_uids.add(uid)
290
                self.buffer_id[uid] = index
291
292
293
294
295

        if cur_uids == set([None]):
            return

        # setup lora in forward modules
296
        bs = forward_batch.batch_size
297
        seg_lens = (
298
299
            forward_batch.extend_seq_lens
            if forward_batch.forward_mode.is_extend()
300
            else torch.ones(bs, device="cuda")
301
        )
302
303
304
        # FIXME: reuse the data rather than recompute
        seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
        seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
305
        weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
306
        for i, lora_path in enumerate(forward_batch.lora_paths):
307
            weight_indices[i] = self.buffer_id[lora_path]
308
309
310
311
312
313
314
315
316
317

        for module_name, module in self.lora_modules:
            layer_id = get_layer_id(module_name)

            if "qkv_proj" not in module_name:
                weight_name = self.get_weight_name(module_name, 0)
                module.set_lora_info(
                    self.A_buffer[weight_name][layer_id],
                    self.B_buffer[weight_name][layer_id],
                    bs,
318
                    seg_indptr,
319
320
321
322
323
324
325
326
                    weight_indices,
                )
            else:
                module.set_lora_info(
                    self.A_buffer["qkv_proj"][layer_id],
                    self.B_buffer["q_proj"][layer_id],
                    self.B_buffer["kv_proj"][layer_id],
                    bs,
327
                    seg_indptr,
328
329
                    weight_indices,
                )