models.py 35.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
24
    is_base_embeddding_weights,
25
    is_moe_model,
26
27
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
28
    process_packed_modules_mapping,
29
30
    replace_submodule,
)
31
from vllm.model_executor.layers.fused_moe import FusedMoE
32
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
33
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
34
from vllm.model_executor.models.interfaces import is_pooling_model
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
37
from vllm.utils.cache import LRUCache
38
from vllm.utils.platform_utils import is_pin_memory_available
39

40
logger = init_logger(__name__)
41

42
43
44
45
46
47
48
49
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

50
    def _on_remove(self, key: int, value: T | None):
51
52
53
54
55
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


56
57
58
59
60
61
62
63
64
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


65
class LoRAModel:
66
67
68
69
70
71
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
72
        loras: dict[str, LoRALayerWeights],
73
    ) -> None:
74
75
76
77
78
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
79

80
        """
81
        self.id = lora_model_id
82

83
84
85
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
86
        self.rank = rank
87
        self.loras: dict[str, LoRALayerWeights] = loras
88

89
90
91
92
93
94
95
96
97
98
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

99
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
100
101
102
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

103
104
105
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

106
107
108
109
110
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
111
        tensors: dict[str, torch.Tensor],
112
        peft_helper: PEFTHelper,
113
        device: str = "cuda",
114
115
116
117
118
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
119
120
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
121
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
122
        loras: dict[str, LoRALayerWeights] = {}
123
        for tensor_name, tensor in tensors.items():
124
125
            if is_base_embeddding_weights(tensor_name):
                continue
126
            module_name, is_lora_a = parse_fine_tuned_lora_name(
127
128
                tensor_name, weights_mapper
            )
129
            if module_name not in loras:
130
                loras[module_name] = LoRALayerWeights.from_config(
131
                    module_name, peft_helper
132
                )
133

134
            if is_lora_a:
135
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
136
                if pin_memory:
137
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
138
            else:
139
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
140
                assert embedding_padding_modules is not None
141
142
143
144
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
145
                    lora_b = loras[module_name].lora_b
146
147
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
148
                    loras[module_name].lora_b = torch.nn.functional.pad(
149
150
                        lora_b, (0, 0, 0, addition)
                    )
151
                if pin_memory:
152
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
153
154
155

        for lora in loras.values():
            lora.optimize()
156

157
        return cls(lora_model_id, peft_helper.r, loras)
158
159
160

    @classmethod
    def from_local_checkpoint(
161
162
163
164
165
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
166
        lora_model_id: int | None = None,
167
        device: str = "cuda",
168
169
170
171
172
173
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
174
    ) -> "LoRAModel":
175
        """Create a LoRAModel from a local checkpoint.
176

177
178
179
180
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
181
            peft_helper: Loaded lora configuration information.
182
            lora_model_id: LoRA model id. If not given, automatically set by
183
184
185
186
187
188
189
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
190
191
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
192
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
193
194
195
196
        # new_embeddings_tensor_path = os.path.join(
        #     lora_dir, "new_embeddings.safetensors"
        # )
        # new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
197
        tensors: dict[str, torch.Tensor] = {}
198
        unexpected_modules: list[list[str] | str] = []
199
200
201

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
202
203
                if is_base_embeddding_weights(lora_module):
                    continue
204
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
205
206
207
208
209
210
211
212
213
214
215
                # Handle FSDP file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
                # Case for expert lora weights
                if ".experts" in module_name:
                    if not any(
                        module_name.endswith(ele) for ele in expected_lora_modules
                    ):
                        unexpected_modules.append(module_name)
                elif module_name.split(".")[-1] not in expected_lora_modules:
216
                    unexpected_modules.append(module_name)
217

218
219
220
221
222
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
223
224
                    f" Please verify that the loaded LoRA module is correct"
                )
225
226
227
228
229

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
230
231
232
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
233
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
234
235
236
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
237
238
                **tensorizer_args.deserialization_kwargs,
            )
239
            check_unexpected_modules(tensors)
240

241
        elif os.path.isfile(lora_tensor_path):
242
243
244
245
246
247
248
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
249
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
250
                # Load tensors if there are only expected modules.
251
                check_unexpected_modules(f)
252
253
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
254
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
255
256
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
257
            unexpected_modules = []
258
            target_modules = peft_helper.target_modules
259
260
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
261
262
263
264
265
266
267
268
269
270
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
271
            if unexpected_modules and not is_regex_target_modules(
272
273
                peft_helper.target_modules, expected_lora_modules
            ):
274
275
276
277
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
278
279
280
281
282
283
284
285
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
286
287
288
289
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        return cls.from_lora_tensors(
290
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
291
            tensors=tensors,
292
            peft_helper=peft_helper,
293
294
295
            device=device,
            dtype=dtype,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
296
            embedding_modules=embedding_modules,
297
            embedding_padding_modules=embedding_padding_modules,
298
299
            weights_mapper=weights_mapper,
        )
300
301


302
class LoRAModelManager:
303
304
305
306
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
307
        model: SupportsLoRA,
308
309
310
311
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
312
        device: torch.device,
313
314
315
316
317
318
319
320
321
322
323
324
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
325
326
327
328
329
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
330
        self.lora_config = lora_config
331
        self.device = device
332
333
334
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
335
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
336
        self.vocab_size = vocab_size
337
338
339
340
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
341
342
            max_loras=self.lora_config.max_loras,
        )
343

344
345
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
346
        f" {self.model.__class__.__name__}."
347

348
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
349
        # Used to indicate whether the model is a multimodal model
350
351
352
353
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
354
355
            and hasattr(self.model, "get_mm_mapping")
        )
356
        self.is_pooling_model = is_pooling_model(self.model)
357
358
359
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
360
        self._last_mapping: LoRAMapping | None = None
361
362
363
        self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
            self.model, "is_3d_moe_weight"
        )
364
        self._create_lora_modules()
365

366
        self.model.lora_manager = self
367
368
369

    def __len__(self) -> int:
        return len(self._registered_adapters)
370
371
372
373
374
375
376
377
378

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

379
380
381
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
382

383
    def activate_adapter(
384
385
386
387
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
388
        if lora_id in self._active_adapters:
389
390
            return False
        first_free_slot = next(
391
392
393
394
395
396
397
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
398
399
400
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
401
402
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
403
404
405
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
406
407
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
408
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
            if not module_lora:
                module.reset_lora(index)
                continue
            # Note (gnovack) - If MOE lora weights are not split into
            # num_experts chunks, we split them here
            if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
                module_lora.lora_a
            ):
                # Handle PEFT file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                gate_up_proj_lora = self._get_lora_layer_weights(
                    lora_model, module_name + ".base_layer"
                )
                down_proj_lora = module_lora
                # FIXME Edge case where LoRA is not added to gate_up_proj
                # or down_proj
                assert gate_up_proj_lora is not None
                assert down_proj_lora is not None
                if self._is_3d_moe_model:
                    module_lora.lora_a = [
                        gate_up_proj_lora.lora_a,
                        down_proj_lora.lora_a,
                    ]
                    module_lora.lora_b = [
                        gate_up_proj_lora.lora_b,
                        down_proj_lora.lora_b,
                    ]
                else:
                    # Some 3D MoE models haven't added the `is_3d_moe_weight`
                    # attribute yet, so fallback here
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b
467
468
469
470
471
            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
472

473
474
        return True

475
    def _deactivate_adapter(self, lora_id: int):
476
477
478
479
480
481
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

482
    def _add_adapter(self, lora: LoRAModel):
483
        self._create_merged_loras_inplace(lora)
484
        self._registered_adapters[lora.id] = lora
485

486
    def pin_adapter(self, lora_id: int) -> bool:
487
488
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
489
            "Pinning is not supported in LoRAModelManager. "
490
491
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
492

493
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
494
495
496
497
498
499
500
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
501

502
    def remove_all_adapters(self):
503
        """Remove all LoRAModels from the manager."""
504
        self._registered_adapters.clear()
505
        self.lora_index_to_id = [None] * self.lora_slots
506
        self._active_adapters.clear()
507
508

    def _create_lora_modules(self):
509
510
511
512
513
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
514
            return module_name.rpartition(".")[0]
515

516
        for module_name, module in self.model.named_modules(remove_duplicate=False):
517
518
            if isinstance(module, PPMissingLayer):
                continue
519

520
521
            if not self._match_target_modules(module_name):
                continue
522
523
524
525
526
527
528
529
530
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
531
532
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
533
534
535
536
537
538
539
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
540
            new_module = replace_submodule(
541
542
543
544
545
546
547
548
549
550
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
551

552
553
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
554
                logits_processor_module_name = "logits_processor"
555
556
557
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
558
559
                        f"{parent_module}.{logits_processor_module_name}"
                    )
560

561
                logits_processor_module = self.model.get_submodule(
562
563
                    logits_processor_module_name
                )
564

565
                new_module = replace_submodule(
566
567
568
569
570
571
572
573
574
575
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
576
577
578
579
580
581

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
582
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
583
                continue
584
585
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
586
587
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
588
        pass
589
590

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
591
592
593
594
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance,"
        )
        f" got {type(module)}"
595
596
        self.modules[module_name] = module

Terry's avatar
Terry committed
597
    def create_dummy_lora(
598
599
600
        self,
        lora_id: int,
        rank: int,
601
        embedding_modules: dict[str, str] | None = None,
602
    ) -> LoRAModel:
603
        """Create zero-initialized LoRAModel for warmup."""
604
        model = LoRAModel(lora_id, rank, {})
605
        for module_name, module in self.model.named_modules():
606
607
608
609
610
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
611
612
613
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
614
                assert embedding_modules is not None
Terry's avatar
Terry committed
615
                if parts[-1] in embedding_modules:
616
617
618
619
620
621
622
623
624
625
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
626
627
628
629
630
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
631
                        module.lora_a_stacked[0].dtype,
632
                        "cpu",
633
                    )
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
658
659
660
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
661
662
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
663
                        rank,
664
                        module.lora_a_stacked[0].dtype,
665
666
                        "cpu",
                    )
667
                    model.loras[module_name] = lora
668
669
670
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
671
                subloras: list[LoRALayerWeights | None] = []
672
673
674
675
676
677
678
679
680
681
682
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
683
                model.loras[module_name] = lora
684
685
686
687
688
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
689
690
691
692
693
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
694

695
696
697
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
698
        language model. LoRA for other modules, such as the vision tower, will
699
700
701
702
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
703
            prefix_lst = module_mapping.connector + module_mapping.tower_model
704
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
705
706
        return False

707
708
709
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
710
711
712
713
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
714
715
716
717
718
719
720
721
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
722
            replacement_loras: list[LoRALayerWeights | None] = []
723
            replaced_module: set[str] = set()
724
725
            has_replacement = False
            for r in new_module_names:
726
                lora = self._get_lora_layer_weights(lora_model, r)
727
728
729
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
730
                    replaced_module.add(r)
731
732
733
734
735
736
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
737
            # HACK Temporary solution for the pool model.
738
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
739
740
741
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
742
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
743
744
                replacement_loras
            )
745
746
747
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
748

749
    def _get_lora_layer_weights(
750
        self, lora_model: LoRAModel, module_name: str
751
    ) -> LoRALayerWeights | None:
752
        org_module_name = module_name
753
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
754
755
756
757
758
759
760
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
761
762
                    "after removing the prefix 'model.'."
                )
763
764
        return lora_model.get_lora(org_module_name)

765
    def deactivate_adapter(self, adapter_id: int) -> bool:
766
767
768
769
770
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
771
772

    def add_adapter(self, adapter: LoRAModel) -> bool:
773
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
774
775
776
777
778
779
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
780

781
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
782
783
784
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
785
786

    def remove_adapter(self, adapter_id: int) -> bool:
787
788
789
790
791
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
792

793
794
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
795

796
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
797
        return self._registered_adapters.get(adapter_id)
798
799
800


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
801
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
802
        super().__init__(capacity, deactivate_lora_fn)
803
804
805
806
807


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

808
809
810
811
812
813
814
815
816
817
818
819
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
820
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
821
822
            self.capacity, self.deactivate_adapter
        )
823
        self._active_adapters: LoRALRUCache = LoRALRUCache(
824
825
            self.lora_slots, self._deactivate_adapter
        )
826

827
    def list_adapters(self) -> dict[int, LoRAModel]:
828
        """List all registered LoRAModels."""
829
        return dict(self._registered_adapters.cache)
830

831
    def add_adapter(self, lora: LoRAModel) -> bool:
832
        """Add a LoRAModel to the manager."""
833
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
834
835
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
836
837
838
            was_added = True
        else:
            # We always touch to update the LRU cache order
839
            self._registered_adapters.touch(lora.id)
840
841
842
            was_added = False
        return was_added

843
    def activate_adapter(
844
845
846
        self,
        lora_id: int,
    ) -> bool:
847
848
849
850
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
851
852
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
853
        # We always touch to update the LRU cache order
854
        self._active_adapters.touch(lora_id)
855
856
        return result

857
858
859
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
860
861
862
            return True
        return False

863
    def pin_adapter(self, lora_id: int) -> bool:
864
865
866
867
868
869
870
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
871
            self._registered_adapters.pin(lora_id)
872
        except ValueError as err:
873
874
875
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
876
877

    def _pin_lora_in_gpu_cache(self, lora_id: int):
878
        if lora_id not in self._active_adapters:
879
            # move lora to gpu if not already active
880
            self.activate_adapter(lora_id)
881

882
        self._active_adapters.pin(lora_id)
883

884
885

def create_lora_manager(
886
887
888
889
890
891
892
893
894
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
895
    """Create a LoRA adapter for a given model."""
896
    if not isinstance(model, SupportsLoRA):
897
898
899
900
901
902
903
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
904
        device=device,
905
906
        **kwargs,
    )
907
    return lora_manager