"requirements.txt" did not exist on "0024a5c66f90c7d3d02f7ef08a773aace6deb155"
lora_manager.py 19.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17

# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"

18
import logging
19
from typing import Dict, Iterable, List, Optional, Set, Tuple
20
21
22

import torch

23
24
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
from sglang.srt.lora.lora import LoRAAdapter
28
from sglang.srt.lora.lora_config import LoRAConfig
29
from sglang.srt.lora.lora_registry import LoRARef
30
31
32
33
34
35
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
    LoRABatchInfo,
    LoRAType,
    get_customized_names_from_hf_names,
    get_layer_id,
36
    get_normalized_lora_weight_names,
37
38
    get_weight_name,
)
39
from sglang.srt.managers.io_struct import LoRAUpdateResult
40
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
from sglang.srt.utils import replace_submodule
42

43
44
logger = logging.getLogger(__name__)

45
46
47
48

class LoRAManager:
    def __init__(
        self,
49
50
51
52
53
54
        base_model: torch.nn.Module,
        base_hf_config: AutoConfig,
        max_loras_per_batch: int,
        load_config: LoadConfig,
        dtype: torch.dtype,
        lora_backend: str = "triton",
55
56
        tp_size: int = 1,
        tp_rank: int = 0,
57
58
        max_lora_rank: Optional[int] = None,
        target_modules: Optional[Iterable[str]] = None,
59
        lora_paths: Optional[Dict[str, LoRARef]] = None,
60
    ):
61
62
63
64
65
        self.base_model: torch.nn.Module = base_model
        self.base_hf_config: AutoConfig = base_hf_config
        self.max_loras_per_batch: int = max_loras_per_batch
        self.load_config: LoadConfig = load_config
        self.dtype: torch.dtype = dtype
66
67
68
        self.device: torch.device = next(self.base_model.parameters()).device
        self.tp_size: int = tp_size
        self.tp_rank: int = tp_rank
69
70
71

        # LoRA backend for running sgemm kernels
        logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
72
        backend_type = get_backend_from_name(lora_backend)
73
        self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
74

75
        # Initialize mutable internal state of the LoRAManager.
76
77
78
79
80
        self.init_state(
            max_lora_rank=max_lora_rank,
            target_modules=target_modules,
            lora_paths=lora_paths,
        )
81

82
83
84
85
86
87
88
89
90
    def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
        self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
        with torch.device("cuda"):
            self.cuda_graph_batch_info = LoRABatchInfo(
                bs=self.max_bs_in_cuda_graph,
                seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
                seg_indptr=torch.zeros(
                    self.max_bs_in_cuda_graph + 1, dtype=torch.int32
                ),
91
                max_len=1,
92
93
94
95
96
97
98
                weight_indices=torch.zeros(
                    self.max_bs_in_cuda_graph, dtype=torch.int32
                ),
                lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
                scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
            )

99
100
101
102
103
104
105
106
107
108
109
            # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
            # across batches.
            self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
            torch.cumsum(
                self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
                dim=0,
                out=self.cuda_graph_batch_info.seg_indptr[
                    1 : self.max_bs_in_cuda_graph + 1
                ],
            )

110
111
112
113
114
115
116
    def create_lora_update_result(
        self, success: bool, error_message: str = ""
    ) -> LoRAUpdateResult:
        return LoRAUpdateResult(
            success=success,
            error_message=error_message,
            loaded_adapters={
117
118
                lora_ref.lora_name: lora_ref.lora_path
                for lora_ref in self.lora_refs.values()
119
120
121
            },
        )

122
    def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
123
124
125
126
        """
        Load a single LoRA adapter from the specified path.

        Args:
127
            lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
128
        """
129
130
131
132
133
134
        assert (
            lora_ref.lora_name is not None and lora_ref.lora_path is not None
        ), "LoRARef must have both lora_name and lora_path set for loading."
        assert (
            lora_ref.lora_id not in self.loras
        ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
135

136
137
138
139
140
        try:
            # load configs
            new_adapter = LoRAConfig(lora_ref.lora_path)
            self.validate_new_adapter(new_adapter, lora_ref)
            self.configs[lora_ref.lora_id] = new_adapter
141

142
143
            # load weights
            self.load_lora_weights(lora_ref)
144

145
146
            # keep metadata for displayed messages
            self.lora_refs[lora_ref.lora_id] = lora_ref
147
        except Exception as e:
148
149
150
            return self.create_lora_update_result(
                success=False,
                error_message=str(e),
151
            )
152

153
        return self.create_lora_update_result(success=True)
154

155
    def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
156
157
158
159
        """
        Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
        """

160
161
        memory_pool = getattr(self, "memory_pool", None)
        incompatible = memory_pool and not memory_pool.can_support(lora_config)
162
163
        if incompatible:
            raise ValueError(
164
                f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
165
166
                "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
                "included in `--enable_lora_modules`."
167
168
            )

169
    def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
170
171
172
173
        """
        Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
        delete the corresponding LoRA modules.
        """
174

175
176
177
178
        adapter = self.configs.get(lora_ref.lora_id, None)
        assert (
            adapter is not None
        ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
179

180
181
182
183
184
185
186
187
188
        try:
            del self.configs[lora_ref.lora_id]
            del self.loras[lora_ref.lora_id]
            del self.lora_refs[lora_ref.lora_id]
        except Exception as e:
            return self.create_lora_update_result(
                success=False,
                error_message=str(e),
            )
189

190
        return self.create_lora_update_result(success=True)
191

192
    def prepare_lora_batch(self, forward_batch: ForwardBatch):
193
194
195
196
197
        # Load active loras into lora memory pool
        # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
        # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
        # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
        # the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
198
        cur_uids = set(forward_batch.lora_paths)
199
        assert len(cur_uids) <= self.max_loras_per_batch
200
        self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
201

202
        # set up batch info shared by all lora modules
203
        bs = forward_batch.batch_size
204

205
206
207
208
209
210
211
212
213
214
215
216
        def transfer_adapter_info(
            weight_indices_out: torch.Tensor,
            lora_ranks_out: torch.Tensor,
            scalings_out: torch.Tensor,
        ):
            """
            Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
            to device (CUDA) asynchronously.
            """
            weight_indices = [0] * len(forward_batch.lora_paths)
            lora_ranks = [0] * self.max_loras_per_batch
            scalings = [0] * self.max_loras_per_batch
217
218
219
220
            for i, uid in enumerate(forward_batch.lora_paths):
                weight_indices[i] = self.memory_pool.get_buffer_id(uid)
                if uid is not None:
                    lora = self.loras[uid]
221
                    lora_ranks[weight_indices[i]] = lora.config.r
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                    scalings[weight_indices[i]] = lora.scaling

            # Use pinned memory to avoid synchronizations during host-to-device transfer
            weight_indices_tensor = torch.tensor(
                weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
            )
            lora_ranks_tensor = torch.tensor(
                lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
            )
            scalings_tensor = torch.tensor(
                scalings, dtype=torch.float, pin_memory=True, device="cpu"
            )

            # Copy to device tensors asynchronously
            weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
            lora_ranks_out[: self.max_loras_per_batch].copy_(
                lora_ranks_tensor, non_blocking=True
            )
            scalings_out[: self.max_loras_per_batch].copy_(
                scalings_tensor, non_blocking=True
            )

244
245
246
247
248
249
250
        if (
            hasattr(self, "max_bs_in_cuda_graph")
            and bs <= self.max_bs_in_cuda_graph
            and forward_batch.forward_mode.is_cuda_graph()
        ):
            # Do in-place updates when CUDA graph is enabled and the batch forward mode
            # could use CUDA graph.
251
252
253
254
255

            transfer_adapter_info(
                self.cuda_graph_batch_info.weight_indices,
                self.cuda_graph_batch_info.lora_ranks,
                self.cuda_graph_batch_info.scalings,
256
257
            )

258
259
            self.cuda_graph_batch_info.bs = bs
            self.cuda_graph_batch_info.max_len = 1
260
261
            batch_info = self.cuda_graph_batch_info
        else:
262
263
264
265
266
267
268
269
270
271
272
273
274
            weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
            lora_ranks = torch.zeros(
                (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
            )
            scalings = torch.zeros(
                (self.max_loras_per_batch,), dtype=torch.float, device=self.device
            )
            transfer_adapter_info(
                weight_indices,
                lora_ranks,
                scalings,
            )

275
276
277
278
279
            seg_lens = (
                forward_batch.extend_seq_lens
                if forward_batch.forward_mode.is_extend()
                else torch.ones(bs, device=self.device)
            )
280
281
282
283
284
285
286
287

            max_len = (
                # Calculate max_len from the CPU copy to avoid D2H transfer.
                max(forward_batch.extend_seq_lens_cpu)
                if forward_batch.forward_mode.is_extend()
                else 1
            )

288
289
290
291
292
293
294
295
296
297
298
299
            seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
            seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)

            batch_info = LoRABatchInfo(
                bs=bs,
                seg_lens=seg_lens,
                seg_indptr=seg_indptr,
                max_len=max_len,
                weight_indices=weight_indices,
                lora_ranks=lora_ranks,
                scalings=scalings,
            )
300
301
        self.lora_backend.set_batch_info(batch_info)

302
303
304
305
306
307
308
309
        # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
        # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
        self.update_lora_info()

    def update_lora_info(self):
        """
        Update all LoRA modules to associate them with the latest memory buffer.
        """
310
        for layer_id, layer_modules in enumerate(self.lora_modules):
311
            for module_name, module in layer_modules.items():
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                if "qkv_proj" in module_name:
                    module.set_lora_info(
                        self.memory_pool.get_tensor(
                            "qkv_proj", layer_id, LoRAType.LORA_A
                        ),
                        self.memory_pool.get_tensor(
                            "q_proj", layer_id, LoRAType.LORA_B
                        ),
                        self.memory_pool.get_tensor(
                            "kv_proj", layer_id, LoRAType.LORA_B
                        ),
                    )
                else:
                    weight_name = get_weight_name(
326
                        module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
327
328
329
330
331
332
333
334
335
                    )
                    module.set_lora_info(
                        self.memory_pool.get_tensor(
                            weight_name, layer_id, LoRAType.LORA_A
                        ),
                        self.memory_pool.get_tensor(
                            weight_name, layer_id, LoRAType.LORA_B
                        ),
                    )
336

337
338
339
340
341
342
    def init_state(
        self,
        max_lora_rank: Optional[int] = None,
        target_modules: Optional[Iterable[str]] = None,
        lora_paths: Optional[Dict[str, LoRARef]] = None,
    ):
343
344
345
        """
        Initialize the internal (mutable) state of the LoRAManager.

346
347
        When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
        the target modules and max_lora_rank.
348
349
        """

350
351
352
        assert lora_paths or (
            max_lora_rank is not None and target_modules is not None
        ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
353

354
355
356
357
358
359
360
361
        self.init_lora_adapters(lora_paths)
        self.init_lora_shapes(
            max_lora_rank=max_lora_rank,
            target_modules=target_modules,
        )
        self.init_lora_weight_names()
        self.init_lora_modules()
        self.init_memory_pool()
362

363
364
365
    def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
        # Configs of all active LoRA adapters, indexed by LoRA ID.
        self.configs: Dict[str, LoRAConfig] = {}
366

367
368
        # LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
        self.loras: Dict[str, LoRAAdapter] = {}
369

370
371
        # Mapping from LoRA ID to LoRARef object.
        self.lora_refs: Dict[str, LoRARef] = {}
372

373
374
375
376
377
378
379
        if lora_paths:
            for lora_ref in lora_paths.values():
                result = self.load_lora_adapter(lora_ref)
                if not result.success:
                    raise RuntimeError(
                        f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
                    )
380

381
382
383
384
385
386
    def init_lora_shapes(
        self,
        max_lora_rank: Optional[int] = None,
        target_modules: Optional[Iterable[str]] = None,
    ):
        """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
387

388
389
390
391
392
393
        if target_modules is not None:
            self.target_modules = set(target_modules)
        else:
            self.target_modules = set()
            for config in self.configs.values():
                self.target_modules.update(config.target_modules)
394

395
396
        if max_lora_rank is not None:
            self.max_lora_rank = max_lora_rank
397
        else:
398
399
400
            self.max_lora_rank = max(
                [x.hf_config["r"] for x in self.configs.values()],
                default=0,
401
            )
402

403
    def init_lora_weight_names(self):
404
405
406
407
408
        """
        Add new LoRA weight names if needed based on the current `self.configs`.
        """

        # Target lora weight names for lora_a and lora_b modules respectively.
409
        lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
410
        self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
411

412
    def load_lora_weights(self, lora_ref: LoRARef):
413
        """
414
        Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
415
        """
416
417
418
419
420
421
422
423
424
        lora_adapter = LoRAAdapter(
            lora_ref.lora_id,
            self.configs[lora_ref.lora_id],
            self.base_hf_config,
            self.load_config,
            self.lora_backend,
        )
        lora_adapter.initialize_weights()
        self.loras[lora_ref.lora_id] = lora_adapter
425
426
427
428

        # Additional checks for flashinfer backend
        # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
        if self.lora_backend == "flashinfer":
429
            lora_dims = set(x.r for x in self.configs.values())
430
431
432
433
434
            scalings = set(x.scaling for x in self.loras.values())
            assert (
                len(lora_dims) == 1 and len(scalings) == 1
            ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "

435
    def init_memory_pool(self):
436
437
438
439
440
441
442
443
444
445
        """(Re)initialize the LoRA memory pool based on the current configurations."""
        self.memory_pool = LoRAMemoryPool(
            base_hf_config=self.base_hf_config,
            max_loras_per_batch=self.max_loras_per_batch,
            dtype=self.dtype,
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
            max_lora_rank=self.max_lora_rank,
            lora_weight_names=self.lora_weight_names,
            base_model=self.base_model,
446
447
        )

448
    def set_lora_module(self, module_name, module):
449
        lora_module = get_lora_layer(module, self.lora_backend)
450
451
452
        replace_submodule(self.base_model, module_name, lora_module)
        return lora_module

453
454
455
456
457
458
    def init_lora_modules(self):
        # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
        self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
            {} for _ in range(self.base_hf_config.num_hidden_layers)
        ]

459
460
461
        # Target module names of customized layers defined in python/sglang/srt/layers
        # e.g., {"qkv_proj", "o_proj"}
        customized_target_names = get_customized_names_from_hf_names(
462
            self.target_modules, self.base_model
463
464
465
        )

        for module_name, module in self.base_model.named_modules():
466
467
468
469
470
471
472
473
474
475
            # TODO (lifuhuang): in the future, we should consider generalizing the
            # should_apply_lora function to support mapping by full module name instead
            # of just the last part (e.g., "qkv_proj") to support scenarios with multiple
            # attention stacks (e.g., multimodal models).
            # See: https://github.com/sgl-project/sglang/issues/6608
            if getattr(
                self.base_model, "should_apply_lora", None
            ) and not self.base_model.should_apply_lora(module_name):
                continue

476
477
            # The module should be converted if it is included in target_names
            if module_name.split(".")[-1] in customized_target_names:
478
                layer_id = get_layer_id(module_name)
479
480
481
                self.lora_modules[layer_id][module_name] = self.set_lora_module(
                    module_name, module
                )