"vllm/model_executor/models/opt.py" did not exist on "2f49f155858faaf82bfd076a821497e41e961658"
models.py 34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
from collections.abc import Sequence
7
from typing import Callable, Optional, TypeVar, Union
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
24
25
26
27
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
    replace_submodule,
)
28
from vllm.model_executor.layers.fused_moe import FusedMoE
29
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
30
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
31
from vllm.model_executor.models.interfaces import is_pooling_model
32
from vllm.model_executor.models.module_mapping import MultiModelKeys
33
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
34
from vllm.model_executor.utils import get_packed_modules_mapping
35
from vllm.utils import LRUCache, is_pin_memory_available
36

37
logger = init_logger(__name__)
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

    def _on_remove(self, key: int, value: Optional[T]):
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


53
54
55
56
57
58
59
60
61
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


62
63
64
65
66
67
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
68
69
            "contain expert weights."
        )
70
71
72
73
        return True
    return False


74
class LoRAModel:
75
76
77
78
79
80
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
81
        loras: dict[str, LoRALayerWeights],
82
    ) -> None:
83
84
85
86
87
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
88

89
        """
90
        self.id = lora_model_id
91

92
93
94
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
95
        self.rank = rank
96
        self.loras: dict[str, LoRALayerWeights] = loras
97

98
99
100
101
102
103
104
105
106
107
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

108
109
    @property
    def extra_vocab_size(self) -> int:
110
111
112
113
114
        return (
            max(lora.extra_vocab_size for lora in self.loras.values())
            if self.loras
            else 0
        )
115
116
117
118
119

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

120
121
122
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

123
124
125
126
127
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
128
        tensors: dict[str, torch.Tensor],
129
        peft_helper: PEFTHelper,
130
131
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
132
        embeddings: Optional[dict[str, torch.Tensor]] = None,
133
        target_embedding_padding: Optional[int] = None,
134
135
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
136
        weights_mapper: Optional[WeightsMapper] = None,
137
138
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
139
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
140
        loras: dict[str, LoRALayerWeights] = {}
141
        for tensor_name, tensor in tensors.items():
142
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
143
144
                tensor_name, weights_mapper
            )
145
146
147
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
148
                    assert embedding_modules is not None
149
                    embeddings_module = next(
150
151
                        (k for k in embedding_modules if k in module_name), None
                    )
152
153
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
154
155
                            embedding_modules[embeddings_module]
                        ].to(device=device, dtype=dtype)
156
                        if pin_memory:
157
                            lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
158
                loras[module_name] = LoRALayerWeights.from_config(
159
160
                    module_name, peft_helper, lora_embeddings_tensor
                )
161

162
            if is_bias:
163
164
                loras[module_name].bias = tensor.to(device=device, dtype=dtype)
                bias = tensor.to(device=device, dtype=dtype)
165
166
167
168
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
169
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
170
                if pin_memory:
171
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
172
            else:
173
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
174
                assert embedding_padding_modules is not None
175
176
177
178
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
179
                    lora_b = loras[module_name].lora_b
180
181
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
182
                    loras[module_name].lora_b = torch.nn.functional.pad(
183
184
                        lora_b, (0, 0, 0, addition)
                    )
185
                if pin_memory:
186
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
187
188
189

        for lora in loras.values():
            lora.optimize()
190

191
        return cls(lora_model_id, peft_helper.r, loras)
192
193
194

    @classmethod
    def from_local_checkpoint(
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
        weights_mapper: Optional[WeightsMapper] = None,
        tensorizer_config_dict: Optional[dict] = None,
    ) -> "LoRAModel":
209
        """Create a LoRAModel from a local checkpoint.
210

211
212
213
214
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
215
            peft_helper: Loaded lora configuration information.
216
            lora_model_id: LoRA model id. If not given, automatically set by
217
218
219
220
221
222
223
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
224
225
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
226
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
227
        new_embeddings_tensor_path = os.path.join(
228
229
230
            lora_dir, "new_embeddings.safetensors"
        )
        new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
231
232
233
234
235
236
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
237
238
                    lora_module, weights_mapper
                )
239
240
241
242
243
244
245
246
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
247
248
                    f" Please verify that the loaded LoRA module is correct"
                )
249
250
251
252
253

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
254
255
256
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
257
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
258
259
260
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
261
262
                **tensorizer_args.deserialization_kwargs,
            )
263
            check_unexpected_modules(tensors)
264

265
        elif os.path.isfile(lora_tensor_path):
266
267
268
269
270
271
272
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
273
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
274
                # Load tensors if there are only expected modules.
275
                check_unexpected_modules(f)
276
277
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
278
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
279
280
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
281
            unexpected_modules = []
282
            target_modules = peft_helper.target_modules
283
284
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
285
286
287
288
289
290
291
292
293
294
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
295
            if unexpected_modules and not is_regex_target_modules(
296
297
                peft_helper.target_modules, expected_lora_modules
            ):
298
299
300
301
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
302
303
304
305
306
307
308
309
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
310
311
312
313
314
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
315
            embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
316
        elif os.path.isfile(new_embeddings_bin_file_path):
317
318
319
            embeddings = torch.load(
                new_embeddings_bin_file_path, map_location=device, weights_only=True
            )
320
321

        return cls.from_lora_tensors(
322
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
323
            tensors=tensors,
324
            peft_helper=peft_helper,
325
326
327
328
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
329
            embedding_modules=embedding_modules,
330
            embedding_padding_modules=embedding_padding_modules,
331
332
            weights_mapper=weights_mapper,
        )
333
334


335
class LoRAModelManager:
336
337
338
339
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
340
        model: SupportsLoRA,
341
342
343
344
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
345
        device: torch.device,
346
347
348
349
350
351
352
353
354
355
356
357
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
358
359
360
361
362
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
363
        self.lora_config = lora_config
364
        self.device = device
365
366
367
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
368
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
369
        self.vocab_size = vocab_size
370
371
372
373
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
374
375
            max_loras=self.lora_config.max_loras,
        )
376

377
378
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
379
        f" {self.model.__class__.__name__}."
380
381

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
382
        # Used to indicate whether the model is a multimodal model
383
384
385
386
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
387
388
            and hasattr(self.model, "get_mm_mapping")
        )
389
        self.is_pooling_model = is_pooling_model(self.model)
390
        self.is_moe_model = is_moe_model(self.model)
391
392
393
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
394
        self._last_mapping: Optional[LoRAMapping] = None
395
        self._create_lora_modules()
396
        self.model.lora_manager = self
397
398
399

    def __len__(self) -> int:
        return len(self._registered_adapters)
400
401
402
403
404
405
406
407
408

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

409
410
411
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
412

413
    def activate_adapter(
414
415
416
417
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
418
        if lora_id in self._active_adapters:
419
420
            return False
        first_free_slot = next(
421
422
423
424
425
426
427
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
428
429
430
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
431
432
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
433
434
435
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
436
437
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
438
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
439
440
            if module_lora:
                module_lora.optimize()
441
442
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
443
444
445
446
                if (
                    torch.is_tensor(bias)
                    or (isinstance(bias, Sequence) and any(b is not None for b in bias))
                ) and not self.lora_config.bias_enabled:
447
448
449
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
450
451
452
453
454
455
456
457
458
                        " without --enable-lora-bias."
                    )
                module.set_lora(
                    index,
                    module_lora.lora_a,
                    module_lora.lora_b,
                    module_lora.embeddings_tensor,
                    module_lora.bias,
                )
459
460
461
462
            else:
                module.reset_lora(index)
        return True

463
    def _deactivate_adapter(self, lora_id: int):
464
465
466
467
468
469
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

470
    def _add_adapter(self, lora: LoRAModel):
471
        self._create_merged_loras_inplace(lora)
472
        self._registered_adapters[lora.id] = lora
473

474
    def pin_adapter(self, lora_id: int) -> bool:
475
476
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
477
            "Pinning is not supported in LoRAModelManager. "
478
479
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
480

481
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
482
483
484
485
486
487
488
489
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
490

491
    def remove_all_adapters(self):
492
        """Remove all LoRAModels from the manager."""
493
        self._registered_adapters.clear()
494
        self.lora_index_to_id = [None] * self.lora_slots
495
        self._active_adapters.clear()
496
497

    def _create_lora_modules(self):
498
499
500
501
502
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
503
            return module_name.rpartition(".")[0]
504

505
        for module_name, module in self.model.named_modules(remove_duplicate=False):
506
507
            if isinstance(module, PPMissingLayer):
                continue
508
509
            if not self._match_target_modules(module_name):
                continue
510
511
512
513
514
515
516
517
518
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
519
520
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
521
            new_module = replace_submodule(
522
523
524
525
526
527
528
529
530
531
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
532

533
534
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
535
                logits_processor_module_name = "logits_processor"
536
537
538
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
539
540
                        f"{parent_module}.{logits_processor_module_name}"
                    )
541

542
                logits_processor_module = self.model.get_submodule(
543
544
                    logits_processor_module_name
                )
545

546
                new_module = replace_submodule(
547
548
549
550
551
552
553
554
555
556
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
557
558
559
560
561
562

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
563
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
564
                continue
565
566
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
567
568
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
569
570
571
572
573

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
574
    def create_dummy_lora(
575
576
577
578
579
        self,
        lora_id: int,
        rank: int,
        embedding_modules: Optional[dict[str, str]] = None,
    ) -> LoRAModel:
580
        """Create zero-initialized LoRAModel for warmup."""
581
        model = LoRAModel(lora_id, rank, {})
582
        for module_name, module in self.model.named_modules():
583
            bias_enabled = self.lora_config.bias_enabled
584
585
586
587
588
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
589
590
591
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
592
                assert embedding_modules is not None
Terry's avatar
Terry committed
593
                if parts[-1] in embedding_modules:
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                    input_dim = (
                        module.base_layer.org_vocab_size
                        + self.lora_config.lora_extra_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
                    embeddings_tensor_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[1]
                    )
610
611
612
613
614
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
615
                        module.lora_a_stacked[0].dtype,
616
                        "cpu",
617
                        embeddings_tensor_dim=embeddings_tensor_dim,
618
619
                        bias_enabled=bias_enabled,
                    )
620
621
622
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
623
624
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
625
                        rank,
626
                        module.lora_a_stacked[0].dtype,
627
                        "cpu",
628
                        bias_enabled=bias_enabled,
629
630
631
632
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
633
                subloras: list[Optional[LoRALayerWeights]] = []
634
635
636
637
638
639
640
641
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
642
                        bias_enabled=bias_enabled,
643
644
645
646
647
648
649
650
651
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
652
653
654
655
656
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
657

658
659
660
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
661
        language model. LoRA for other modules, such as the vision tower, will
662
663
664
665
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
666
            prefix_lst = module_mapping.connector + module_mapping.tower_model
667
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
668
669
        return False

670
671
672
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
673
674
675
676
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
677
678
679
680
681
682
683
684
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
685
686
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
687
688
            has_replacement = False
            for r in new_module_names:
689
                lora = self._get_lora_layer_weights(lora_model, r)
690
691
692
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
693
                    replaced_module.add(r)
694
695
696
697
698
699
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
700
            # HACK Temporary solution for the pool model.
701
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
702
703
704
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
705
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
706
707
                replacement_loras
            )
708
709
710
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
711

712
    def _get_lora_layer_weights(
713
714
        self, lora_model: LoRAModel, module_name: str
    ) -> Optional[LoRALayerWeights]:
715
        org_module_name = module_name
716
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
717
718
719
720
721
722
723
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
724
725
                    "after removing the prefix 'model.'."
                )
726
727
        return lora_model.get_lora(org_module_name)

728
    def deactivate_adapter(self, adapter_id: int) -> bool:
729
730
731
732
733
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
734
735

    def add_adapter(self, adapter: LoRAModel) -> bool:
736
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
737
738
739
740
741
742
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
743

744
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
745
746
747
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
748
749

    def remove_adapter(self, adapter_id: int) -> bool:
750
751
752
753
754
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
755

756
757
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
758

759
760
    def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]:
        return self._registered_adapters.get(adapter_id)
761
762
763


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
764
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
765
        super().__init__(capacity, deactivate_lora_fn)
766
767
768
769
770


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

771
772
773
774
775
776
777
778
779
780
781
782
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
783
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
784
785
            self.capacity, self.deactivate_adapter
        )
786
        self._active_adapters: LoRALRUCache = LoRALRUCache(
787
788
            self.lora_slots, self._deactivate_adapter
        )
789

790
    def list_adapters(self) -> dict[int, LoRAModel]:
791
        """List all registered LoRAModels."""
792
        return dict(self._registered_adapters.cache)
793

794
    def add_adapter(self, lora: LoRAModel) -> bool:
795
        """Add a LoRAModel to the manager."""
796
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
797
798
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
799
800
801
            was_added = True
        else:
            # We always touch to update the LRU cache order
802
            self._registered_adapters.touch(lora.id)
803
804
805
            was_added = False
        return was_added

806
    def activate_adapter(
807
808
809
        self,
        lora_id: int,
    ) -> bool:
810
811
812
813
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
814
815
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
816
        # We always touch to update the LRU cache order
817
        self._active_adapters.touch(lora_id)
818
819
        return result

820
821
822
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
823
824
825
            return True
        return False

826
    def pin_adapter(self, lora_id: int) -> bool:
827
828
829
830
831
832
833
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
834
            self._registered_adapters.pin(lora_id)
835
        except ValueError as err:
836
837
838
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
839
840

    def _pin_lora_in_gpu_cache(self, lora_id: int):
841
        if lora_id not in self._active_adapters:
842
            # move lora to gpu if not already active
843
            self.activate_adapter(lora_id)
844

845
        self._active_adapters.pin(lora_id)
846

847
848

def create_lora_manager(
849
850
851
852
853
854
855
856
857
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
858
    """Create a LoRA adapter for a given model."""
859
    if not isinstance(model, SupportsLoRA):
860
861
862
863
864
865
866
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
867
        device=device,
868
869
        **kwargs,
    )
870
    return lora_manager