models.py 36.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
24
    is_base_embeddding_weights,
25
    is_moe_model,
26
27
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
28
    process_packed_modules_mapping,
29
30
    replace_submodule,
)
31
from vllm.model_executor.layers.fused_moe import FusedMoE
32
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
33
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
34
from vllm.model_executor.models.interfaces import is_pooling_model
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
37
from vllm.utils.cache import LRUCache
38
from vllm.utils.platform_utils import is_pin_memory_available
39

40
logger = init_logger(__name__)
41

42
43
44
45
46
47
48
49
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

50
    def _on_remove(self, key: int, value: T | None):
51
52
53
54
55
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


56
57
58
59
60
61
62
63
64
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


65
class LoRAModel:
66
67
68
69
70
71
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
72
        loras: dict[str, LoRALayerWeights],
73
    ) -> None:
74
75
76
77
78
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
79

80
        """
81
        self.id = lora_model_id
82

83
84
85
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
86
        self.rank = rank
87
        self.loras: dict[str, LoRALayerWeights] = loras
88

89
90
91
92
93
94
95
96
97
98
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

99
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
100
101
102
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

103
104
105
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

106
107
108
109
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
110
        tensors: dict[str, torch.Tensor],
111
        peft_helper: PEFTHelper,
112
        device: str = "cuda",
113
        dtype: torch.dtype | None = None,
114
        model_vocab_size: int | None = None,
115
        weights_mapper: WeightsMapper | None = None,
116
117
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
118

119
        loras: dict[str, LoRALayerWeights] = {}
120
        for tensor_name, tensor in tensors.items():
121
122
            if is_base_embeddding_weights(tensor_name):
                continue
123
            module_name, is_lora_a = parse_fine_tuned_lora_name(
124
125
                tensor_name, weights_mapper
            )
126
            if module_name not in loras:
127
                loras[module_name] = LoRALayerWeights.from_config(
128
                    module_name, peft_helper
129
                )
130

131
            if is_lora_a:
132
133
134
135
136
137
138
139
140
                if (
                    "lora_embedding_A" in tensor_name
                    and model_vocab_size is not None
                    and model_vocab_size != tensor.shape[1]
                ):
                    raise RuntimeError(
                        f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
                        f" with the base model's vocabulary size({model_vocab_size})."
                    )
141
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
142
            else:
143
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
144
        return cls(lora_model_id, peft_helper.r, loras)
145
146
147

    @classmethod
    def from_local_checkpoint(
148
149
        cls,
        lora_dir: str,
150
        expected_lora_modules: set[str],
151
152
        peft_helper: PEFTHelper,
        *,
153
        lora_model_id: int | None = None,
154
        device: str = "cuda",
155
        dtype: torch.dtype | None = None,
156
        model_vocab_size: int | None = None,
157
158
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
159
    ) -> "LoRAModel":
160
        """Create a LoRAModel from a local checkpoint.
161

162
163
164
165
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
166
            peft_helper: Loaded lora configuration information.
167
            lora_model_id: LoRA model id. If not given, automatically set by
168
169
170
171
172
173
174
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
175
176
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
177
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
178

179
        tensors: dict[str, torch.Tensor] = {}
180
        unexpected_modules: list[list[str] | str] = []
181
182
183

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
184
185
                if is_base_embeddding_weights(lora_module):
                    continue
186
                # Handle PEFT file format where experts.base_layer is the
187
188
189
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
190
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
191
192
                # Case for expert lora weights
                if ".experts" in module_name:
193
194
195
                    expert_idx = module_name.find(".experts")
                    expert_suffix = module_name[expert_idx + 1 :]
                    if expert_suffix not in expected_lora_modules:
196
                        unexpected_modules.append(module_name)
197
198

                elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
199
                    unexpected_modules.append(module_name)
200

201
202
203
204
205
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
206
207
                    f" Please verify that the loaded LoRA module is correct"
                )
208
209
210
211
212

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
213
214
215
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
216
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
217
218
219
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
220
221
                **tensorizer_args.deserialization_kwargs,
            )
222
            check_unexpected_modules(tensors)
223

224
        elif os.path.isfile(lora_tensor_path):
225
226
227
228
229
230
231
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
232
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
233
                # Load tensors if there are only expected modules.
234
                check_unexpected_modules(f)
235
236
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
237
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
238
239
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
240
            unexpected_modules = []
241
            target_modules = peft_helper.target_modules
242
243
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
244
245
246
247
248
249
250
251
252
253
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
254
            if unexpected_modules and not is_regex_target_modules(
255
256
                peft_helper.target_modules, expected_lora_modules
            ):
257
258
259
260
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
261
262
263
264
265
266
267
268
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
269
270
271
272
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        return cls.from_lora_tensors(
273
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
274
            tensors=tensors,
275
            peft_helper=peft_helper,
276
277
            device=device,
            dtype=dtype,
278
            model_vocab_size=model_vocab_size,
279
280
            weights_mapper=weights_mapper,
        )
281
282


283
class LoRAModelManager:
284
285
286
287
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
288
        model: SupportsLoRA,
289
290
291
292
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
293
        device: torch.device,
294
295
296
297
298
299
300
301
302
303
304
305
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
306
307
308
309
310
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
311
        self.lora_config = lora_config
312
        self.device = device
313
314
315
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
316
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
317
        self.vocab_size = vocab_size
318
319
320
321
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
322
323
            max_loras=self.lora_config.max_loras,
        )
324

325
326
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
327
        f" {self.model.__class__.__name__}."
328

329
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
330
        # Used to indicate whether the model is a multimodal model
331
332
333
334
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
335
336
            and hasattr(self.model, "get_mm_mapping")
        )
337
        self.is_pooling_model = is_pooling_model(self.model)
338
339
340
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
341
        self._last_mapping: LoRAMapping | None = None
342
        self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
343
        self._create_lora_modules()
344

345
        self.model.lora_manager = self
346
347
348

    def __len__(self) -> int:
        return len(self._registered_adapters)
349
350
351
352
353
354
355
356
357

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

358
359
360
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
361

362
    def activate_adapter(
363
364
365
366
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
367
        if lora_id in self._active_adapters:
368
369
            return False
        first_free_slot = next(
370
371
372
373
374
375
376
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
377
378
379
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
380
381
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
382
383
384
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
385
386
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
387
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
388
389
390
391
392
            if not module_lora:
                module.reset_lora(index)
                continue
            # Note (gnovack) - If MOE lora weights are not split into
            # num_experts chunks, we split them here
393
            if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
                module_lora.lora_a
            ):
                # Handle PEFT file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                gate_up_proj_lora = self._get_lora_layer_weights(
                    lora_model, module_name + ".base_layer"
                )
                down_proj_lora = module_lora
                # FIXME Edge case where LoRA is not added to gate_up_proj
                # or down_proj
                assert gate_up_proj_lora is not None
                assert down_proj_lora is not None
                if self._is_3d_moe_model:
                    module_lora.lora_a = [
                        gate_up_proj_lora.lora_a,
                        down_proj_lora.lora_a,
                    ]
                    module_lora.lora_b = [
                        gate_up_proj_lora.lora_b,
                        down_proj_lora.lora_b,
                    ]
                else:
                    # Some 3D MoE models haven't added the `is_3d_moe_weight`
                    # attribute yet, so fallback here
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b
446
447
448
449
450
            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
451

452
453
        return True

454
    def _deactivate_adapter(self, lora_id: int):
455
456
457
458
459
460
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

461
    def _add_adapter(self, lora: LoRAModel):
462
        self._create_merged_loras_inplace(lora)
463
        self._registered_adapters[lora.id] = lora
464

465
    def pin_adapter(self, lora_id: int) -> bool:
466
467
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
468
            "Pinning is not supported in LoRAModelManager. "
469
470
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
471

472
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
473
474
475
476
477
478
479
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
480

481
    def remove_all_adapters(self):
482
        """Remove all LoRAModels from the manager."""
483
        self._registered_adapters.clear()
484
        self.lora_index_to_id = [None] * self.lora_slots
485
        self._active_adapters.clear()
486
487

    def _create_lora_modules(self):
488
489
490
491
492
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
493
            return module_name.rpartition(".")[0]
494

495
        for module_name, module in self.model.named_modules(remove_duplicate=False):
496
497
            if isinstance(module, PPMissingLayer):
                continue
498

499
500
            if not self._match_target_modules(module_name):
                continue
501
502
503
504
505
506
507
508
509
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
510
511
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
512
513
514
515
516
517
518
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
519
            new_module = replace_submodule(
520
521
522
523
524
525
526
527
528
529
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
530

531
532
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
533
                logits_processor_module_name = "logits_processor"
534
535
536
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
537
538
                        f"{parent_module}.{logits_processor_module_name}"
                    )
539

540
                logits_processor_module = self.model.get_submodule(
541
542
                    logits_processor_module_name
                )
543

544
                new_module = replace_submodule(
545
546
547
548
549
550
551
552
553
554
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
555
556
557
558
559
560

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
561
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
562
                continue
563
564
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
565
566
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
567
        pass
568
569

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
570
        assert isinstance(module, BaseLayerWithLoRA), (
571
572
            f"Module {module_name} must be a BaseLayerWithLoRA instance, "
            f"got {type(module)}"
573
        )
574
575
        self.modules[module_name] = module

Terry's avatar
Terry committed
576
    def create_dummy_lora(
577
578
579
        self,
        lora_id: int,
        rank: int,
580
        embedding_modules: dict[str, str] | None = None,
581
    ) -> LoRAModel:
582
        """Create zero-initialized LoRAModel for warmup."""
583
        model = LoRAModel(lora_id, rank, {})
584
        for module_name, module in self.model.named_modules():
585
586
587
588
589
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
590
591
592
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
593
                assert embedding_modules is not None
Terry's avatar
Terry committed
594
                if parts[-1] in embedding_modules:
595
596
597
598
599
600
601
602
603
604
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
605
606
607
608
609
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
610
                        module.lora_a_stacked[0].dtype,
611
                        "cpu",
612
                    )
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
637
638
639
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
640
641
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
642
                        rank,
643
                        module.lora_a_stacked[0].dtype,
644
645
                        "cpu",
                    )
646
                    model.loras[module_name] = lora
647
648
649
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
650
                subloras: list[LoRALayerWeights | None] = []
651
652
653
654
655
656
657
658
659
660
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
661
662
663
664
                if module.__class__.__name__ == "FusedMoEWithLoRA":
                    lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
                else:
                    lora = PackedLoRALayerWeights.pack(subloras)
665
                model.loras[module_name] = lora
666
667
668
669
670
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
671
672
673
674
675
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
676

677
678
679
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
680
        language model. LoRA for other modules, such as the vision tower, will
681
682
683
684
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
685
            prefix_lst = module_mapping.connector + module_mapping.tower_model
686
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
687
688
        return False

689
690
691
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
692
693
694
695
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
696
697
698
699
700
701
702
703
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
704
            replacement_loras: list[LoRALayerWeights | None] = []
705
            replaced_module: set[str] = set()
706
707
            has_replacement = False
            for r in new_module_names:
708
                lora = self._get_lora_layer_weights(lora_model, r)
709
710
711
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
712
                    replaced_module.add(r)
713
714
715
716
717
718
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
719
            # HACK Temporary solution for the pool model.
720
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
721
722
723
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
724
725
726
727
728
729
730
731
            if module_name.endswith(".experts"):
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
                    replacement_loras, module_name
                )
            else:
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                    replacement_loras
                )
732
733
734
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
735

736
737
738
        for lora in lora_model.loras.values():
            lora.optimize()

739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        first_lora: LoRALayerWeights = next(iter(lora_model.loras.values()))
        assert first_lora.lora_a is not None
        if isinstance(first_lora.lora_a, list):
            lora_device = next(iter(first_lora.lora_a))
        else:
            lora_device = first_lora.lora_a.device
        # Execute pin_memory after LoRA weight merging, mainly because:
        # 1. Some MoE models have a large number of LoRA weights. If we
        # perform # pin_memory immediately after loading weights, the
        # overhead is significant.
        # 2. The weight packing above (e.g., pack_moe) may invalidate the
        # pin_memory allocation, so we execute it after packing.

        pin_memory = str(lora_device) == "cpu" and is_pin_memory_available()
        if pin_memory:
            for lora in lora_model.loras.values():
                if isinstance(lora.lora_a, list):
                    for index in range(len(lora.lora_a)):
                        if lora.lora_a[index] is None:
                            continue
                        lora.lora_a[index] = lora.lora_a[index].pin_memory()
                        lora.lora_b[index] = lora.lora_b[index].pin_memory()
                else:
                    lora.lora_a = lora.lora_a.pin_memory()
                    lora.lora_b = lora.lora_b.pin_memory()

765
    def _get_lora_layer_weights(
766
        self, lora_model: LoRAModel, module_name: str
767
    ) -> LoRALayerWeights | None:
768
        org_module_name = module_name
769
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
770
771
772
773
774
775
776
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
777
778
                    "after removing the prefix 'model.'."
                )
779
780
        return lora_model.get_lora(org_module_name)

781
    def deactivate_adapter(self, adapter_id: int) -> bool:
782
783
784
785
786
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
787
788

    def add_adapter(self, adapter: LoRAModel) -> bool:
789
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
790
791
792
793
794
795
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
796

797
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
798
799
800
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
801
802

    def remove_adapter(self, adapter_id: int) -> bool:
803
804
805
806
807
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
808

809
810
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
811

812
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
813
        return self._registered_adapters.get(adapter_id)
814
815
816


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
817
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
818
        super().__init__(capacity, deactivate_lora_fn)
819
820
821
822
823


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

824
825
826
827
828
829
830
831
832
833
834
835
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
836
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
837
838
            self.capacity, self.deactivate_adapter
        )
839
        self._active_adapters: LoRALRUCache = LoRALRUCache(
840
841
            self.lora_slots, self._deactivate_adapter
        )
842

843
    def list_adapters(self) -> dict[int, LoRAModel]:
844
        """List all registered LoRAModels."""
845
        return dict(self._registered_adapters.cache)
846

847
    def add_adapter(self, lora: LoRAModel) -> bool:
848
        """Add a LoRAModel to the manager."""
849
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
850
851
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
852
853
854
            was_added = True
        else:
            # We always touch to update the LRU cache order
855
            self._registered_adapters.touch(lora.id)
856
857
858
            was_added = False
        return was_added

859
    def activate_adapter(
860
861
862
        self,
        lora_id: int,
    ) -> bool:
863
864
865
866
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
867
868
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
869
        # We always touch to update the LRU cache order
870
        self._active_adapters.touch(lora_id)
871
872
        return result

873
874
875
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
876
877
878
            return True
        return False

879
    def pin_adapter(self, lora_id: int) -> bool:
880
881
882
883
884
885
886
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
887
            self._registered_adapters.pin(lora_id)
888
        except ValueError as err:
889
890
891
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
892
893

    def _pin_lora_in_gpu_cache(self, lora_id: int):
894
        if lora_id not in self._active_adapters:
895
            # move lora to gpu if not already active
896
            self.activate_adapter(lora_id)
897

898
        self._active_adapters.pin(lora_id)
899

900
901

def create_lora_manager(
902
903
904
905
906
907
908
909
910
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
911
    """Create a LoRA adapter for a given model."""
912
    if not isinstance(model, SupportsLoRA):
913
914
915
916
917
918
919
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
920
        device=device,
921
922
        **kwargs,
    )
923
    return lora_manager