models.py 35.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
24
    is_base_embeddding_weights,
25
    is_moe_model,
26
27
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
28
    process_packed_modules_mapping,
29
30
    replace_submodule,
)
31
from vllm.model_executor.layers.fused_moe import FusedMoE
32
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
33
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
34
from vllm.model_executor.models.interfaces import is_pooling_model
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
37
from vllm.utils.cache import LRUCache
38
from vllm.utils.platform_utils import is_pin_memory_available
39

40
logger = init_logger(__name__)
41

42
43
44
45
46
47
48
49
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

50
    def _on_remove(self, key: int, value: T | None):
51
52
53
54
55
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


56
57
58
59
60
61
62
63
64
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


65
class LoRAModel:
66
67
68
69
70
71
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
72
        loras: dict[str, LoRALayerWeights],
73
    ) -> None:
74
75
76
77
78
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
79

80
        """
81
        self.id = lora_model_id
82

83
84
85
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
86
        self.rank = rank
87
        self.loras: dict[str, LoRALayerWeights] = loras
88

89
90
91
92
93
94
95
96
97
98
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

99
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
100
101
102
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

103
104
105
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

106
107
108
109
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
110
        tensors: dict[str, torch.Tensor],
111
        peft_helper: PEFTHelper,
112
        device: str = "cuda",
113
        dtype: torch.dtype | None = None,
114
        model_vocab_size: int | None = None,
115
        weights_mapper: WeightsMapper | None = None,
116
117
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
118
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
119
        loras: dict[str, LoRALayerWeights] = {}
120
        for tensor_name, tensor in tensors.items():
121
122
            if is_base_embeddding_weights(tensor_name):
                continue
123
            module_name, is_lora_a = parse_fine_tuned_lora_name(
124
125
                tensor_name, weights_mapper
            )
126
            if module_name not in loras:
127
                loras[module_name] = LoRALayerWeights.from_config(
128
                    module_name, peft_helper
129
                )
130

131
            if is_lora_a:
132
133
134
135
136
137
138
139
140
                if (
                    "lora_embedding_A" in tensor_name
                    and model_vocab_size is not None
                    and model_vocab_size != tensor.shape[1]
                ):
                    raise RuntimeError(
                        f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
                        f" with the base model's vocabulary size({model_vocab_size})."
                    )
141
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
142
                if pin_memory:
143
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
144
            else:
145
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
146

147
                if pin_memory:
148
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
149

150
        return cls(lora_model_id, peft_helper.r, loras)
151
152
153

    @classmethod
    def from_local_checkpoint(
154
155
        cls,
        lora_dir: str,
156
        expected_lora_modules: set[str],
157
158
        peft_helper: PEFTHelper,
        *,
159
        lora_model_id: int | None = None,
160
        device: str = "cuda",
161
        dtype: torch.dtype | None = None,
162
        model_vocab_size: int | None = None,
163
164
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
165
    ) -> "LoRAModel":
166
        """Create a LoRAModel from a local checkpoint.
167

168
169
170
171
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
172
            peft_helper: Loaded lora configuration information.
173
            lora_model_id: LoRA model id. If not given, automatically set by
174
175
176
177
178
179
180
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
181
182
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
183
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
184

185
        tensors: dict[str, torch.Tensor] = {}
186
        unexpected_modules: list[list[str] | str] = []
187
188
189

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
190
191
                if is_base_embeddding_weights(lora_module):
                    continue
192
                # Handle PEFT file format where experts.base_layer is the
193
194
195
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
196
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
197
198
                # Case for expert lora weights
                if ".experts" in module_name:
199
200
201
                    expert_idx = module_name.find(".experts")
                    expert_suffix = module_name[expert_idx + 1 :]
                    if expert_suffix not in expected_lora_modules:
202
                        unexpected_modules.append(module_name)
203
204

                elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
205
                    unexpected_modules.append(module_name)
206

207
208
209
210
211
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
212
213
                    f" Please verify that the loaded LoRA module is correct"
                )
214
215
216
217
218

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
219
220
221
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
222
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
223
224
225
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
226
227
                **tensorizer_args.deserialization_kwargs,
            )
228
            check_unexpected_modules(tensors)
229

230
        elif os.path.isfile(lora_tensor_path):
231
232
233
234
235
236
237
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
238
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
239
                # Load tensors if there are only expected modules.
240
                check_unexpected_modules(f)
241
242
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
243
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
244
245
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
246
            unexpected_modules = []
247
            target_modules = peft_helper.target_modules
248
249
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
250
251
252
253
254
255
256
257
258
259
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
260
            if unexpected_modules and not is_regex_target_modules(
261
262
                peft_helper.target_modules, expected_lora_modules
            ):
263
264
265
266
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
267
268
269
270
271
272
273
274
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
275
276
277
278
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        return cls.from_lora_tensors(
279
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
280
            tensors=tensors,
281
            peft_helper=peft_helper,
282
283
            device=device,
            dtype=dtype,
284
            model_vocab_size=model_vocab_size,
285
286
            weights_mapper=weights_mapper,
        )
287
288


289
class LoRAModelManager:
290
291
292
293
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
294
        model: SupportsLoRA,
295
296
297
298
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
299
        device: torch.device,
300
301
302
303
304
305
306
307
308
309
310
311
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
312
313
314
315
316
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
317
        self.lora_config = lora_config
318
        self.device = device
319
320
321
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
322
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
323
        self.vocab_size = vocab_size
324
325
326
327
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
328
329
            max_loras=self.lora_config.max_loras,
        )
330

331
332
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
333
        f" {self.model.__class__.__name__}."
334

335
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
336
        # Used to indicate whether the model is a multimodal model
337
338
339
340
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
341
342
            and hasattr(self.model, "get_mm_mapping")
        )
343
        self.is_pooling_model = is_pooling_model(self.model)
344
345
346
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
347
        self._last_mapping: LoRAMapping | None = None
348
        self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
349
        self._create_lora_modules()
350

351
        self.model.lora_manager = self
352
353
354

    def __len__(self) -> int:
        return len(self._registered_adapters)
355
356
357
358
359
360
361
362
363

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

364
365
366
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
367

368
    def activate_adapter(
369
370
371
372
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
373
        if lora_id in self._active_adapters:
374
375
            return False
        first_free_slot = next(
376
377
378
379
380
381
382
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
383
384
385
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
386
387
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
388
389
390
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
391
392
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
393
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
394
395
396
397
398
            if not module_lora:
                module.reset_lora(index)
                continue
            # Note (gnovack) - If MOE lora weights are not split into
            # num_experts chunks, we split them here
399
            if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                module_lora.lora_a
            ):
                # Handle PEFT file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                gate_up_proj_lora = self._get_lora_layer_weights(
                    lora_model, module_name + ".base_layer"
                )
                down_proj_lora = module_lora
                # FIXME Edge case where LoRA is not added to gate_up_proj
                # or down_proj
                assert gate_up_proj_lora is not None
                assert down_proj_lora is not None
                if self._is_3d_moe_model:
                    module_lora.lora_a = [
                        gate_up_proj_lora.lora_a,
                        down_proj_lora.lora_a,
                    ]
                    module_lora.lora_b = [
                        gate_up_proj_lora.lora_b,
                        down_proj_lora.lora_b,
                    ]
                else:
                    # Some 3D MoE models haven't added the `is_3d_moe_weight`
                    # attribute yet, so fallback here
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b
452
453
454
455
456
            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
457

458
459
        return True

460
    def _deactivate_adapter(self, lora_id: int):
461
462
463
464
465
466
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

467
    def _add_adapter(self, lora: LoRAModel):
468
        self._create_merged_loras_inplace(lora)
469
        self._registered_adapters[lora.id] = lora
470

471
    def pin_adapter(self, lora_id: int) -> bool:
472
473
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
474
            "Pinning is not supported in LoRAModelManager. "
475
476
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
477

478
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
479
480
481
482
483
484
485
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
486

487
    def remove_all_adapters(self):
488
        """Remove all LoRAModels from the manager."""
489
        self._registered_adapters.clear()
490
        self.lora_index_to_id = [None] * self.lora_slots
491
        self._active_adapters.clear()
492
493

    def _create_lora_modules(self):
494
495
496
497
498
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
499
            return module_name.rpartition(".")[0]
500

501
        for module_name, module in self.model.named_modules(remove_duplicate=False):
502
503
            if isinstance(module, PPMissingLayer):
                continue
504

505
506
            if not self._match_target_modules(module_name):
                continue
507
508
509
510
511
512
513
514
515
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
516
517
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
518
519
520
521
522
523
524
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
525
            new_module = replace_submodule(
526
527
528
529
530
531
532
533
534
535
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
536

537
538
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
539
                logits_processor_module_name = "logits_processor"
540
541
542
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
543
544
                        f"{parent_module}.{logits_processor_module_name}"
                    )
545

546
                logits_processor_module = self.model.get_submodule(
547
548
                    logits_processor_module_name
                )
549

550
                new_module = replace_submodule(
551
552
553
554
555
556
557
558
559
560
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
561
562
563
564
565
566

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
567
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
568
                continue
569
570
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
571
572
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
573
        pass
574
575

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
576
577
578
579
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance,"
        )
        f" got {type(module)}"
580
581
        self.modules[module_name] = module

Terry's avatar
Terry committed
582
    def create_dummy_lora(
583
584
585
        self,
        lora_id: int,
        rank: int,
586
        embedding_modules: dict[str, str] | None = None,
587
    ) -> LoRAModel:
588
        """Create zero-initialized LoRAModel for warmup."""
589
        model = LoRAModel(lora_id, rank, {})
590
        for module_name, module in self.model.named_modules():
591
592
593
594
595
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
596
597
598
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
599
                assert embedding_modules is not None
Terry's avatar
Terry committed
600
                if parts[-1] in embedding_modules:
601
602
603
604
605
606
607
608
609
610
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
611
612
613
614
615
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
616
                        module.lora_a_stacked[0].dtype,
617
                        "cpu",
618
                    )
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
643
644
645
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
646
647
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
648
                        rank,
649
                        module.lora_a_stacked[0].dtype,
650
651
                        "cpu",
                    )
652
                    model.loras[module_name] = lora
653
654
655
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
656
                subloras: list[LoRALayerWeights | None] = []
657
658
659
660
661
662
663
664
665
666
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
667
668
669
670
                if module.__class__.__name__ == "FusedMoEWithLoRA":
                    lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
                else:
                    lora = PackedLoRALayerWeights.pack(subloras)
671
                model.loras[module_name] = lora
672
673
674
675
676
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
677
678
679
680
681
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
682

683
684
685
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
686
        language model. LoRA for other modules, such as the vision tower, will
687
688
689
690
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
691
            prefix_lst = module_mapping.connector + module_mapping.tower_model
692
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
693
694
        return False

695
696
697
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
698
699
700
701
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
702
703
704
705
706
707
708
709
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
710
            replacement_loras: list[LoRALayerWeights | None] = []
711
            replaced_module: set[str] = set()
712
713
            has_replacement = False
            for r in new_module_names:
714
                lora = self._get_lora_layer_weights(lora_model, r)
715
716
717
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
718
                    replaced_module.add(r)
719
720
721
722
723
724
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
725
            # HACK Temporary solution for the pool model.
726
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
727
728
729
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
730
731
732
733
734
735
736
737
            if module_name.endswith(".experts"):
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
                    replacement_loras, module_name
                )
            else:
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                    replacement_loras
                )
738
739
740
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
741

742
743
744
        for lora in lora_model.loras.values():
            lora.optimize()

745
    def _get_lora_layer_weights(
746
        self, lora_model: LoRAModel, module_name: str
747
    ) -> LoRALayerWeights | None:
748
        org_module_name = module_name
749
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
750
751
752
753
754
755
756
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
757
758
                    "after removing the prefix 'model.'."
                )
759
760
        return lora_model.get_lora(org_module_name)

761
    def deactivate_adapter(self, adapter_id: int) -> bool:
762
763
764
765
766
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
767
768

    def add_adapter(self, adapter: LoRAModel) -> bool:
769
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
770
771
772
773
774
775
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
776

777
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
778
779
780
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
781
782

    def remove_adapter(self, adapter_id: int) -> bool:
783
784
785
786
787
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
788

789
790
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
791

792
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
793
        return self._registered_adapters.get(adapter_id)
794
795
796


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
797
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
798
        super().__init__(capacity, deactivate_lora_fn)
799
800
801
802
803


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

804
805
806
807
808
809
810
811
812
813
814
815
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
816
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
817
818
            self.capacity, self.deactivate_adapter
        )
819
        self._active_adapters: LoRALRUCache = LoRALRUCache(
820
821
            self.lora_slots, self._deactivate_adapter
        )
822

823
    def list_adapters(self) -> dict[int, LoRAModel]:
824
        """List all registered LoRAModels."""
825
        return dict(self._registered_adapters.cache)
826

827
    def add_adapter(self, lora: LoRAModel) -> bool:
828
        """Add a LoRAModel to the manager."""
829
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
830
831
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
832
833
834
            was_added = True
        else:
            # We always touch to update the LRU cache order
835
            self._registered_adapters.touch(lora.id)
836
837
838
            was_added = False
        return was_added

839
    def activate_adapter(
840
841
842
        self,
        lora_id: int,
    ) -> bool:
843
844
845
846
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
847
848
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
849
        # We always touch to update the LRU cache order
850
        self._active_adapters.touch(lora_id)
851
852
        return result

853
854
855
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
856
857
858
            return True
        return False

859
    def pin_adapter(self, lora_id: int) -> bool:
860
861
862
863
864
865
866
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
867
            self._registered_adapters.pin(lora_id)
868
        except ValueError as err:
869
870
871
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
872
873

    def _pin_lora_in_gpu_cache(self, lora_id: int):
874
        if lora_id not in self._active_adapters:
875
            # move lora to gpu if not already active
876
            self.activate_adapter(lora_id)
877

878
        self._active_adapters.pin(lora_id)
879

880
881

def create_lora_manager(
882
883
884
885
886
887
888
889
890
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
891
    """Create a LoRA adapter for a given model."""
892
    if not isinstance(model, SupportsLoRA):
893
894
895
896
897
898
899
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
900
        device=device,
901
902
        **kwargs,
    )
903
    return lora_manager