models.py 35.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
from collections.abc import Sequence
7
from dataclasses import dataclass, field
8
from typing import Any, Callable, Optional, Union
9

10
import regex as re
11
12
13
14
import safetensors.torch
import torch
from torch import nn

15
16
17
18
19
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
20
from vllm.config import LoRAConfig
21
from vllm.logger import init_logger
22
from vllm.lora.layers import (BaseLayerWithLoRA,
23
                              LinearScalingRotaryEmbeddingWithLoRA,
24
                              LoRAMapping)
25
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
26
from vllm.lora.peft_helper import PEFTHelper
27
from vllm.lora.punica_wrapper import get_punica_wrapper
28
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
29
                             get_supported_lora_modules,
30
                             is_regex_target_modules,
31
                             parse_fine_tuned_lora_name, replace_submodule)
32
from vllm.model_executor.layers.fused_moe import FusedMoE
33
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
34
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
35
from vllm.model_executor.models.interfaces import is_pooling_model
36
from vllm.model_executor.models.module_mapping import MultiModelKeys
37
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
38
from vllm.model_executor.utils import get_packed_modules_mapping
39
from vllm.utils import is_pin_memory_available
40

41
logger = init_logger(__name__)
42
43
44
45

_GLOBAL_LORA_ID = 0


46
47
48
49
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
50
    scaling_factors: list[float]
51
52
53
54
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
55
    offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
56
57


58
59
60
61
62
63
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


64
65
66
67
68
69
70
71
72
73
74
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
            "contain expert weights.")
        return True
    return False


75
class LoRAModel(AdapterModel):
76
77
78
79
80
81
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
82
        loras: dict[str, LoRALayerWeights],
83
        scaling_factor: Optional[float] = None,
84
    ) -> None:
85
86
87
88
89
90
91
92
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
93
        self.id = lora_model_id
94
95
96
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
97
98
99
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
100
        self.rank = rank
101
        self.loras: dict[str, LoRALayerWeights] = loras
102

103
104
105
106
107
108
109
110
111
112
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

113
114
115
116
117
118
119
120
121
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

122
123
124
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

125
126
127
128
129
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
130
        tensors: dict[str, torch.Tensor],
131
        peft_helper: PEFTHelper,
132
133
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
134
        embeddings: Optional[dict[str, torch.Tensor]] = None,
135
        target_embedding_padding: Optional[int] = None,
136
137
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
138
        weights_mapper: Optional[WeightsMapper] = None,
139
140
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
141
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
142
        loras: dict[str, LoRALayerWeights] = {}
143
        for tensor_name, tensor in tensors.items():
144
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
145
                tensor_name, weights_mapper)
146
147
148
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
149
                    assert embedding_modules is not None
150
                    embeddings_module = next(
Terry's avatar
Terry committed
151
                        (k for k in embedding_modules if k in module_name),
152
153
154
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
155
                            embedding_modules[embeddings_module]].to(
156
157
158
159
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
160
161
162
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

163
164
165
166
167
168
169
170
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
171
172
173
174
175
176
177
178
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
179
                assert embedding_padding_modules is not None
180
                if any(name in module_name
Terry's avatar
Terry committed
181
                       for name in embedding_padding_modules
182
183
184
185
186
187
188
189
190
191
192
193
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
194
195
196
197

        return cls(lora_model_id,
                   peft_helper.r,
                   loras,
198
                   scaling_factor=peft_helper.vllm_long_context_scaling_factor)
199
200
201

    @classmethod
    def from_local_checkpoint(
202
203
204
205
206
207
208
209
210
211
212
213
214
            cls,
            lora_dir: str,
            expected_lora_modules: list[str],
            peft_helper: PEFTHelper,
            *,
            lora_model_id: Optional[int] = None,
            device: str = "cuda",
            dtype: Optional[torch.dtype] = None,
            target_embedding_padding: Optional[int] = None,
            embedding_modules: Optional[dict[str, str]] = None,
            embedding_padding_modules: Optional[list[str]] = None,
            weights_mapper: Optional[WeightsMapper] = None,
            tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
215
        """Create a LoRAModel from a local checkpoint.
216

217
218
219
220
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
221
            peft_helper: Loaded lora configuration information.
222
            lora_model_id: LoRA model id. If not given, automatically set by
223
224
225
226
227
228
229
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
230
231
232
233
234
235
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
                    lora_module, weights_mapper)
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
                                            "adapter_model.tensors")
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
260
261
262
263
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
                **tensorizer_args.deserialization_kwargs)
264
            check_unexpected_modules(tensors)
265

266
        elif os.path.isfile(lora_tensor_path):
267
268
269
270
271
272
273
274
275
276
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                # Load tensors if there are only expected modules.
277
                check_unexpected_modules(f)
278
279
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
280
        elif os.path.isfile(lora_bin_file_path):
281
282
283
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
284
            target_modules = peft_helper.target_modules
285
286
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
287
288
289
290
291
292
293
294
295
296
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
297
            if unexpected_modules and not is_regex_target_modules(
298
                    peft_helper.target_modules, expected_lora_modules):
299
300
301
302
303
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
cyyever's avatar
cyyever committed
304
305
306
            tensors = torch.load(lora_bin_file_path,
                                 map_location=device,
                                 weights_only=True)
307
308
309
310
311
312
313
314
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
315
            embeddings = torch.load(new_embeddings_bin_file_path,
316
317
                                    map_location=device,
                                    weights_only=True)
318
319
320
321
322

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
323
            peft_helper=peft_helper,
324
325
326
327
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
328
            embedding_modules=embedding_modules,
329
330
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
331
332


333
class LoRAModelManager(AdapterModelManager):
334
335
336
337
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
338
        model: SupportsLoRA,
339
340
341
342
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
343
        device: torch.device,
344
345
346
347
348
349
350
351
352
353
354
355
356
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
357
        self.device = device
358
359
360
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
361
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
362
        self.vocab_size = vocab_size
363
        self.long_lora_context: Optional[LongContextLoRAContext] = None
364
365
366
367
368
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
            max_loras=self.lora_config.max_loras)
369
370
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
371
        self.scaling_factor_to_offset: dict[float, int] = {}
372
        super().__init__(model)
373

374
375
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
376
        f" {self.model.__class__.__name__}."
377
378
379
380
        if lora_config.long_lora_scaling_factors:
            # We need to replace rotary emb layer to do batch computation
            # for long lora.
            self.supported_lora_modules.append("rotary_emb")
381
382

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
383
        # Used to indicate whether the model is a multimodal model
384
385
386
387
388
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
389
        self.is_pooling_model = is_pooling_model(self.model)
390
        self.is_moe_model = is_moe_model(self.model)
391
392
393
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
394
        self._last_mapping: Optional[LoRAMapping] = None
395
        self._create_lora_modules()
396
        self.model.lora_manager = self
397
        self.adapter_type = 'LoRA'
398
399
400
401
402
403
404
405
406

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

407
408
409
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
410

411
    def activate_adapter(
412
413
414
415
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
416
        if lora_id in self._active_adapters:
417
418
419
420
421
422
423
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
424
425
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
426
427
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
428
429
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
430
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
431
432
            if module_lora:
                module_lora.optimize()
433
434
435
436
437
438
439
440
441
442
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
443
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
444
445
                                module_lora.embeddings_tensor,
                                module_lora.bias)
446
447
448
449
            else:
                module.reset_lora(index)
        return True

450
    def _deactivate_adapter(self, lora_id: int):
451
452
453
454
455
456
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

472
    def _add_adapter(self, lora: LoRAModel):
473
        self._create_merged_loras_inplace(lora)
474
        self._registered_adapters[lora.id] = lora
475
        self._set_long_lora_context(lora)
476

477
    def pin_adapter(self, lora_id: int) -> bool:
478
479
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
480
            "Pinning is not supported in LoRAModelManager. "
481
482
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

483
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
484
485
486
487
488
489
490
491
492
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
493

494
    def remove_all_adapters(self):
495
        """Remove all LoRAModels from the manager."""
496
        self._registered_adapters.clear()
497
        self.lora_index_to_id = [None] * self.lora_slots
498
        self._active_adapters.clear()
499
500

    def _create_lora_modules(self):
501
502
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
503
504
            if isinstance(module, PPMissingLayer):
                continue
505
506
            if not self._match_target_modules(module_name):
                continue
507
508
509
510
511
512
513
514
515
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
516
517
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
518
519
520
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
521
                           packed_moduled_lst, self.model.config))
522

523
            # LinearScalingRotaryEmbeddingWithLoRA is used to handle
524
            # long context lora. Register relevant metadata.
525
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
526
527
528
529
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
530
531
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
532
533
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
534
                new_module = replace_submodule(
535
536
537
538
539
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
540
541
542
543
544
545
546
547
548

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
549
550
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
551
552
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
553
554
555
556
557

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
558
559
560
561
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
562
            scaling_factor: Optional[float],
563
            embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
564
        """Create zero-initialized LoRAModel for warmup."""
565
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
566
        for module_name, module in self.model.named_modules():
567
            bias_enabled = self.lora_config.bias_enabled
568
569
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
570
                    or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
571
                    or self._filter_unsupported_mm_module(module_name)):
572
573
574
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
575
                assert embedding_modules is not None
Terry's avatar
Terry committed
576
                if parts[-1] in embedding_modules:
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
593
                        module.lora_a_stacked[0].dtype,
594
                        "cpu",
595
596
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
597
598
599
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
600
601
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
602
                        rank,
603
                        module.lora_a_stacked[0].dtype,
604
                        "cpu",
605
                        bias_enabled=bias_enabled,
606
607
608
609
610
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
611
                subloras: list[Optional[LoRALayerWeights]] = []
612
613
614
615
616
617
618
619
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
620
                        bias_enabled=bias_enabled,
621
622
623
624
625
626
627
628
629
630
631
632
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
633
            for target_module in self.supported_lora_modules)
634

635
636
637
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
638
        language model. LoRA for other modules, such as the vision tower, will
639
640
641
642
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
643
644
645
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
646
647
        return False

648
649
650
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
651
652
653
654
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
655
656
657
658
659
660
661
662
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
663
664
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
665
666
            has_replacement = False
            for r in new_module_names:
667
                lora = self._get_lora_layer_weights(lora_model, r)
668
669
670
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
671
                    replaced_module.add(r)
672
673
674
675
676
677
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
678
679
680
681
682
683
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(
                    module_name):
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
684
685
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
686
687
688
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
689

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def _get_lora_layer_weights(
            self, lora_model: LoRAModel,
            module_name: str) -> Optional[LoRALayerWeights]:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(
                module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'.")
        return lora_model.get_lora(org_module_name)

706
707
708
709
710
711
712
713
714
715
716
717
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
718

719
720
721
722
723
724
725
726
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

727
    def list_adapters(self) -> dict[int, Any]:
728
729
730
731
732
733
734
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
735

736
737
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
738
        super().__init__(capacity, deactivate_lora_fn)
739
740
741
742
743


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

744
745
746
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
747
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
748
                         vocab_size, lora_config, device)
749
750
751
752
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
753

754
    def list_adapters(self) -> dict[int, LoRAModel]:
755
        """List all registered LoRAModels."""
756
        return dict(self._registered_adapters.cache)
757

758
    def add_adapter(self, lora: LoRAModel) -> bool:
759
        """Add a LoRAModel to the manager."""
760
761
762
763
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
764
765
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
766
767
768
            was_added = True
        else:
            # We always touch to update the LRU cache order
769
            self._registered_adapters.touch(lora.id)
770
771
772
            was_added = False
        return was_added

773
    def activate_adapter(
774
775
776
        self,
        lora_id: int,
    ) -> bool:
777
778
779
780
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
781
        # We always touch to update the LRU cache order
782
        self._active_adapters.touch(lora_id)
783
784
        return result

785
786
787
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
788
789
790
            return True
        return False

791
    def pin_adapter(self, lora_id: int) -> bool:
792
793
794
795
796
797
798
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
799
            self._registered_adapters.pin(lora_id)
800
801
802
803
804
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
805
        if lora_id not in self._active_adapters:
806
            # move lora to gpu if not already active
807
            self.activate_adapter(lora_id)
808

809
        self._active_adapters.pin(lora_id)
810

811
812
813
814
815
816
817

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
818
        device: torch.device,
819
        lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
820
821
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
822
    if not isinstance(model, SupportsLoRA):
823
824
825
826
827
828
829
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
830
        device=device,
831
832
        **kwargs)
    return lora_manager