lora_manager.py 21.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17

# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"

18
import logging
19
from typing import Dict, Iterable, List, Optional, Set, Tuple
20
21
22

import torch

23
24
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
from sglang.srt.lora.lora import LoRAAdapter
28
from sglang.srt.lora.lora_config import LoRAConfig
29
from sglang.srt.lora.lora_registry import LoRARef
30
31
32
33
34
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
    LoRABatchInfo,
    LoRAType,
    get_layer_id,
35
36
    get_normalized_target_modules,
    get_target_module_name,
37
)
38
from sglang.srt.managers.io_struct import LoRAUpdateResult
39
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
from sglang.srt.utils import replace_submodule
41

42
43
logger = logging.getLogger(__name__)

44
45
46
47

class LoRAManager:
    def __init__(
        self,
48
49
50
51
52
53
        base_model: torch.nn.Module,
        base_hf_config: AutoConfig,
        max_loras_per_batch: int,
        load_config: LoadConfig,
        dtype: torch.dtype,
        lora_backend: str = "triton",
54
55
        tp_size: int = 1,
        tp_rank: int = 0,
56
57
        max_lora_rank: Optional[int] = None,
        target_modules: Optional[Iterable[str]] = None,
58
        lora_paths: Optional[List[LoRARef]] = None,
59
    ):
60
61
62
63
64
        self.base_model: torch.nn.Module = base_model
        self.base_hf_config: AutoConfig = base_hf_config
        self.max_loras_per_batch: int = max_loras_per_batch
        self.load_config: LoadConfig = load_config
        self.dtype: torch.dtype = dtype
65
66
67
        self.device: torch.device = next(self.base_model.parameters()).device
        self.tp_size: int = tp_size
        self.tp_rank: int = tp_rank
68
69
70

        # LoRA backend for running sgemm kernels
        logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
71
        backend_type = get_backend_from_name(lora_backend)
72
        self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
73

74
        # Initialize mutable internal state of the LoRAManager.
75
76
77
78
79
        self.init_state(
            max_lora_rank=max_lora_rank,
            target_modules=target_modules,
            lora_paths=lora_paths,
        )
80

81
82
83
84
85
86
87
88
89
    def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
        self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
        with torch.device("cuda"):
            self.cuda_graph_batch_info = LoRABatchInfo(
                bs=self.max_bs_in_cuda_graph,
                seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
                seg_indptr=torch.zeros(
                    self.max_bs_in_cuda_graph + 1, dtype=torch.int32
                ),
90
                max_len=1,
91
92
93
94
95
96
97
                weight_indices=torch.zeros(
                    self.max_bs_in_cuda_graph, dtype=torch.int32
                ),
                lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
                scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
            )

98
99
100
101
102
103
104
105
106
107
108
            # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
            # across batches.
            self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
            torch.cumsum(
                self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
                dim=0,
                out=self.cuda_graph_batch_info.seg_indptr[
                    1 : self.max_bs_in_cuda_graph + 1
                ],
            )

109
110
111
112
113
114
115
    def create_lora_update_result(
        self, success: bool, error_message: str = ""
    ) -> LoRAUpdateResult:
        return LoRAUpdateResult(
            success=success,
            error_message=error_message,
            loaded_adapters={
116
117
                lora_ref.lora_name: lora_ref.lora_path
                for lora_ref in self.lora_refs.values()
118
119
120
            },
        )

121
    def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
122
123
124
125
        """
        Load a single LoRA adapter from the specified path.

        Args:
126
            lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
127
        """
128
129
130
131
132
133
        assert (
            lora_ref.lora_name is not None and lora_ref.lora_path is not None
        ), "LoRARef must have both lora_name and lora_path set for loading."
        assert (
            lora_ref.lora_id not in self.loras
        ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
134

135
136
137
138
139
        try:
            # load configs
            new_adapter = LoRAConfig(lora_ref.lora_path)
            self.validate_new_adapter(new_adapter, lora_ref)
            self.configs[lora_ref.lora_id] = new_adapter
140

141
142
            # load weights
            self.load_lora_weights(lora_ref)
143

144
145
            # keep metadata for displayed messages
            self.lora_refs[lora_ref.lora_id] = lora_ref
146
            self.num_pinned_loras += int(lora_ref.pinned)
147
        except Exception as e:
148
149
150
            return self.create_lora_update_result(
                success=False,
                error_message=str(e),
151
            )
152

153
        return self.create_lora_update_result(success=True)
154

155
    def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
156
157
158
159
        """
        Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
        """

160
        # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
161
162
        memory_pool = getattr(self, "memory_pool", None)
        incompatible = memory_pool and not memory_pool.can_support(lora_config)
163
164
        if incompatible:
            raise ValueError(
165
166
167
168
169
170
171
172
173
174
175
                f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
                "LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
                "`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
            )

        # Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
        if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
            raise ValueError(
                f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
                "in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
                "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
176
177
            )

178
    def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
179
180
181
182
        """
        Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
        delete the corresponding LoRA modules.
        """
183

184
185
        adapter = self.configs.get(lora_ref.lora_id)
        lora_ref = self.lora_refs.get(lora_ref.lora_id)
186
        assert (
187
            adapter is not None and lora_ref is not None
188
        ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
189

190
191
192
193
        try:
            del self.configs[lora_ref.lora_id]
            del self.loras[lora_ref.lora_id]
            del self.lora_refs[lora_ref.lora_id]
194
            self.num_pinned_loras -= int(lora_ref.pinned)
195
196
197
198
199
        except Exception as e:
            return self.create_lora_update_result(
                success=False,
                error_message=str(e),
            )
200

201
        return self.create_lora_update_result(success=True)
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def validate_lora_batch(self, lora_ids: set[str]) -> bool:
        """
        Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
        """
        if len(lora_ids) > self.max_loras_per_batch:
            return False

        # skip pinned LoRA check if no pinned LoRA adapters are loaded.
        if self.num_pinned_loras == 0:
            return True

        # counting the number of pinned LoRA adapters in the batch.
        pinned_loras_in_batch = 0
        for lora_id in lora_ids:
            if lora_id is not None:
                lora_ref = self.lora_refs.get(lora_id)
                assert (
                    lora_ref is not None
                ), f"LoRA ID {lora_id} not found in lora_refs."
                pinned_loras_in_batch += int(lora_ref.pinned)

        assert pinned_loras_in_batch <= self.num_pinned_loras, (
            f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
            f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
        )

        required_slots = len(lora_ids) - pinned_loras_in_batch
        mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras

        return required_slots <= mem_pool_vacancy

234
    def prepare_lora_batch(self, forward_batch: ForwardBatch):
235

236
        # Load active loras into lora memory pool
237
        cur_uids = set(forward_batch.lora_ids)
238

239
        assert len(cur_uids) <= self.max_loras_per_batch
240
241
242
243
244
245
        self.memory_pool.prepare_lora_batch(
            cur_uids=cur_uids,
            lora_adapters=self.loras,
            lora_modules=self.lora_modules,
            lora_refs=self.lora_refs.copy(),  # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
        )
246

247
        # set up batch info shared by all lora modules
248
        bs = forward_batch.batch_size
249

250
251
252
253
254
255
256
257
258
        def transfer_adapter_info(
            weight_indices_out: torch.Tensor,
            lora_ranks_out: torch.Tensor,
            scalings_out: torch.Tensor,
        ):
            """
            Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
            to device (CUDA) asynchronously.
            """
259
            weight_indices = [0] * len(forward_batch.lora_ids)
260
261
            lora_ranks = [0] * self.max_loras_per_batch
            scalings = [0] * self.max_loras_per_batch
262
            for i, uid in enumerate(forward_batch.lora_ids):
263
264
265
                weight_indices[i] = self.memory_pool.get_buffer_id(uid)
                if uid is not None:
                    lora = self.loras[uid]
266
                    lora_ranks[weight_indices[i]] = lora.config.r
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                    scalings[weight_indices[i]] = lora.scaling

            # Use pinned memory to avoid synchronizations during host-to-device transfer
            weight_indices_tensor = torch.tensor(
                weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
            )
            lora_ranks_tensor = torch.tensor(
                lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
            )
            scalings_tensor = torch.tensor(
                scalings, dtype=torch.float, pin_memory=True, device="cpu"
            )

            # Copy to device tensors asynchronously
            weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
            lora_ranks_out[: self.max_loras_per_batch].copy_(
                lora_ranks_tensor, non_blocking=True
            )
            scalings_out[: self.max_loras_per_batch].copy_(
                scalings_tensor, non_blocking=True
            )

289
290
291
292
293
294
295
        if (
            hasattr(self, "max_bs_in_cuda_graph")
            and bs <= self.max_bs_in_cuda_graph
            and forward_batch.forward_mode.is_cuda_graph()
        ):
            # Do in-place updates when CUDA graph is enabled and the batch forward mode
            # could use CUDA graph.
296
297
298
299
300

            transfer_adapter_info(
                self.cuda_graph_batch_info.weight_indices,
                self.cuda_graph_batch_info.lora_ranks,
                self.cuda_graph_batch_info.scalings,
301
302
            )

303
304
            self.cuda_graph_batch_info.bs = bs
            self.cuda_graph_batch_info.max_len = 1
305
306
            batch_info = self.cuda_graph_batch_info
        else:
307
308
309
310
311
312
313
314
315
316
317
318
319
            weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
            lora_ranks = torch.zeros(
                (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
            )
            scalings = torch.zeros(
                (self.max_loras_per_batch,), dtype=torch.float, device=self.device
            )
            transfer_adapter_info(
                weight_indices,
                lora_ranks,
                scalings,
            )

320
321
322
323
324
            seg_lens = (
                forward_batch.extend_seq_lens
                if forward_batch.forward_mode.is_extend()
                else torch.ones(bs, device=self.device)
            )
325
326
327
328
329
330
331
332

            max_len = (
                # Calculate max_len from the CPU copy to avoid D2H transfer.
                max(forward_batch.extend_seq_lens_cpu)
                if forward_batch.forward_mode.is_extend()
                else 1
            )

333
334
335
336
337
338
339
340
341
342
343
344
            seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
            seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)

            batch_info = LoRABatchInfo(
                bs=bs,
                seg_lens=seg_lens,
                seg_indptr=seg_indptr,
                max_len=max_len,
                weight_indices=weight_indices,
                lora_ranks=lora_ranks,
                scalings=scalings,
            )
345
346
        self.lora_backend.set_batch_info(batch_info)

347
348
349
350
    def update_lora_info(self):
        """
        Update all LoRA modules to associate them with the latest memory buffer.
        """
351
        for layer_id, layer_modules in enumerate(self.lora_modules):
352
            for module_name, module in layer_modules.items():
353
354
                target_module = get_target_module_name(
                    module_name, self.memory_pool.target_modules
355
356
                )
                module.set_lora_info(
357
358
359
360
361
362
363
364
365
366
                    self.memory_pool.get_tensor(
                        target_module=target_module,
                        layer_id=layer_id,
                        lora_type=LoRAType.LORA_A,
                    ),
                    self.memory_pool.get_tensor(
                        target_module=target_module,
                        layer_id=layer_id,
                        lora_type=LoRAType.LORA_B,
                    ),
367
                )
368

369
370
371
372
    def init_state(
        self,
        max_lora_rank: Optional[int] = None,
        target_modules: Optional[Iterable[str]] = None,
373
        lora_paths: Optional[List[LoRARef]] = None,
374
    ):
375
376
377
        """
        Initialize the internal (mutable) state of the LoRAManager.

378
379
        When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
        the target modules and max_lora_rank.
380
381
        """

382
383
384
        assert lora_paths or (
            max_lora_rank is not None and target_modules is not None
        ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
385

386
387
388
389
390
391
392
        self.init_lora_adapters(lora_paths)
        self.init_lora_shapes(
            max_lora_rank=max_lora_rank,
            target_modules=target_modules,
        )
        self.init_lora_modules()
        self.init_memory_pool()
393
        self.update_lora_info()
394

395
    def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
396
397
        # Configs of all active LoRA adapters, indexed by LoRA ID.
        self.configs: Dict[str, LoRAConfig] = {}
398

399
400
        # LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
        self.loras: Dict[str, LoRAAdapter] = {}
401

402
403
        # Mapping from LoRA ID to LoRARef object.
        self.lora_refs: Dict[str, LoRARef] = {}
404

405
406
407
        # Count of pinned LoRA adapters.
        self.num_pinned_loras: int = 0

408
        if lora_paths:
409
            for lora_ref in lora_paths:
410
411
412
413
414
                result = self.load_lora_adapter(lora_ref)
                if not result.success:
                    raise RuntimeError(
                        f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
                    )
415

416
417
418
419
420
421
    def init_lora_shapes(
        self,
        max_lora_rank: Optional[int] = None,
        target_modules: Optional[Iterable[str]] = None,
    ):
        """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
422

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        self.target_modules = (
            get_normalized_target_modules(target_modules) if target_modules else set()
        )

        for lora_id, config in self.configs.items():
            if not isinstance(config.target_modules, list):
                raise ValueError(
                    f"SGLang currently only supports inferring LoRA target modules when a list of "
                    "suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
                    "specify `--lora-target-modules` during server startup. You can specify `all` to "
                    "enable all support modules types. "
                )

            adapter_target_modules = get_normalized_target_modules(
                config.target_modules
            )

            if target_modules is not None:
                # When `--lora-target-modules` is provided, validate adapter target modules is a subset of the specified target modules.
                if not adapter_target_modules.issubset(self.target_modules):
                    unsupported_modules = adapter_target_modules - self.target_modules
                    lora_name = self.lora_refs[lora_id].lora_name
445
                    raise ValueError(
446
447
448
449
                        f"LoRA adapter '{lora_name}' contains target modules {sorted(unsupported_modules)} "
                        f"that are not included in the specified --lora-target-modules {sorted(self.target_modules)}. "
                        f"Please update --lora-target-modules to include all required modules: "
                        f"{sorted(self.target_modules | adapter_target_modules)}, or use 'all' to enable all supported modules."
450
                    )
451
452
453
            else:
                # Otherwise, infer target_modules from adapter configs.
                self.target_modules.update(adapter_target_modules)
454

455
456
        if max_lora_rank is not None:
            self.max_lora_rank = max_lora_rank
457
        else:
458
            self.max_lora_rank = max(
459
                [x.r for x in self.configs.values()],
460
                default=0,
461
            )
462

463
    def load_lora_weights(self, lora_ref: LoRARef):
464
        """
465
        Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
466
        """
467
468
469
470
471
472
473
474
475
        lora_adapter = LoRAAdapter(
            lora_ref.lora_id,
            self.configs[lora_ref.lora_id],
            self.base_hf_config,
            self.load_config,
            self.lora_backend,
        )
        lora_adapter.initialize_weights()
        self.loras[lora_ref.lora_id] = lora_adapter
476

477
    def init_memory_pool(self):
478
479
480
481
482
483
484
485
        """(Re)initialize the LoRA memory pool based on the current configurations."""
        self.memory_pool = LoRAMemoryPool(
            base_hf_config=self.base_hf_config,
            max_loras_per_batch=self.max_loras_per_batch,
            dtype=self.dtype,
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
            max_lora_rank=self.max_lora_rank,
486
            target_modules=self.target_modules,
487
            base_model=self.base_model,
488
489
        )

490
    def set_lora_module(self, module_name, module):
491
        lora_module = get_lora_layer(module, self.lora_backend)
492
493
494
        replace_submodule(self.base_model, module_name, lora_module)
        return lora_module

495
496
497
498
499
500
    def init_lora_modules(self):
        # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
        self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
            {} for _ in range(self.base_hf_config.num_hidden_layers)
        ]

501
        for module_name, module in self.base_model.named_modules():
502
503
504
505
506
507
508
509
510
511
            # TODO (lifuhuang): in the future, we should consider generalizing the
            # should_apply_lora function to support mapping by full module name instead
            # of just the last part (e.g., "qkv_proj") to support scenarios with multiple
            # attention stacks (e.g., multimodal models).
            # See: https://github.com/sgl-project/sglang/issues/6608
            if getattr(
                self.base_model, "should_apply_lora", None
            ) and not self.base_model.should_apply_lora(module_name):
                continue

512
            # The module should be converted if it is included in target_names
513
            if module_name.split(".")[-1] in self.target_modules:
514
                layer_id = get_layer_id(module_name)
515
516
517
                self.lora_modules[layer_id][module_name] = self.set_lora_module(
                    module_name, module
                )