models.py 33.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
15
16
17
18
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
19
from vllm.config import LoRAConfig
20
from vllm.logger import init_logger
21
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
22
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
23
from vllm.lora.peft_helper import PEFTHelper
24
from vllm.lora.punica_wrapper import get_punica_wrapper
25
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
26
                             get_supported_lora_modules,
27
                             is_regex_target_modules,
28
                             parse_fine_tuned_lora_name, replace_submodule)
29
from vllm.model_executor.layers.fused_moe import FusedMoE
30
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
31
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
32
from vllm.model_executor.models.interfaces import is_pooling_model
33
from vllm.model_executor.models.module_mapping import MultiModelKeys
34
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
35
from vllm.model_executor.utils import get_packed_modules_mapping
36
from vllm.utils import is_pin_memory_available
37

38
logger = init_logger(__name__)
39
40
41
42
43
44
45
46
47
48

_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


49
50
51
52
53
54
55
56
57
58
59
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
            "contain expert weights.")
        return True
    return False


60
class LoRAModel(AdapterModel):
61
62
63
64
65
66
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
67
        loras: dict[str, LoRALayerWeights],
68
    ) -> None:
69
70
71
72
73
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
74

75
        """
76
        self.id = lora_model_id
77

78
79
80
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
81
        self.rank = rank
82
        self.loras: dict[str, LoRALayerWeights] = loras
83

84
85
86
87
88
89
90
91
92
93
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

94
95
96
97
98
99
100
101
102
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

103
104
105
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

106
107
108
109
110
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
111
        tensors: dict[str, torch.Tensor],
112
        peft_helper: PEFTHelper,
113
114
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
115
        embeddings: Optional[dict[str, torch.Tensor]] = None,
116
        target_embedding_padding: Optional[int] = None,
117
118
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
119
        weights_mapper: Optional[WeightsMapper] = None,
120
121
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
122
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
123
        loras: dict[str, LoRALayerWeights] = {}
124
        for tensor_name, tensor in tensors.items():
125
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
126
                tensor_name, weights_mapper)
127
128
129
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
130
                    assert embedding_modules is not None
131
                    embeddings_module = next(
Terry's avatar
Terry committed
132
                        (k for k in embedding_modules if k in module_name),
133
134
135
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
136
                            embedding_modules[embeddings_module]].to(
137
138
139
140
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
141
142
143
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

144
145
146
147
148
149
150
151
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
152
153
154
155
156
157
158
159
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
160
                assert embedding_padding_modules is not None
161
                if any(name in module_name
Terry's avatar
Terry committed
162
                       for name in embedding_padding_modules
163
164
165
166
167
168
169
170
171
172
173
174
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
175

176
        return cls(lora_model_id, peft_helper.r, loras)
177
178
179

    @classmethod
    def from_local_checkpoint(
180
181
182
183
184
185
186
187
188
189
190
191
192
            cls,
            lora_dir: str,
            expected_lora_modules: list[str],
            peft_helper: PEFTHelper,
            *,
            lora_model_id: Optional[int] = None,
            device: str = "cuda",
            dtype: Optional[torch.dtype] = None,
            target_embedding_padding: Optional[int] = None,
            embedding_modules: Optional[dict[str, str]] = None,
            embedding_padding_modules: Optional[list[str]] = None,
            weights_mapper: Optional[WeightsMapper] = None,
            tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
193
        """Create a LoRAModel from a local checkpoint.
194

195
196
197
198
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
199
            peft_helper: Loaded lora configuration information.
200
            lora_model_id: LoRA model id. If not given, automatically set by
201
202
203
204
205
206
207
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
208
209
210
211
212
213
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
                    lora_module, weights_mapper)
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer
233

234
235
236
237
            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
                                            "adapter_model.tensors")
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
238
239
240
241
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
                **tensorizer_args.deserialization_kwargs)
242
            check_unexpected_modules(tensors)
243

244
        elif os.path.isfile(lora_tensor_path):
245
246
247
248
249
250
251
252
253
254
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                # Load tensors if there are only expected modules.
255
                check_unexpected_modules(f)
256
257
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
258
        elif os.path.isfile(lora_bin_file_path):
259
260
261
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
262
            target_modules = peft_helper.target_modules
263
264
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
265
266
267
268
269
270
271
272
273
274
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
275
            if unexpected_modules and not is_regex_target_modules(
276
                    peft_helper.target_modules, expected_lora_modules):
277
278
279
280
281
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
cyyever's avatar
cyyever committed
282
283
284
            tensors = torch.load(lora_bin_file_path,
                                 map_location=device,
                                 weights_only=True)
285
286
287
288
289
290
291
292
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
293
            embeddings = torch.load(new_embeddings_bin_file_path,
294
295
                                    map_location=device,
                                    weights_only=True)
296
297
298
299
300

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
301
            peft_helper=peft_helper,
302
303
304
305
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
306
            embedding_modules=embedding_modules,
307
308
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
309
310


311
class LoRAModelManager(AdapterModelManager):
312
313
314
315
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
316
        model: SupportsLoRA,
317
318
319
320
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
321
        device: torch.device,
322
323
324
325
326
327
328
329
330
331
332
333
334
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
335
        self.device = device
336
337
338
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
339
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
340
        self.vocab_size = vocab_size
341
342
343
344
345
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
            max_loras=self.lora_config.max_loras)
346

347
        super().__init__(model)
348
349
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
350
        f" {self.model.__class__.__name__}."
351
        
352
353
        if lora_config.lora_target_modules is not None:
            self.supported_lora_modules = lora_config.lora_target_modules
354
355

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
356
        # Used to indicate whether the model is a multimodal model
357
358
359
360
361
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
362
        self.is_pooling_model = is_pooling_model(self.model)
363
        self.is_moe_model = is_moe_model(self.model)
364
365
366
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
367
        self._last_mapping: Optional[LoRAMapping] = None
368
        self._create_lora_modules()
369
        self.model.lora_manager = self
370
        self.adapter_type = 'LoRA'
371
372
373
374
375
376
377
378
379

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

380
381
382
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
383

384
    def activate_adapter(
385
386
387
388
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
389
        if lora_id in self._active_adapters:
390
391
392
393
394
395
396
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
397
398
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
399
400
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
401
402
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
403
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
404
405
            if module_lora:
                module_lora.optimize()
406
407
408
409
410
411
412
413
414
415
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
416
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
417
418
                                module_lora.embeddings_tensor,
                                module_lora.bias)
419
420
421
422
            else:
                module.reset_lora(index)
        return True

423
    def _deactivate_adapter(self, lora_id: int):
424
425
426
427
428
429
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

430
    def _add_adapter(self, lora: LoRAModel):
431
        self._create_merged_loras_inplace(lora)
432
        self._registered_adapters[lora.id] = lora
433

434
    def pin_adapter(self, lora_id: int) -> bool:
435
436
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
437
            "Pinning is not supported in LoRAModelManager. "
438
439
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

440
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
441
442
443
444
445
446
447
448
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
449

450
    def remove_all_adapters(self):
451
        """Remove all LoRAModels from the manager."""
452
        self._registered_adapters.clear()
453
        self.lora_index_to_id = [None] * self.lora_slots
454
        self._active_adapters.clear()
455
456

    def _create_lora_modules(self):
457
458
459
460
461
462
463
464

        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
            return module_name.rpartition('.')[0]

465
466
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
467
468
            if isinstance(module, PPMissingLayer):
                continue
469
470
            if not self._match_target_modules(module_name):
                continue
471
472
473
474
475
476
477
478
479
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
480
481
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
482
483
484
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
485
                           packed_moduled_lst, self.model.config))
486

487
488
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
489
490
491
492
493
494
                logits_processor_module_name = 'logits_processor'
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
                        f"{parent_module}.{logits_processor_module_name}")

495
                logits_processor_module = self.model.get_submodule(
496
497
                    logits_processor_module_name)

498
                new_module = replace_submodule(
499
                    self.model, logits_processor_module_name,
500
501
502
503
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
504
505
506
507
508
509
510
511
512

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
513
514
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
515
516
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
517
518
519
520
521

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
522
523
524
525
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
526
            embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
527
        """Create zero-initialized LoRAModel for warmup."""
528
        model = LoRAModel(lora_id, rank, {})
529
        for module_name, module in self.model.named_modules():
530
            bias_enabled = self.lora_config.bias_enabled
531
532
533
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
                    or self._filter_unsupported_mm_module(module_name)):
534
535
536
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
537
                assert embedding_modules is not None
Terry's avatar
Terry committed
538
                if parts[-1] in embedding_modules:
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
555
                        module.lora_a_stacked[0].dtype,
556
                        "cpu",
557
558
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
559
560
561
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
562
563
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
564
                        rank,
565
                        module.lora_a_stacked[0].dtype,
566
                        "cpu",
567
                        bias_enabled=bias_enabled,
568
569
570
571
572
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
573
                subloras: list[Optional[LoRALayerWeights]] = []
574
575
576
577
578
579
580
581
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
582
                        bias_enabled=bias_enabled,
583
584
585
586
587
588
589
590
591
592
593
594
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
595
            for target_module in self.supported_lora_modules)
596

597
598
599
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
600
        language model. LoRA for other modules, such as the vision tower, will
601
602
603
604
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
605
606
607
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
608
609
        return False

610
611
612
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
613
614
615
616
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
617
618
619
620
621
622
623
624
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
625
626
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
627
628
            has_replacement = False
            for r in new_module_names:
629
                lora = self._get_lora_layer_weights(lora_model, r)
630
631
632
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
633
                    replaced_module.add(r)
634
635
636
637
638
639
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
640
641
642
643
644
645
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(
                    module_name):
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
646
647
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
648
649
650
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
    def _get_lora_layer_weights(
            self, lora_model: LoRAModel,
            module_name: str) -> Optional[LoRALayerWeights]:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(
                module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'.")
        return lora_model.get_lora(org_module_name)

668
669
670
671
672
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
673
674
        logger.debug("Adding lora. Model id: %d, "
                     "int id: %d", adapter.id, adapter.id)
675
676
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
677

678
679
680
681
682
683
684
685
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

686
    def list_adapters(self) -> dict[int, Any]:
687
688
689
690
691
692
693
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
694

695
696
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
697
        super().__init__(capacity, deactivate_lora_fn)
698
699
700
701
702


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

703
704
705
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
706
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
707
                         vocab_size, lora_config, device)
708
709
710
711
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
712

713
    def list_adapters(self) -> dict[int, LoRAModel]:
714
        """List all registered LoRAModels."""
715
        return dict(self._registered_adapters.cache)
716

717
    def add_adapter(self, lora: LoRAModel) -> bool:
718
        """Add a LoRAModel to the manager."""
719
720
        logger.debug("Adding lora. Model id: %d, "
                     "int id: %d", lora.id, lora.id)
721
722
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
723
724
725
            was_added = True
        else:
            # We always touch to update the LRU cache order
726
            self._registered_adapters.touch(lora.id)
727
728
729
            was_added = False
        return was_added

730
    def activate_adapter(
731
732
733
        self,
        lora_id: int,
    ) -> bool:
734
735
736
737
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
738
        # We always touch to update the LRU cache order
739
        self._active_adapters.touch(lora_id)
740
741
        return result

742
743
744
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
745
746
747
            return True
        return False

748
    def pin_adapter(self, lora_id: int) -> bool:
749
750
751
752
753
754
755
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
756
            self._registered_adapters.pin(lora_id)
757
758
759
760
761
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
762
        if lora_id not in self._active_adapters:
763
            # move lora to gpu if not already active
764
            self.activate_adapter(lora_id)
765

766
        self._active_adapters.pin(lora_id)
767

768
769
770
771
772
773
774

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
775
        device: torch.device,
776
        lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
777
778
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
779
    if not isinstance(model, SupportsLoRA):
780
781
782
783
784
785
786
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
787
        device=device,
788
789
        **kwargs)
    return lora_manager