models.py 34.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
24
25
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
26
    process_packed_modules_mapping,
27
28
    replace_submodule,
)
29
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
30
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
31
from vllm.model_executor.models.interfaces import is_pooling_model
32
from vllm.model_executor.models.module_mapping import MultiModelKeys
33
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
34
35
from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache
36

37
logger = init_logger(__name__)
38

39
40
41
42
43
44
45
46
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

47
    def _on_remove(self, key: int, value: T | None):
48
49
50
51
52
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


53
54
55
56
57
58
59
60
61
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


62
class LoRAModel:
63
64
65
66
67
68
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
69
        loras: dict[str, LoRALayerWeights],
70
    ) -> None:
71
72
73
74
75
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
76

77
        """
78
        self.id = lora_model_id
79

80
81
82
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
83
        self.rank = rank
84
        self.loras: dict[str, LoRALayerWeights] = loras
85

86
87
88
89
90
91
92
93
94
95
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

96
97
    @property
    def extra_vocab_size(self) -> int:
98
99
100
101
102
        return (
            max(lora.extra_vocab_size for lora in self.loras.values())
            if self.loras
            else 0
        )
103

104
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
105
106
107
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

108
109
110
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

111
112
113
114
115
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
116
        tensors: dict[str, torch.Tensor],
117
        peft_helper: PEFTHelper,
118
        device: str = "cuda",
119
120
121
122
123
124
        dtype: torch.dtype | None = None,
        embeddings: dict[str, torch.Tensor] | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
125
126
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
127
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
128
        loras: dict[str, LoRALayerWeights] = {}
129
        for tensor_name, tensor in tensors.items():
130
            module_name, is_lora_a = parse_fine_tuned_lora_name(
131
132
                tensor_name, weights_mapper
            )
133
134
135
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
136
                    assert embedding_modules is not None
137
                    embeddings_module = next(
138
139
                        (k for k in embedding_modules if k in module_name), None
                    )
140
141
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
142
143
                            embedding_modules[embeddings_module]
                        ].to(device=device, dtype=dtype)
144
                        if pin_memory:
145
                            lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
146
                loras[module_name] = LoRALayerWeights.from_config(
147
148
                    module_name, peft_helper, lora_embeddings_tensor
                )
149

150
            if is_lora_a:
151
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
152
                if pin_memory:
153
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
154
            else:
155
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
156
                assert embedding_padding_modules is not None
157
158
159
160
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
161
                    lora_b = loras[module_name].lora_b
162
163
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
164
                    loras[module_name].lora_b = torch.nn.functional.pad(
165
166
                        lora_b, (0, 0, 0, addition)
                    )
167
                if pin_memory:
168
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
169
170
171

        for lora in loras.values():
            lora.optimize()
172

173
        return cls(lora_model_id, peft_helper.r, loras)
174
175
176

    @classmethod
    def from_local_checkpoint(
177
178
179
180
181
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
182
        lora_model_id: int | None = None,
183
        device: str = "cuda",
184
185
186
187
188
189
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
190
    ) -> "LoRAModel":
191
        """Create a LoRAModel from a local checkpoint.
192

193
194
195
196
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
197
            peft_helper: Loaded lora configuration information.
198
            lora_model_id: LoRA model id. If not given, automatically set by
199
200
201
202
203
204
205
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
206
207
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
208
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
209
        new_embeddings_tensor_path = os.path.join(
210
211
212
            lora_dir, "new_embeddings.safetensors"
        )
        new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
213
        tensors: dict[str, torch.Tensor] = {}
214
        unexpected_modules: list[list[str] | str] = []
215
216
217

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
218
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
219
220
221
222
223
224
225
226
227
228
229
                # Handle FSDP file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
                # Case for expert lora weights
                if ".experts" in module_name:
                    if not any(
                        module_name.endswith(ele) for ele in expected_lora_modules
                    ):
                        unexpected_modules.append(module_name)
                elif module_name.split(".")[-1] not in expected_lora_modules:
230
                    unexpected_modules.append(module_name)
231

232
233
234
235
236
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
237
238
                    f" Please verify that the loaded LoRA module is correct"
                )
239
240
241
242
243

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
244
245
246
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
247
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
248
249
250
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
251
252
                **tensorizer_args.deserialization_kwargs,
            )
253
            check_unexpected_modules(tensors)
254

255
        elif os.path.isfile(lora_tensor_path):
256
257
258
259
260
261
262
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
263
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
264
                # Load tensors if there are only expected modules.
265
                check_unexpected_modules(f)
266
267
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
268
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
269
270
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
271
            unexpected_modules = []
272
            target_modules = peft_helper.target_modules
273
274
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
275
276
277
278
279
280
281
282
283
284
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
285
            if unexpected_modules and not is_regex_target_modules(
286
287
                peft_helper.target_modules, expected_lora_modules
            ):
288
289
290
291
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
292
293
294
295
296
297
298
299
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
300
301
302
303
304
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
305
            embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
306
        elif os.path.isfile(new_embeddings_bin_file_path):
307
308
309
            embeddings = torch.load(
                new_embeddings_bin_file_path, map_location=device, weights_only=True
            )
310
311

        return cls.from_lora_tensors(
312
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
313
            tensors=tensors,
314
            peft_helper=peft_helper,
315
316
317
318
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
319
            embedding_modules=embedding_modules,
320
            embedding_padding_modules=embedding_padding_modules,
321
322
            weights_mapper=weights_mapper,
        )
323
324


325
class LoRAModelManager:
326
327
328
329
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
330
        model: SupportsLoRA,
331
332
333
334
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
335
        device: torch.device,
336
337
338
339
340
341
342
343
344
345
346
347
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
348
349
350
351
352
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
353
        self.lora_config = lora_config
354
        self.device = device
355
356
357
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
358
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
359
        self.vocab_size = vocab_size
360
361
362
363
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
364
365
            max_loras=self.lora_config.max_loras,
        )
366

367
368
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
369
        f" {self.model.__class__.__name__}."
370

371
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
372
        # Used to indicate whether the model is a multimodal model
373
374
375
376
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
377
378
            and hasattr(self.model, "get_mm_mapping")
        )
379
        self.is_pooling_model = is_pooling_model(self.model)
380
381
382
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
383
        self._last_mapping: LoRAMapping | None = None
384
        self._create_lora_modules()
385
        self.model.lora_manager = self
386
387
388

    def __len__(self) -> int:
        return len(self._registered_adapters)
389
390
391
392
393
394
395
396
397

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

398
399
400
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
401

402
    def activate_adapter(
403
404
405
406
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
407
        if lora_id in self._active_adapters:
408
409
            return False
        first_free_slot = next(
410
411
412
413
414
415
416
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
417
418
419
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
420
421
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
422
423
424
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
425
426
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
427
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
428
429
            if module_lora:
                module_lora.optimize()
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
                # Note (gnovack) - If MOE lora weights are not split into
                # num_experts chunks, we split them here
                if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
                    module_lora.lora_a
                ):
                    # Handle FSDP file format where experts.base_layer is the
                    # gate_up_proj and experts is the down_proj
                    gate_up_proj_lora = self._get_lora_layer_weights(
                        lora_model, module_name + ".base_layer"
                    )

                    assert gate_up_proj_lora is not None
                    assert module_lora is not None

                    down_proj_lora = module_lora
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b

474
475
476
477
478
479
                module.set_lora(
                    index,
                    module_lora.lora_a,
                    module_lora.lora_b,
                    module_lora.embeddings_tensor,
                )
480
481
482
483
            else:
                module.reset_lora(index)
        return True

484
    def _deactivate_adapter(self, lora_id: int):
485
486
487
488
489
490
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

491
    def _add_adapter(self, lora: LoRAModel):
492
        self._create_merged_loras_inplace(lora)
493
        self._registered_adapters[lora.id] = lora
494

495
    def pin_adapter(self, lora_id: int) -> bool:
496
497
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
498
            "Pinning is not supported in LoRAModelManager. "
499
500
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
501

502
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
503
504
505
506
507
508
509
510
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
511

512
    def remove_all_adapters(self):
513
        """Remove all LoRAModels from the manager."""
514
        self._registered_adapters.clear()
515
        self.lora_index_to_id = [None] * self.lora_slots
516
        self._active_adapters.clear()
517
518

    def _create_lora_modules(self):
519
520
521
522
523
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
524
            return module_name.rpartition(".")[0]
525

526
        for module_name, module in self.model.named_modules(remove_duplicate=False):
527
528
            if isinstance(module, PPMissingLayer):
                continue
529

530
531
            if not self._match_target_modules(module_name):
                continue
532
533
534
535
536
537
538
539
540
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
541
542
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
543
            new_module = replace_submodule(
544
545
546
547
548
549
550
551
552
553
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
554

555
556
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
557
                logits_processor_module_name = "logits_processor"
558
559
560
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
561
562
                        f"{parent_module}.{logits_processor_module_name}"
                    )
563

564
                logits_processor_module = self.model.get_submodule(
565
566
                    logits_processor_module_name
                )
567

568
                new_module = replace_submodule(
569
570
571
572
573
574
575
576
577
578
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
579
580
581
582
583
584

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
585
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
586
                continue
587
588
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
589
590
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
591
592

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
593
594
595
596
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance,"
        )
        f" got {type(module)}"
597
598
        self.modules[module_name] = module

Terry's avatar
Terry committed
599
    def create_dummy_lora(
600
601
602
        self,
        lora_id: int,
        rank: int,
603
        embedding_modules: dict[str, str] | None = None,
604
    ) -> LoRAModel:
605
        """Create zero-initialized LoRAModel for warmup."""
606
        model = LoRAModel(lora_id, rank, {})
607
        for module_name, module in self.model.named_modules():
608
609
610
611
612
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
613
614
615
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
616
                assert embedding_modules is not None
Terry's avatar
Terry committed
617
                if parts[-1] in embedding_modules:
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
                    input_dim = (
                        module.base_layer.org_vocab_size
                        + self.lora_config.lora_extra_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
                    embeddings_tensor_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[1]
                    )
634
635
636
637
638
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
639
                        module.lora_a_stacked[0].dtype,
640
                        "cpu",
641
                        embeddings_tensor_dim=embeddings_tensor_dim,
642
                    )
643
644
645
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
646
647
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
648
                        rank,
649
                        module.lora_a_stacked[0].dtype,
650
651
652
653
654
                        "cpu",
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
655
                subloras: list[LoRALayerWeights | None] = []
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
673
674
675
676
677
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
678

679
680
681
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
682
        language model. LoRA for other modules, such as the vision tower, will
683
684
685
686
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
687
            prefix_lst = module_mapping.connector + module_mapping.tower_model
688
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
689
690
        return False

691
692
693
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
694
695
696
697
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
698
699
700
701
702
703
704
705
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
706
            replacement_loras: list[LoRALayerWeights | None] = []
707
            replaced_module: set[str] = set()
708
709
            has_replacement = False
            for r in new_module_names:
710
                lora = self._get_lora_layer_weights(lora_model, r)
711
712
713
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
714
                    replaced_module.add(r)
715
716
717
718
719
720
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
721
            # HACK Temporary solution for the pool model.
722
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
723
724
725
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
726
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
727
728
                replacement_loras
            )
729
730
731
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
732

733
    def _get_lora_layer_weights(
734
        self, lora_model: LoRAModel, module_name: str
735
    ) -> LoRALayerWeights | None:
736
        org_module_name = module_name
737
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
738
739
740
741
742
743
744
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
745
746
                    "after removing the prefix 'model.'."
                )
747
748
        return lora_model.get_lora(org_module_name)

749
    def deactivate_adapter(self, adapter_id: int) -> bool:
750
751
752
753
754
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
755
756

    def add_adapter(self, adapter: LoRAModel) -> bool:
757
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
758
759
760
761
762
763
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
764

765
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
766
767
768
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
769
770

    def remove_adapter(self, adapter_id: int) -> bool:
771
772
773
774
775
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
776

777
778
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
779

780
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
781
        return self._registered_adapters.get(adapter_id)
782
783
784


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
785
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
786
        super().__init__(capacity, deactivate_lora_fn)
787
788
789
790
791


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

792
793
794
795
796
797
798
799
800
801
802
803
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
804
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
805
806
            self.capacity, self.deactivate_adapter
        )
807
        self._active_adapters: LoRALRUCache = LoRALRUCache(
808
809
            self.lora_slots, self._deactivate_adapter
        )
810

811
    def list_adapters(self) -> dict[int, LoRAModel]:
812
        """List all registered LoRAModels."""
813
        return dict(self._registered_adapters.cache)
814

815
    def add_adapter(self, lora: LoRAModel) -> bool:
816
        """Add a LoRAModel to the manager."""
817
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
818
819
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
820
821
822
            was_added = True
        else:
            # We always touch to update the LRU cache order
823
            self._registered_adapters.touch(lora.id)
824
825
826
            was_added = False
        return was_added

827
    def activate_adapter(
828
829
830
        self,
        lora_id: int,
    ) -> bool:
831
832
833
834
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
835
836
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
837
        # We always touch to update the LRU cache order
838
        self._active_adapters.touch(lora_id)
839
840
        return result

841
842
843
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
844
845
846
            return True
        return False

847
    def pin_adapter(self, lora_id: int) -> bool:
848
849
850
851
852
853
854
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
855
            self._registered_adapters.pin(lora_id)
856
        except ValueError as err:
857
858
859
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
860
861

    def _pin_lora_in_gpu_cache(self, lora_id: int):
862
        if lora_id not in self._active_adapters:
863
            # move lora to gpu if not already active
864
            self.activate_adapter(lora_id)
865

866
        self._active_adapters.pin(lora_id)
867

868
869

def create_lora_manager(
870
871
872
873
874
875
876
877
878
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
879
    """Create a LoRA adapter for a given model."""
880
    if not isinstance(model, SupportsLoRA):
881
882
883
884
885
886
887
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
888
        device=device,
889
890
        **kwargs,
    )
891
    return lora_manager