models.py 35.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
24
    is_base_embeddding_weights,
25
    is_moe_model,
26
27
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
28
    process_packed_modules_mapping,
29
30
    replace_submodule,
)
31
from vllm.model_executor.layers.fused_moe import FusedMoE
32
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
33
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
34
from vllm.model_executor.models.interfaces import is_pooling_model
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
37
from vllm.utils.cache import LRUCache
38
from vllm.utils.platform_utils import is_pin_memory_available
39

40
logger = init_logger(__name__)
41

42
43
44
45
46
47
48
49
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

50
    def _on_remove(self, key: int, value: T | None):
51
52
53
54
55
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


56
57
58
59
60
61
62
63
64
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


65
class LoRAModel:
66
67
68
69
70
71
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
72
        loras: dict[str, LoRALayerWeights],
73
    ) -> None:
74
75
76
77
78
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
79

80
        """
81
        self.id = lora_model_id
82

83
84
85
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
86
        self.rank = rank
87
        self.loras: dict[str, LoRALayerWeights] = loras
88

89
90
91
92
93
94
95
96
97
98
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

99
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
100
101
102
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

103
104
105
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

106
107
108
109
110
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
111
        tensors: dict[str, torch.Tensor],
112
        peft_helper: PEFTHelper,
113
        device: str = "cuda",
114
115
116
117
118
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
119
120
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
121
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
122
        loras: dict[str, LoRALayerWeights] = {}
123
        for tensor_name, tensor in tensors.items():
124
125
            if is_base_embeddding_weights(tensor_name):
                continue
126
            module_name, is_lora_a = parse_fine_tuned_lora_name(
127
128
                tensor_name, weights_mapper
            )
129
            if module_name not in loras:
130
                loras[module_name] = LoRALayerWeights.from_config(
131
                    module_name, peft_helper
132
                )
133

134
            if is_lora_a:
135
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
136
                if pin_memory:
137
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
138
            else:
139
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
140
                assert embedding_padding_modules is not None
141
142
143
144
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
145
                    lora_b = loras[module_name].lora_b
146
147
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
148
                    loras[module_name].lora_b = torch.nn.functional.pad(
149
150
                        lora_b, (0, 0, 0, addition)
                    )
151
                if pin_memory:
152
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
153

154
        return cls(lora_model_id, peft_helper.r, loras)
155
156
157

    @classmethod
    def from_local_checkpoint(
158
159
        cls,
        lora_dir: str,
160
        expected_lora_modules: set[str],
161
162
        peft_helper: PEFTHelper,
        *,
163
        lora_model_id: int | None = None,
164
        device: str = "cuda",
165
166
167
168
169
170
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
171
    ) -> "LoRAModel":
172
        """Create a LoRAModel from a local checkpoint.
173

174
175
176
177
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
178
            peft_helper: Loaded lora configuration information.
179
            lora_model_id: LoRA model id. If not given, automatically set by
180
181
182
183
184
185
186
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
187
188
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
189
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
190

191
        tensors: dict[str, torch.Tensor] = {}
192
        unexpected_modules: list[list[str] | str] = []
193
194
195

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
196
197
                if is_base_embeddding_weights(lora_module):
                    continue
198
                # Handle PEFT file format where experts.base_layer is the
199
200
201
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
202
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
203
204
                # Case for expert lora weights
                if ".experts" in module_name:
205
206
207
                    expert_idx = module_name.find(".experts")
                    expert_suffix = module_name[expert_idx + 1 :]
                    if expert_suffix not in expected_lora_modules:
208
                        unexpected_modules.append(module_name)
209
210

                elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
211
                    unexpected_modules.append(module_name)
212

213
214
215
216
217
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
218
219
                    f" Please verify that the loaded LoRA module is correct"
                )
220
221
222
223
224

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
225
226
227
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
228
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
229
230
231
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
232
233
                **tensorizer_args.deserialization_kwargs,
            )
234
            check_unexpected_modules(tensors)
235

236
        elif os.path.isfile(lora_tensor_path):
237
238
239
240
241
242
243
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
244
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
245
                # Load tensors if there are only expected modules.
246
                check_unexpected_modules(f)
247
248
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
249
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
250
251
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
252
            unexpected_modules = []
253
            target_modules = peft_helper.target_modules
254
255
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
256
257
258
259
260
261
262
263
264
265
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
266
            if unexpected_modules and not is_regex_target_modules(
267
268
                peft_helper.target_modules, expected_lora_modules
            ):
269
270
271
272
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
273
274
275
276
277
278
279
280
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
281
282
283
284
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        return cls.from_lora_tensors(
285
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
286
            tensors=tensors,
287
            peft_helper=peft_helper,
288
289
290
            device=device,
            dtype=dtype,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
291
            embedding_modules=embedding_modules,
292
            embedding_padding_modules=embedding_padding_modules,
293
294
            weights_mapper=weights_mapper,
        )
295
296


297
class LoRAModelManager:
298
299
300
301
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
302
        model: SupportsLoRA,
303
304
305
306
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
307
        device: torch.device,
308
309
310
311
312
313
314
315
316
317
318
319
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
320
321
322
323
324
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
325
        self.lora_config = lora_config
326
        self.device = device
327
328
329
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
330
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
331
        self.vocab_size = vocab_size
332
333
334
335
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
336
337
            max_loras=self.lora_config.max_loras,
        )
338

339
340
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
341
        f" {self.model.__class__.__name__}."
342

343
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
344
        # Used to indicate whether the model is a multimodal model
345
346
347
348
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
349
350
            and hasattr(self.model, "get_mm_mapping")
        )
351
        self.is_pooling_model = is_pooling_model(self.model)
352
353
354
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
355
        self._last_mapping: LoRAMapping | None = None
356
        self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
357
        self._create_lora_modules()
358

359
        self.model.lora_manager = self
360
361
362

    def __len__(self) -> int:
        return len(self._registered_adapters)
363
364
365
366
367
368
369
370
371

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

372
373
374
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
375

376
    def activate_adapter(
377
378
379
380
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
381
        if lora_id in self._active_adapters:
382
383
            return False
        first_free_slot = next(
384
385
386
387
388
389
390
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
391
392
393
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
394
395
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
396
397
398
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
399
400
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
401
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
402
403
404
405
406
            if not module_lora:
                module.reset_lora(index)
                continue
            # Note (gnovack) - If MOE lora weights are not split into
            # num_experts chunks, we split them here
407
            if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                module_lora.lora_a
            ):
                # Handle PEFT file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                gate_up_proj_lora = self._get_lora_layer_weights(
                    lora_model, module_name + ".base_layer"
                )
                down_proj_lora = module_lora
                # FIXME Edge case where LoRA is not added to gate_up_proj
                # or down_proj
                assert gate_up_proj_lora is not None
                assert down_proj_lora is not None
                if self._is_3d_moe_model:
                    module_lora.lora_a = [
                        gate_up_proj_lora.lora_a,
                        down_proj_lora.lora_a,
                    ]
                    module_lora.lora_b = [
                        gate_up_proj_lora.lora_b,
                        down_proj_lora.lora_b,
                    ]
                else:
                    # Some 3D MoE models haven't added the `is_3d_moe_weight`
                    # attribute yet, so fallback here
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b
460
461
462
463
464
            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
465

466
467
        return True

468
    def _deactivate_adapter(self, lora_id: int):
469
470
471
472
473
474
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

475
    def _add_adapter(self, lora: LoRAModel):
476
        self._create_merged_loras_inplace(lora)
477
        self._registered_adapters[lora.id] = lora
478

479
    def pin_adapter(self, lora_id: int) -> bool:
480
481
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
482
            "Pinning is not supported in LoRAModelManager. "
483
484
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
485

486
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
487
488
489
490
491
492
493
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
494

495
    def remove_all_adapters(self):
496
        """Remove all LoRAModels from the manager."""
497
        self._registered_adapters.clear()
498
        self.lora_index_to_id = [None] * self.lora_slots
499
        self._active_adapters.clear()
500
501

    def _create_lora_modules(self):
502
503
504
505
506
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
507
            return module_name.rpartition(".")[0]
508

509
        for module_name, module in self.model.named_modules(remove_duplicate=False):
510
511
            if isinstance(module, PPMissingLayer):
                continue
512

513
514
            if not self._match_target_modules(module_name):
                continue
515
516
517
518
519
520
521
522
523
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
524
525
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
526
527
528
529
530
531
532
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
533
            new_module = replace_submodule(
534
535
536
537
538
539
540
541
542
543
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
544

545
546
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
547
                logits_processor_module_name = "logits_processor"
548
549
550
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
551
552
                        f"{parent_module}.{logits_processor_module_name}"
                    )
553

554
                logits_processor_module = self.model.get_submodule(
555
556
                    logits_processor_module_name
                )
557

558
                new_module = replace_submodule(
559
560
561
562
563
564
565
566
567
568
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
569
570
571
572
573
574

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
575
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
576
                continue
577
578
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
579
580
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
581
        pass
582
583

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
584
585
586
587
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance,"
        )
        f" got {type(module)}"
588
589
        self.modules[module_name] = module

Terry's avatar
Terry committed
590
    def create_dummy_lora(
591
592
593
        self,
        lora_id: int,
        rank: int,
594
        embedding_modules: dict[str, str] | None = None,
595
    ) -> LoRAModel:
596
        """Create zero-initialized LoRAModel for warmup."""
597
        model = LoRAModel(lora_id, rank, {})
598
        for module_name, module in self.model.named_modules():
599
600
601
602
603
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
604
605
606
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
607
                assert embedding_modules is not None
Terry's avatar
Terry committed
608
                if parts[-1] in embedding_modules:
609
610
611
612
613
614
615
616
617
618
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
619
620
621
622
623
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
624
                        module.lora_a_stacked[0].dtype,
625
                        "cpu",
626
                    )
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
651
652
653
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
654
655
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
656
                        rank,
657
                        module.lora_a_stacked[0].dtype,
658
659
                        "cpu",
                    )
660
                    model.loras[module_name] = lora
661
662
663
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
664
                subloras: list[LoRALayerWeights | None] = []
665
666
667
668
669
670
671
672
673
674
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
675
676
677
678
                if module.__class__.__name__ == "FusedMoEWithLoRA":
                    lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
                else:
                    lora = PackedLoRALayerWeights.pack(subloras)
679
                model.loras[module_name] = lora
680
681
682
683
684
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
685
686
687
688
689
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
690

691
692
693
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
694
        language model. LoRA for other modules, such as the vision tower, will
695
696
697
698
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
699
            prefix_lst = module_mapping.connector + module_mapping.tower_model
700
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
701
702
        return False

703
704
705
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
706
707
708
709
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
710
711
712
713
714
715
716
717
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
718
            replacement_loras: list[LoRALayerWeights | None] = []
719
            replaced_module: set[str] = set()
720
721
            has_replacement = False
            for r in new_module_names:
722
                lora = self._get_lora_layer_weights(lora_model, r)
723
724
725
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
726
                    replaced_module.add(r)
727
728
729
730
731
732
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
733
            # HACK Temporary solution for the pool model.
734
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
735
736
737
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
738
739
740
741
742
743
744
745
            if module_name.endswith(".experts"):
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
                    replacement_loras, module_name
                )
            else:
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                    replacement_loras
                )
746
747
748
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
749

750
751
752
        for lora in lora_model.loras.values():
            lora.optimize()

753
    def _get_lora_layer_weights(
754
        self, lora_model: LoRAModel, module_name: str
755
    ) -> LoRALayerWeights | None:
756
        org_module_name = module_name
757
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
758
759
760
761
762
763
764
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
765
766
                    "after removing the prefix 'model.'."
                )
767
768
        return lora_model.get_lora(org_module_name)

769
    def deactivate_adapter(self, adapter_id: int) -> bool:
770
771
772
773
774
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
775
776

    def add_adapter(self, adapter: LoRAModel) -> bool:
777
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
778
779
780
781
782
783
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
784

785
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
786
787
788
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
789
790

    def remove_adapter(self, adapter_id: int) -> bool:
791
792
793
794
795
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
796

797
798
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
799

800
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
801
        return self._registered_adapters.get(adapter_id)
802
803
804


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
805
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
806
        super().__init__(capacity, deactivate_lora_fn)
807
808
809
810
811


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

812
813
814
815
816
817
818
819
820
821
822
823
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
824
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
825
826
            self.capacity, self.deactivate_adapter
        )
827
        self._active_adapters: LoRALRUCache = LoRALRUCache(
828
829
            self.lora_slots, self._deactivate_adapter
        )
830

831
    def list_adapters(self) -> dict[int, LoRAModel]:
832
        """List all registered LoRAModels."""
833
        return dict(self._registered_adapters.cache)
834

835
    def add_adapter(self, lora: LoRAModel) -> bool:
836
        """Add a LoRAModel to the manager."""
837
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
838
839
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
840
841
842
            was_added = True
        else:
            # We always touch to update the LRU cache order
843
            self._registered_adapters.touch(lora.id)
844
845
846
            was_added = False
        return was_added

847
    def activate_adapter(
848
849
850
        self,
        lora_id: int,
    ) -> bool:
851
852
853
854
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
855
856
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
857
        # We always touch to update the LRU cache order
858
        self._active_adapters.touch(lora_id)
859
860
        return result

861
862
863
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
864
865
866
            return True
        return False

867
    def pin_adapter(self, lora_id: int) -> bool:
868
869
870
871
872
873
874
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
875
            self._registered_adapters.pin(lora_id)
876
        except ValueError as err:
877
878
879
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
880
881

    def _pin_lora_in_gpu_cache(self, lora_id: int):
882
        if lora_id not in self._active_adapters:
883
            # move lora to gpu if not already active
884
            self.activate_adapter(lora_id)
885

886
        self._active_adapters.pin(lora_id)
887

888
889

def create_lora_manager(
890
891
892
893
894
895
896
897
898
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
899
    """Create a LoRA adapter for a given model."""
900
    if not isinstance(model, SupportsLoRA):
901
902
903
904
905
906
907
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
908
        device=device,
909
910
        **kwargs,
    )
911
    return lora_manager