models.py 34.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
24
25
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
26
    process_packed_modules_mapping,
27
28
    replace_submodule,
)
29
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
30
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
31
from vllm.model_executor.models.interfaces import is_pooling_model
32
from vllm.model_executor.models.module_mapping import MultiModelKeys
33
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
34
35
from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache
36

37
logger = init_logger(__name__)
38

39
40
41
42
43
44
45
46
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

47
    def _on_remove(self, key: int, value: T | None):
48
49
50
51
52
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


53
54
55
56
57
58
59
60
61
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


62
class LoRAModel:
63
64
65
66
67
68
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
69
        loras: dict[str, LoRALayerWeights],
70
    ) -> None:
71
72
73
74
75
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
76

77
        """
78
        self.id = lora_model_id
79

80
81
82
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
83
        self.rank = rank
84
        self.loras: dict[str, LoRALayerWeights] = loras
85

86
87
88
89
90
91
92
93
94
95
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

96
97
    @property
    def extra_vocab_size(self) -> int:
98
99
100
101
102
        return (
            max(lora.extra_vocab_size for lora in self.loras.values())
            if self.loras
            else 0
        )
103

104
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
105
106
107
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

108
109
110
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

111
112
113
114
115
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
116
        tensors: dict[str, torch.Tensor],
117
        peft_helper: PEFTHelper,
118
        device: str = "cuda",
119
120
121
122
123
124
        dtype: torch.dtype | None = None,
        embeddings: dict[str, torch.Tensor] | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
125
126
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
127
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
128
        loras: dict[str, LoRALayerWeights] = {}
129
        for tensor_name, tensor in tensors.items():
130
            module_name, is_lora_a = parse_fine_tuned_lora_name(
131
132
                tensor_name, weights_mapper
            )
133
134
135
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
136
                    assert embedding_modules is not None
137
                    embeddings_module = next(
138
139
                        (k for k in embedding_modules if k in module_name), None
                    )
140
141
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
142
143
                            embedding_modules[embeddings_module]
                        ].to(device=device, dtype=dtype)
144
                        if pin_memory:
145
                            lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
146
                loras[module_name] = LoRALayerWeights.from_config(
147
148
                    module_name, peft_helper, lora_embeddings_tensor
                )
149

150
            if is_lora_a:
151
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
152
                if pin_memory:
153
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
154
            else:
155
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
156
                assert embedding_padding_modules is not None
157
158
159
160
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
161
                    lora_b = loras[module_name].lora_b
162
163
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
164
                    loras[module_name].lora_b = torch.nn.functional.pad(
165
166
                        lora_b, (0, 0, 0, addition)
                    )
167
                if pin_memory:
168
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
169
170
171

        for lora in loras.values():
            lora.optimize()
172

173
        return cls(lora_model_id, peft_helper.r, loras)
174
175
176

    @classmethod
    def from_local_checkpoint(
177
178
179
180
181
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
182
        lora_model_id: int | None = None,
183
        device: str = "cuda",
184
185
186
187
188
189
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
190
    ) -> "LoRAModel":
191
        """Create a LoRAModel from a local checkpoint.
192

193
194
195
196
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
197
            peft_helper: Loaded lora configuration information.
198
            lora_model_id: LoRA model id. If not given, automatically set by
199
200
201
202
203
204
205
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
206
207
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
208
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
209
        new_embeddings_tensor_path = os.path.join(
210
211
212
            lora_dir, "new_embeddings.safetensors"
        )
        new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
213
        tensors: dict[str, torch.Tensor] = {}
214
        unexpected_modules: list[list[str] | str] = []
215
216
217

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
218
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
219
220
221
222
223
224
225
226
227
228
229
                # Handle FSDP file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
                # Case for expert lora weights
                if ".experts" in module_name:
                    if not any(
                        module_name.endswith(ele) for ele in expected_lora_modules
                    ):
                        unexpected_modules.append(module_name)
                elif module_name.split(".")[-1] not in expected_lora_modules:
230
                    unexpected_modules.append(module_name)
231

232
233
234
235
236
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
237
238
                    f" Please verify that the loaded LoRA module is correct"
                )
239
240
241
242
243

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
244
245
246
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
247
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
248
249
250
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
251
252
                **tensorizer_args.deserialization_kwargs,
            )
253
            check_unexpected_modules(tensors)
254

255
        elif os.path.isfile(lora_tensor_path):
256
257
258
259
260
261
262
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
263
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
264
                # Load tensors if there are only expected modules.
265
                check_unexpected_modules(f)
266
267
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
268
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
269
270
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
271
            unexpected_modules = []
272
            target_modules = peft_helper.target_modules
273
274
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
275
276
277
278
279
280
281
282
283
284
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
285
            if unexpected_modules and not is_regex_target_modules(
286
287
                peft_helper.target_modules, expected_lora_modules
            ):
288
289
290
291
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
292
293
294
295
296
297
298
299
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
300
301
302
303
304
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
305
            embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
306
        elif os.path.isfile(new_embeddings_bin_file_path):
307
308
309
            embeddings = torch.load(
                new_embeddings_bin_file_path, map_location=device, weights_only=True
            )
310
311

        return cls.from_lora_tensors(
312
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
313
            tensors=tensors,
314
            peft_helper=peft_helper,
315
316
317
318
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
319
            embedding_modules=embedding_modules,
320
            embedding_padding_modules=embedding_padding_modules,
321
322
            weights_mapper=weights_mapper,
        )
323
324


325
class LoRAModelManager:
326
327
328
329
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
330
        model: SupportsLoRA,
331
332
333
334
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
335
        device: torch.device,
336
337
338
339
340
341
342
343
344
345
346
347
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
348
349
350
351
352
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
353
        self.lora_config = lora_config
354
        self.device = device
355
356
357
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
358
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
359
        self.vocab_size = vocab_size
360
361
362
363
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
364
365
            max_loras=self.lora_config.max_loras,
        )
366

367
368
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
369
        f" {self.model.__class__.__name__}."
370

371
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
372
        # Used to indicate whether the model is a multimodal model
373
374
375
376
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
377
378
            and hasattr(self.model, "get_mm_mapping")
        )
379
        self.is_pooling_model = is_pooling_model(self.model)
380
381
382
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
383
        self._last_mapping: LoRAMapping | None = None
384
        self._create_lora_modules()
385
        self.model.lora_manager = self
386
387
388

    def __len__(self) -> int:
        return len(self._registered_adapters)
389
390
391
392
393
394
395
396
397

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

398
399
400
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
401

402
    def activate_adapter(
403
404
405
406
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
407
        if lora_id in self._active_adapters:
408
409
            return False
        first_free_slot = next(
410
411
412
413
414
415
416
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
417
418
419
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
420
421
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
422
423
424
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
425
426
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
427
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
428
            if module_lora:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                # Note (gnovack) - If MOE lora weights are not split into
                # num_experts chunks, we split them here
                if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
                    module_lora.lora_a
                ):
                    # Handle FSDP file format where experts.base_layer is the
                    # gate_up_proj and experts is the down_proj
                    gate_up_proj_lora = self._get_lora_layer_weights(
                        lora_model, module_name + ".base_layer"
                    )

                    assert gate_up_proj_lora is not None
                    assert module_lora is not None

                    down_proj_lora = module_lora
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b

473
474
475
476
477
478
                module.set_lora(
                    index,
                    module_lora.lora_a,
                    module_lora.lora_b,
                    module_lora.embeddings_tensor,
                )
479
480
481
482
            else:
                module.reset_lora(index)
        return True

483
    def _deactivate_adapter(self, lora_id: int):
484
485
486
487
488
489
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

490
    def _add_adapter(self, lora: LoRAModel):
491
        self._create_merged_loras_inplace(lora)
492
        self._registered_adapters[lora.id] = lora
493

494
    def pin_adapter(self, lora_id: int) -> bool:
495
496
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
497
            "Pinning is not supported in LoRAModelManager. "
498
499
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
500

501
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
502
503
504
505
506
507
508
509
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
510

511
    def remove_all_adapters(self):
512
        """Remove all LoRAModels from the manager."""
513
        self._registered_adapters.clear()
514
        self.lora_index_to_id = [None] * self.lora_slots
515
        self._active_adapters.clear()
516
517

    def _create_lora_modules(self):
518
519
520
521
522
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
523
            return module_name.rpartition(".")[0]
524

525
        for module_name, module in self.model.named_modules(remove_duplicate=False):
526
527
            if isinstance(module, PPMissingLayer):
                continue
528

529
530
            if not self._match_target_modules(module_name):
                continue
531
532
533
534
535
536
537
538
539
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
540
541
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
542
            new_module = replace_submodule(
543
544
545
546
547
548
549
550
551
552
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
553

554
555
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
556
                logits_processor_module_name = "logits_processor"
557
558
559
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
560
561
                        f"{parent_module}.{logits_processor_module_name}"
                    )
562

563
                logits_processor_module = self.model.get_submodule(
564
565
                    logits_processor_module_name
                )
566

567
                new_module = replace_submodule(
568
569
570
571
572
573
574
575
576
577
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
578
579
580
581
582
583

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
584
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
585
                continue
586
587
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
588
589
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
590
591

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
592
593
594
595
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance,"
        )
        f" got {type(module)}"
596
597
        self.modules[module_name] = module

Terry's avatar
Terry committed
598
    def create_dummy_lora(
599
600
601
        self,
        lora_id: int,
        rank: int,
602
        embedding_modules: dict[str, str] | None = None,
603
    ) -> LoRAModel:
604
        """Create zero-initialized LoRAModel for warmup."""
605
        model = LoRAModel(lora_id, rank, {})
606
        for module_name, module in self.model.named_modules():
607
608
609
610
611
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
612
613
614
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
615
                assert embedding_modules is not None
Terry's avatar
Terry committed
616
                if parts[-1] in embedding_modules:
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                    input_dim = (
                        module.base_layer.org_vocab_size
                        + self.lora_config.lora_extra_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
                    embeddings_tensor_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[1]
                    )
633
634
635
636
637
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
638
                        module.lora_a_stacked[0].dtype,
639
                        "cpu",
640
                        embeddings_tensor_dim=embeddings_tensor_dim,
641
                    )
642
643
644
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
645
646
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
647
                        rank,
648
                        module.lora_a_stacked[0].dtype,
649
650
651
652
653
                        "cpu",
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
654
                subloras: list[LoRALayerWeights | None] = []
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
672
673
674
675
676
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
677

678
679
680
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
681
        language model. LoRA for other modules, such as the vision tower, will
682
683
684
685
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
686
            prefix_lst = module_mapping.connector + module_mapping.tower_model
687
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
688
689
        return False

690
691
692
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
693
694
695
696
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
697
698
699
700
701
702
703
704
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
705
            replacement_loras: list[LoRALayerWeights | None] = []
706
            replaced_module: set[str] = set()
707
708
            has_replacement = False
            for r in new_module_names:
709
                lora = self._get_lora_layer_weights(lora_model, r)
710
711
712
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
713
                    replaced_module.add(r)
714
715
716
717
718
719
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
720
            # HACK Temporary solution for the pool model.
721
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
722
723
724
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
725
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
726
727
                replacement_loras
            )
728
729
730
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
731

732
    def _get_lora_layer_weights(
733
        self, lora_model: LoRAModel, module_name: str
734
    ) -> LoRALayerWeights | None:
735
        org_module_name = module_name
736
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
737
738
739
740
741
742
743
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
744
745
                    "after removing the prefix 'model.'."
                )
746
747
        return lora_model.get_lora(org_module_name)

748
    def deactivate_adapter(self, adapter_id: int) -> bool:
749
750
751
752
753
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
754
755

    def add_adapter(self, adapter: LoRAModel) -> bool:
756
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
757
758
759
760
761
762
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
763

764
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
765
766
767
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
768
769

    def remove_adapter(self, adapter_id: int) -> bool:
770
771
772
773
774
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
775

776
777
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
778

779
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
780
        return self._registered_adapters.get(adapter_id)
781
782
783


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
784
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
785
        super().__init__(capacity, deactivate_lora_fn)
786
787
788
789
790


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

791
792
793
794
795
796
797
798
799
800
801
802
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
803
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
804
805
            self.capacity, self.deactivate_adapter
        )
806
        self._active_adapters: LoRALRUCache = LoRALRUCache(
807
808
            self.lora_slots, self._deactivate_adapter
        )
809

810
    def list_adapters(self) -> dict[int, LoRAModel]:
811
        """List all registered LoRAModels."""
812
        return dict(self._registered_adapters.cache)
813

814
    def add_adapter(self, lora: LoRAModel) -> bool:
815
        """Add a LoRAModel to the manager."""
816
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
817
818
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
819
820
821
            was_added = True
        else:
            # We always touch to update the LRU cache order
822
            self._registered_adapters.touch(lora.id)
823
824
825
            was_added = False
        return was_added

826
    def activate_adapter(
827
828
829
        self,
        lora_id: int,
    ) -> bool:
830
831
832
833
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
834
835
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
836
        # We always touch to update the LRU cache order
837
        self._active_adapters.touch(lora_id)
838
839
        return result

840
841
842
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
843
844
845
            return True
        return False

846
    def pin_adapter(self, lora_id: int) -> bool:
847
848
849
850
851
852
853
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
854
            self._registered_adapters.pin(lora_id)
855
        except ValueError as err:
856
857
858
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
859
860

    def _pin_lora_in_gpu_cache(self, lora_id: int):
861
        if lora_id not in self._active_adapters:
862
            # move lora to gpu if not already active
863
            self.activate_adapter(lora_id)
864

865
        self._active_adapters.pin(lora_id)
866

867
868

def create_lora_manager(
869
870
871
872
873
874
875
876
877
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
878
    """Create a LoRA adapter for a given model."""
879
    if not isinstance(model, SupportsLoRA):
880
881
882
883
884
885
886
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
887
        device=device,
888
889
        **kwargs,
    )
890
    return lora_manager