models.py 33.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
24
    is_base_embeddding_weights,
25
26
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
27
    process_packed_modules_mapping,
28
29
    replace_submodule,
)
30
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
31
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
32
from vllm.model_executor.models.interfaces import is_pooling_model
33
from vllm.model_executor.models.module_mapping import MultiModelKeys
34
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
35
from vllm.utils.cache import LRUCache
36
from vllm.utils.platform_utils import is_pin_memory_available
37

38
logger = init_logger(__name__)
39

40
41
42
43
44
45
46
47
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

48
    def _on_remove(self, key: int, value: T | None):
49
50
51
52
53
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


54
55
56
57
58
59
60
61
62
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


63
class LoRAModel:
64
65
66
67
68
69
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
70
        loras: dict[str, LoRALayerWeights],
71
    ) -> None:
72
73
74
75
76
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
77

78
        """
79
        self.id = lora_model_id
80

81
82
83
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
84
        self.rank = rank
85
        self.loras: dict[str, LoRALayerWeights] = loras
86

87
88
89
90
91
92
93
94
95
96
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

97
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
98
99
100
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

101
102
103
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

104
105
106
107
108
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
109
        tensors: dict[str, torch.Tensor],
110
        peft_helper: PEFTHelper,
111
        device: str = "cuda",
112
113
114
115
116
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
117
118
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
119
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
120
        loras: dict[str, LoRALayerWeights] = {}
121
        for tensor_name, tensor in tensors.items():
122
123
            if is_base_embeddding_weights(tensor_name):
                continue
124
            module_name, is_lora_a = parse_fine_tuned_lora_name(
125
126
                tensor_name, weights_mapper
            )
127
            if module_name not in loras:
128
                loras[module_name] = LoRALayerWeights.from_config(
129
                    module_name, peft_helper
130
                )
131

132
            if is_lora_a:
133
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
134
                if pin_memory:
135
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
136
            else:
137
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
138
                assert embedding_padding_modules is not None
139
140
141
142
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
143
                    lora_b = loras[module_name].lora_b
144
145
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
146
                    loras[module_name].lora_b = torch.nn.functional.pad(
147
148
                        lora_b, (0, 0, 0, addition)
                    )
149
                if pin_memory:
150
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
151
152
153

        for lora in loras.values():
            lora.optimize()
154

155
        return cls(lora_model_id, peft_helper.r, loras)
156
157
158

    @classmethod
    def from_local_checkpoint(
159
160
161
162
163
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
164
        lora_model_id: int | None = None,
165
        device: str = "cuda",
166
167
168
169
170
171
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
172
    ) -> "LoRAModel":
173
        """Create a LoRAModel from a local checkpoint.
174

175
176
177
178
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
179
            peft_helper: Loaded lora configuration information.
180
            lora_model_id: LoRA model id. If not given, automatically set by
181
182
183
184
185
186
187
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
188
189
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
190
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
191
192
193
194
        # new_embeddings_tensor_path = os.path.join(
        #     lora_dir, "new_embeddings.safetensors"
        # )
        # new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
195
        tensors: dict[str, torch.Tensor] = {}
196
        unexpected_modules: list[list[str] | str] = []
197
198
199

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
200
201
                if is_base_embeddding_weights(lora_module):
                    continue
202
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
203
204
205
206
207
208
209
210
211
212
213
                # Handle FSDP file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                if "base_layer" in lora_module:
                    continue
                # Case for expert lora weights
                if ".experts" in module_name:
                    if not any(
                        module_name.endswith(ele) for ele in expected_lora_modules
                    ):
                        unexpected_modules.append(module_name)
                elif module_name.split(".")[-1] not in expected_lora_modules:
214
                    unexpected_modules.append(module_name)
215

216
217
218
219
220
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
221
222
                    f" Please verify that the loaded LoRA module is correct"
                )
223
224
225
226
227

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
228
229
230
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
231
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
232
233
234
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
235
236
                **tensorizer_args.deserialization_kwargs,
            )
237
            check_unexpected_modules(tensors)
238

239
        elif os.path.isfile(lora_tensor_path):
240
241
242
243
244
245
246
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
247
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
248
                # Load tensors if there are only expected modules.
249
                check_unexpected_modules(f)
250
251
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
252
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
253
254
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
255
            unexpected_modules = []
256
            target_modules = peft_helper.target_modules
257
258
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
259
260
261
262
263
264
265
266
267
268
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
269
            if unexpected_modules and not is_regex_target_modules(
270
271
                peft_helper.target_modules, expected_lora_modules
            ):
272
273
274
275
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
276
277
278
279
280
281
282
283
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
284
285
286
287
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        return cls.from_lora_tensors(
288
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
289
            tensors=tensors,
290
            peft_helper=peft_helper,
291
292
293
            device=device,
            dtype=dtype,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
294
            embedding_modules=embedding_modules,
295
            embedding_padding_modules=embedding_padding_modules,
296
297
            weights_mapper=weights_mapper,
        )
298
299


300
class LoRAModelManager:
301
302
303
304
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
305
        model: SupportsLoRA,
306
307
308
309
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
310
        device: torch.device,
311
312
313
314
315
316
317
318
319
320
321
322
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
323
324
325
326
327
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
328
        self.lora_config = lora_config
329
        self.device = device
330
331
332
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
333
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
334
        self.vocab_size = vocab_size
335
336
337
338
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
339
340
            max_loras=self.lora_config.max_loras,
        )
341

342
343
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
344
        f" {self.model.__class__.__name__}."
345

346
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
347
        # Used to indicate whether the model is a multimodal model
348
349
350
351
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
352
353
            and hasattr(self.model, "get_mm_mapping")
        )
354
        self.is_pooling_model = is_pooling_model(self.model)
355
356
357
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
358
        self._last_mapping: LoRAMapping | None = None
359
        self._create_lora_modules()
360
        self.model.lora_manager = self
361
362
363

    def __len__(self) -> int:
        return len(self._registered_adapters)
364
365
366
367
368
369
370
371
372

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

373
374
375
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
376

377
    def activate_adapter(
378
379
380
381
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
382
        if lora_id in self._active_adapters:
383
384
            return False
        first_free_slot = next(
385
386
387
388
389
390
391
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
392
393
394
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
395
396
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
397
398
399
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
400
401
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
402
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
403
            if module_lora:
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                # Note (gnovack) - If MOE lora weights are not split into
                # num_experts chunks, we split them here
                if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
                    module_lora.lora_a
                ):
                    # Handle FSDP file format where experts.base_layer is the
                    # gate_up_proj and experts is the down_proj
                    gate_up_proj_lora = self._get_lora_layer_weights(
                        lora_model, module_name + ".base_layer"
                    )

                    assert gate_up_proj_lora is not None
                    assert module_lora is not None

                    down_proj_lora = module_lora
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b

448
449
450
451
452
                module.set_lora(
                    index,
                    module_lora.lora_a,
                    module_lora.lora_b,
                )
453
454
455
456
            else:
                module.reset_lora(index)
        return True

457
    def _deactivate_adapter(self, lora_id: int):
458
459
460
461
462
463
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

464
    def _add_adapter(self, lora: LoRAModel):
465
        self._create_merged_loras_inplace(lora)
466
        self._registered_adapters[lora.id] = lora
467

468
    def pin_adapter(self, lora_id: int) -> bool:
469
470
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
471
            "Pinning is not supported in LoRAModelManager. "
472
473
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
474

475
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
476
477
478
479
480
481
482
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
483

484
    def remove_all_adapters(self):
485
        """Remove all LoRAModels from the manager."""
486
        self._registered_adapters.clear()
487
        self.lora_index_to_id = [None] * self.lora_slots
488
        self._active_adapters.clear()
489
490

    def _create_lora_modules(self):
491
492
493
494
495
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
496
            return module_name.rpartition(".")[0]
497

498
        for module_name, module in self.model.named_modules(remove_duplicate=False):
499
500
            if isinstance(module, PPMissingLayer):
                continue
501

502
503
            if not self._match_target_modules(module_name):
                continue
504
505
506
507
508
509
510
511
512
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
513
514
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
515
            new_module = replace_submodule(
516
517
518
519
520
521
522
523
524
525
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
526

527
528
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
529
                logits_processor_module_name = "logits_processor"
530
531
532
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
533
534
                        f"{parent_module}.{logits_processor_module_name}"
                    )
535

536
                logits_processor_module = self.model.get_submodule(
537
538
                    logits_processor_module_name
                )
539

540
                new_module = replace_submodule(
541
542
543
544
545
546
547
548
549
550
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
551
552
553
554
555
556

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
557
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
558
                continue
559
560
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
561
562
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
563
564

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
565
566
567
568
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance,"
        )
        f" got {type(module)}"
569
570
        self.modules[module_name] = module

Terry's avatar
Terry committed
571
    def create_dummy_lora(
572
573
574
        self,
        lora_id: int,
        rank: int,
575
        embedding_modules: dict[str, str] | None = None,
576
    ) -> LoRAModel:
577
        """Create zero-initialized LoRAModel for warmup."""
578
        model = LoRAModel(lora_id, rank, {})
579
        for module_name, module in self.model.named_modules():
580
581
582
583
584
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
585
586
587
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
588
                assert embedding_modules is not None
Terry's avatar
Terry committed
589
                if parts[-1] in embedding_modules:
590
591
592
593
594
595
596
597
598
599
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
600
601
602
603
604
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
605
                        module.lora_a_stacked[0].dtype,
606
                        "cpu",
607
                    )
608
609
610
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
611
612
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
613
                        rank,
614
                        module.lora_a_stacked[0].dtype,
615
616
617
618
619
                        "cpu",
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
620
                subloras: list[LoRALayerWeights | None] = []
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
638
639
640
641
642
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
643

644
645
646
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
647
        language model. LoRA for other modules, such as the vision tower, will
648
649
650
651
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
652
            prefix_lst = module_mapping.connector + module_mapping.tower_model
653
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
654
655
        return False

656
657
658
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
659
660
661
662
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
663
664
665
666
667
668
669
670
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
671
            replacement_loras: list[LoRALayerWeights | None] = []
672
            replaced_module: set[str] = set()
673
674
            has_replacement = False
            for r in new_module_names:
675
                lora = self._get_lora_layer_weights(lora_model, r)
676
677
678
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
679
                    replaced_module.add(r)
680
681
682
683
684
685
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
686
            # HACK Temporary solution for the pool model.
687
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
688
689
690
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
691
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
692
693
                replacement_loras
            )
694
695
696
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
697

698
    def _get_lora_layer_weights(
699
        self, lora_model: LoRAModel, module_name: str
700
    ) -> LoRALayerWeights | None:
701
        org_module_name = module_name
702
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
703
704
705
706
707
708
709
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
710
711
                    "after removing the prefix 'model.'."
                )
712
713
        return lora_model.get_lora(org_module_name)

714
    def deactivate_adapter(self, adapter_id: int) -> bool:
715
716
717
718
719
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
720
721

    def add_adapter(self, adapter: LoRAModel) -> bool:
722
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
723
724
725
726
727
728
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
729

730
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
731
732
733
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
734
735

    def remove_adapter(self, adapter_id: int) -> bool:
736
737
738
739
740
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
741

742
743
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
744

745
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
746
        return self._registered_adapters.get(adapter_id)
747
748
749


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
750
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
751
        super().__init__(capacity, deactivate_lora_fn)
752
753
754
755
756


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

757
758
759
760
761
762
763
764
765
766
767
768
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
769
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
770
771
            self.capacity, self.deactivate_adapter
        )
772
        self._active_adapters: LoRALRUCache = LoRALRUCache(
773
774
            self.lora_slots, self._deactivate_adapter
        )
775

776
    def list_adapters(self) -> dict[int, LoRAModel]:
777
        """List all registered LoRAModels."""
778
        return dict(self._registered_adapters.cache)
779

780
    def add_adapter(self, lora: LoRAModel) -> bool:
781
        """Add a LoRAModel to the manager."""
782
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
783
784
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
785
786
787
            was_added = True
        else:
            # We always touch to update the LRU cache order
788
            self._registered_adapters.touch(lora.id)
789
790
791
            was_added = False
        return was_added

792
    def activate_adapter(
793
794
795
        self,
        lora_id: int,
    ) -> bool:
796
797
798
799
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
800
801
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
802
        # We always touch to update the LRU cache order
803
        self._active_adapters.touch(lora_id)
804
805
        return result

806
807
808
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
809
810
811
            return True
        return False

812
    def pin_adapter(self, lora_id: int) -> bool:
813
814
815
816
817
818
819
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
820
            self._registered_adapters.pin(lora_id)
821
        except ValueError as err:
822
823
824
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
825
826

    def _pin_lora_in_gpu_cache(self, lora_id: int):
827
        if lora_id not in self._active_adapters:
828
            # move lora to gpu if not already active
829
            self.activate_adapter(lora_id)
830

831
        self._active_adapters.pin(lora_id)
832

833
834

def create_lora_manager(
835
836
837
838
839
840
841
842
843
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
844
    """Create a LoRA adapter for a given model."""
845
    if not isinstance(model, SupportsLoRA):
846
847
848
849
850
851
852
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
853
        device=device,
854
855
        **kwargs,
    )
856
    return lora_manager