models.py 35.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
from collections.abc import Sequence
7
from dataclasses import dataclass, field
8
from typing import Any, Callable, Optional, Union
9

10
import regex as re
11
12
13
14
import safetensors.torch
import torch
from torch import nn

15
16
17
18
19
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
20
from vllm.config import LoRAConfig
21
from vllm.logger import init_logger
22
from vllm.lora.layers import (BaseLayerWithLoRA,
23
                              LinearScalingRotaryEmbeddingWithLoRA,
24
                              LoRAMapping)
25
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
26
from vllm.lora.peft_helper import PEFTHelper
27
from vllm.lora.punica_wrapper import get_punica_wrapper
28
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
29
                             get_supported_lora_modules,
30
                             is_regex_target_modules,
31
                             parse_fine_tuned_lora_name, replace_submodule)
32
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
33
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
34
from vllm.model_executor.models.interfaces import is_pooling_model
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
37
from vllm.model_executor.utils import get_packed_modules_mapping
38
from vllm.utils import is_pin_memory_available
39

40
logger = init_logger(__name__)
41
42
43
44

_GLOBAL_LORA_ID = 0


45
46
47
48
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
49
    scaling_factors: list[float]
50
51
52
53
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
54
    offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
55
56


57
58
59
60
61
62
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


63
class LoRAModel(AdapterModel):
64
65
66
67
68
69
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
70
        loras: dict[str, LoRALayerWeights],
71
        scaling_factor: Optional[float] = None,
72
    ) -> None:
73
74
75
76
77
78
79
80
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
81
        self.id = lora_model_id
82
83
84
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
85
86
87
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
88
        self.rank = rank
89
        self.loras: dict[str, LoRALayerWeights] = loras
90

91
92
93
94
95
96
97
98
99
100
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

101
102
103
104
105
106
107
108
109
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

110
111
112
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

113
114
115
116
117
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
118
        tensors: dict[str, torch.Tensor],
119
        peft_helper: PEFTHelper,
120
121
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
122
        embeddings: Optional[dict[str, torch.Tensor]] = None,
123
        target_embedding_padding: Optional[int] = None,
124
125
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
126
        weights_mapper: Optional[WeightsMapper] = None,
127
128
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
129
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
130
        loras: dict[str, LoRALayerWeights] = {}
131
        for tensor_name, tensor in tensors.items():
132
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
133
                tensor_name, weights_mapper)
134
135
136
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
137
                    assert embedding_modules is not None
138
                    embeddings_module = next(
Terry's avatar
Terry committed
139
                        (k for k in embedding_modules if k in module_name),
140
141
142
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
143
                            embedding_modules[embeddings_module]].to(
144
145
146
147
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
148
149
150
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

151
152
153
154
155
156
157
158
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
159
160
161
162
163
164
165
166
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
167
                assert embedding_padding_modules is not None
168
                if any(name in module_name
Terry's avatar
Terry committed
169
                       for name in embedding_padding_modules
170
171
172
173
174
175
176
177
178
179
180
181
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
182
183
184
185

        return cls(lora_model_id,
                   peft_helper.r,
                   loras,
186
                   scaling_factor=peft_helper.vllm_long_context_scaling_factor)
187
188
189

    @classmethod
    def from_local_checkpoint(
190
191
192
193
194
195
196
197
198
199
200
201
202
            cls,
            lora_dir: str,
            expected_lora_modules: list[str],
            peft_helper: PEFTHelper,
            *,
            lora_model_id: Optional[int] = None,
            device: str = "cuda",
            dtype: Optional[torch.dtype] = None,
            target_embedding_padding: Optional[int] = None,
            embedding_modules: Optional[dict[str, str]] = None,
            embedding_padding_modules: Optional[list[str]] = None,
            weights_mapper: Optional[WeightsMapper] = None,
            tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
203
        """Create a LoRAModel from a local checkpoint.
204

205
206
207
208
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
209
            peft_helper: Loaded lora configuration information.
210
            lora_model_id: LoRA model id. If not given, automatically set by
211
212
213
214
215
216
217
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
218
219
220
221
222
223
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
                    lora_module, weights_mapper)
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
                                            "adapter_model.tensors")
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
248
249
250
251
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
                **tensorizer_args.deserialization_kwargs)
252
            check_unexpected_modules(tensors)
253

254
        elif os.path.isfile(lora_tensor_path):
255
256
257
258
259
260
261
262
263
264
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                # Load tensors if there are only expected modules.
265
                check_unexpected_modules(f)
266
267
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
268
        elif os.path.isfile(lora_bin_file_path):
269
270
271
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
272
            target_modules = peft_helper.target_modules
273
274
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
275
276
277
278
279
280
281
282
283
284
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
285
            if unexpected_modules and not is_regex_target_modules(
286
                    peft_helper.target_modules, expected_lora_modules):
287
288
289
290
291
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
cyyever's avatar
cyyever committed
292
293
294
            tensors = torch.load(lora_bin_file_path,
                                 map_location=device,
                                 weights_only=True)
295
296
297
298
299
300
301
302
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
303
            embeddings = torch.load(new_embeddings_bin_file_path,
304
305
                                    map_location=device,
                                    weights_only=True)
306
307
308
309
310

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
311
            peft_helper=peft_helper,
312
313
314
315
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
316
            embedding_modules=embedding_modules,
317
318
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
319
320


321
class LoRAModelManager(AdapterModelManager):
322
323
324
325
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
326
        model: SupportsLoRA,
327
328
329
330
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
331
        device: torch.device,
332
333
334
335
336
337
338
339
340
341
342
343
344
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
345
        self.device = device
346
347
348
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
349
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
350
        self.vocab_size = vocab_size
351
        self.long_lora_context: Optional[LongContextLoRAContext] = None
352
353
354
355
356
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
            max_loras=self.lora_config.max_loras)
357
358
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
359
        self.scaling_factor_to_offset: dict[float, int] = {}
360
        super().__init__(model)
361

362
363
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
364
        f" {self.model.__class__.__name__}."
365
366
367
368
        if lora_config.long_lora_scaling_factors:
            # We need to replace rotary emb layer to do batch computation
            # for long lora.
            self.supported_lora_modules.append("rotary_emb")
369
370

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
371
        # Used to indicate whether the model is a multimodal model
372
373
374
375
376
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
377
        self.is_pooling_model = is_pooling_model(self.model)
378
379
380
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
381
        self._last_mapping: Optional[LoRAMapping] = None
382
        self._create_lora_modules()
383
        self.model.lora_manager = self
384
        self.adapter_type = 'LoRA'
385
386
387
388
389
390
391
392
393

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

394
395
396
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
397

398
    def activate_adapter(
399
400
401
402
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
403
        if lora_id in self._active_adapters:
404
405
406
407
408
409
410
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
411
412
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
413
414
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
415
416
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
417
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
418
419
            if module_lora:
                module_lora.optimize()
420
421
422
423
424
425
426
427
428
429
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
430
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
431
432
                                module_lora.embeddings_tensor,
                                module_lora.bias)
433
434
435
436
            else:
                module.reset_lora(index)
        return True

437
    def _deactivate_adapter(self, lora_id: int):
438
439
440
441
442
443
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

459
    def _add_adapter(self, lora: LoRAModel):
460
        self._create_merged_loras_inplace(lora)
461
        self._registered_adapters[lora.id] = lora
462
        self._set_long_lora_context(lora)
463

464
    def pin_adapter(self, lora_id: int) -> bool:
465
466
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
467
            "Pinning is not supported in LoRAModelManager. "
468
469
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

470
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
471
472
473
474
475
476
477
478
479
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
480

481
    def remove_all_adapters(self):
482
        """Remove all LoRAModels from the manager."""
483
        self._registered_adapters.clear()
484
        self.lora_index_to_id = [None] * self.lora_slots
485
        self._active_adapters.clear()
486
487

    def _create_lora_modules(self):
488
489
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
490
491
            if isinstance(module, PPMissingLayer):
                continue
492
493
            if not self._match_target_modules(module_name):
                continue
494
495
496
497
498
499
500
501
502
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
503
504
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
505
506
507
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
508
                           packed_moduled_lst, self.model.config))
509

510
            # LinearScalingRotaryEmbeddingWithLoRA is used to handle
511
            # long context lora. Register relevant metadata.
512
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
513
514
515
516
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
517
518
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
519
520
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
521
                new_module = replace_submodule(
522
523
524
525
526
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
527
528
529
530
531
532
533
534
535

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
536
537
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
538
539
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
540
541
542
543
544

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
545
546
547
548
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
549
            scaling_factor: Optional[float],
550
            embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
551
        """Create zero-initialized LoRAModel for warmup."""
552
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
553
        for module_name, module in self.model.named_modules():
554
            bias_enabled = self.lora_config.bias_enabled
555
556
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
557
                    or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
558
                    or self._filter_unsupported_mm_module(module_name)):
559
560
561
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
562
                assert embedding_modules is not None
Terry's avatar
Terry committed
563
                if parts[-1] in embedding_modules:
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
580
                        module.lora_a_stacked[0].dtype,
581
                        "cpu",
582
583
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
584
585
586
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
587
588
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
589
                        rank,
590
                        module.lora_a_stacked[0].dtype,
591
                        "cpu",
592
                        bias_enabled=bias_enabled,
593
594
595
596
597
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
598
                subloras: list[Optional[LoRALayerWeights]] = []
599
600
601
602
603
604
605
606
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
607
                        bias_enabled=bias_enabled,
608
609
610
611
612
613
614
615
616
617
618
619
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
620
            for target_module in self.supported_lora_modules)
621

622
623
624
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
625
        language model. LoRA for other modules, such as the vision tower, will
626
627
628
629
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
630
631
632
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
633
634
        return False

635
636
637
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
638
639
640
641
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
642
643
644
645
646
647
648
649
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
650
651
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
652
653
            has_replacement = False
            for r in new_module_names:
654
                lora = self._get_lora_layer_weights(lora_model, r)
655
656
657
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
658
                    replaced_module.add(r)
659
660
661
662
663
664
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
665
666
667
668
669
670
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(
                    module_name):
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
671
672
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
673
674
675
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
676

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
    def _get_lora_layer_weights(
            self, lora_model: LoRAModel,
            module_name: str) -> Optional[LoRALayerWeights]:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(
                module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'.")
        return lora_model.get_lora(org_module_name)

693
694
695
696
697
698
699
700
701
702
703
704
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
705

706
707
708
709
710
711
712
713
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

714
    def list_adapters(self) -> dict[int, Any]:
715
716
717
718
719
720
721
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
722

723
724
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
725
        super().__init__(capacity, deactivate_lora_fn)
726
727
728
729
730


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

731
732
733
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
734
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
735
                         vocab_size, lora_config, device)
736
737
738
739
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
740

741
    def list_adapters(self) -> dict[int, LoRAModel]:
742
        """List all registered LoRAModels."""
743
        return dict(self._registered_adapters.cache)
744

745
    def add_adapter(self, lora: LoRAModel) -> bool:
746
        """Add a LoRAModel to the manager."""
747
748
749
750
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
751
752
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
753
754
755
            was_added = True
        else:
            # We always touch to update the LRU cache order
756
            self._registered_adapters.touch(lora.id)
757
758
759
            was_added = False
        return was_added

760
    def activate_adapter(
761
762
763
        self,
        lora_id: int,
    ) -> bool:
764
765
766
767
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
768
        # We always touch to update the LRU cache order
769
        self._active_adapters.touch(lora_id)
770
771
        return result

772
773
774
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
775
776
777
            return True
        return False

778
    def pin_adapter(self, lora_id: int) -> bool:
779
780
781
782
783
784
785
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
786
            self._registered_adapters.pin(lora_id)
787
788
789
790
791
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
792
        if lora_id not in self._active_adapters:
793
            # move lora to gpu if not already active
794
            self.activate_adapter(lora_id)
795

796
        self._active_adapters.pin(lora_id)
797

798
799
800
801
802
803
804

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
805
        device: torch.device,
806
        lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
807
808
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
809
    if not isinstance(model, SupportsLoRA):
810
811
812
813
814
815
816
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
817
        device=device,
818
819
        **kwargs)
    return lora_manager