models.py 35.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import math
import os
5
from collections.abc import Sequence
6
from dataclasses import dataclass, field
7
from typing import Any, Callable, Optional, Union
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
15
16
17
18
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
19
from vllm.config import LoRAConfig
20
from vllm.logger import init_logger
21
from vllm.lora.layers import (BaseLayerWithLoRA,
22
                              LinearScalingRotaryEmbeddingWithLoRA,
23
                              LoRAMapping)
24
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
25
from vllm.lora.peft_helper import PEFTHelper
26
from vllm.lora.punica_wrapper import get_punica_wrapper
27
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
28
                             get_supported_lora_modules,
29
                             is_regex_target_modules,
30
                             parse_fine_tuned_lora_name, replace_submodule)
31
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
32
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
33
from vllm.model_executor.models.interfaces import is_pooling_model
34
from vllm.model_executor.models.module_mapping import MultiModelKeys
35
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
36
from vllm.model_executor.utils import get_packed_modules_mapping
37
from vllm.utils import is_pin_memory_available
38

39
logger = init_logger(__name__)
40
41
42
43

_GLOBAL_LORA_ID = 0


44
45
46
47
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
48
    scaling_factors: list[float]
49
50
51
52
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
53
    offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
54
55


56
57
58
59
60
61
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


62
class LoRAModel(AdapterModel):
63
64
65
66
67
68
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
69
        loras: dict[str, LoRALayerWeights],
70
        scaling_factor: Optional[float] = None,
71
    ) -> None:
72
73
74
75
76
77
78
79
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
80
        self.id = lora_model_id
81
82
83
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
84
85
86
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
87
        self.rank = rank
88
        self.loras: dict[str, LoRALayerWeights] = loras
89

90
91
92
93
94
95
96
97
98
99
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

100
101
102
103
104
105
106
107
108
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

109
110
111
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

112
113
114
115
116
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
117
        tensors: dict[str, torch.Tensor],
118
        peft_helper: PEFTHelper,
119
120
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
121
        embeddings: Optional[dict[str, torch.Tensor]] = None,
122
        target_embedding_padding: Optional[int] = None,
123
124
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
125
        weights_mapper: Optional[WeightsMapper] = None,
126
127
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
128
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
129
        loras: dict[str, LoRALayerWeights] = {}
130
        for tensor_name, tensor in tensors.items():
131
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
132
                tensor_name, weights_mapper)
133
134
135
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
136
                    assert embedding_modules is not None
137
                    embeddings_module = next(
Terry's avatar
Terry committed
138
                        (k for k in embedding_modules if k in module_name),
139
140
141
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
142
                            embedding_modules[embeddings_module]].to(
143
144
145
146
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
147
148
149
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

150
151
152
153
154
155
156
157
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
158
159
160
161
162
163
164
165
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
166
                assert embedding_padding_modules is not None
167
                if any(name in module_name
Terry's avatar
Terry committed
168
                       for name in embedding_padding_modules
169
170
171
172
173
174
175
176
177
178
179
180
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
181
182
183
184

        return cls(lora_model_id,
                   peft_helper.r,
                   loras,
185
                   scaling_factor=peft_helper.vllm_long_context_scaling_factor)
186
187
188

    @classmethod
    def from_local_checkpoint(
189
190
191
192
193
194
195
196
197
198
199
200
201
            cls,
            lora_dir: str,
            expected_lora_modules: list[str],
            peft_helper: PEFTHelper,
            *,
            lora_model_id: Optional[int] = None,
            device: str = "cuda",
            dtype: Optional[torch.dtype] = None,
            target_embedding_padding: Optional[int] = None,
            embedding_modules: Optional[dict[str, str]] = None,
            embedding_padding_modules: Optional[list[str]] = None,
            weights_mapper: Optional[WeightsMapper] = None,
            tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
202
        """Create a LoRAModel from a local checkpoint.
203

204
205
206
207
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
208
            peft_helper: Loaded lora configuration information.
209
            lora_model_id: LoRA model id. If not given, automatically set by
210
211
212
213
214
215
216
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
217
218
219
220
221
222
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
                    lora_module, weights_mapper)
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
                                            "adapter_model.tensors")
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
            tensors = TensorDeserializer(lora_tensor_path,
                                         dtype=tensorizer_config.dtype,
                                         **tensorizer_args.deserializer_params)
            check_unexpected_modules(tensors)
251

252
        elif os.path.isfile(lora_tensor_path):
253
254
255
256
257
258
259
260
261
262
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                # Load tensors if there are only expected modules.
263
                check_unexpected_modules(f)
264
265
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
266
        elif os.path.isfile(lora_bin_file_path):
267
268
269
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
270
            target_modules = peft_helper.target_modules
271
272
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
273
274
275
276
277
278
279
280
281
282
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
283
            if unexpected_modules and not is_regex_target_modules(
284
                    peft_helper.target_modules, expected_lora_modules):
285
286
287
288
289
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
cyyever's avatar
cyyever committed
290
291
292
            tensors = torch.load(lora_bin_file_path,
                                 map_location=device,
                                 weights_only=True)
293
294
295
296
297
298
299
300
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
301
            embeddings = torch.load(new_embeddings_bin_file_path,
302
303
                                    map_location=device,
                                    weights_only=True)
304
305
306
307
308

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
309
            peft_helper=peft_helper,
310
311
312
313
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
314
            embedding_modules=embedding_modules,
315
316
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
317
318


319
class LoRAModelManager(AdapterModelManager):
320
321
322
323
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
324
        model: SupportsLoRA,
325
326
327
328
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
329
        device: torch.device,
330
331
332
333
334
335
336
337
338
339
340
341
342
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
343
        self.device = device
344
345
346
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
347
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
348
        self.vocab_size = vocab_size
349
        self.long_lora_context: Optional[LongContextLoRAContext] = None
350
351
352
353
354
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
            max_loras=self.lora_config.max_loras)
355
356
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
357
        self.scaling_factor_to_offset: dict[float, int] = {}
358
        super().__init__(model)
359

360
361
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
362
        f" {self.model.__class__.__name__}."
363
364
365
366
        if lora_config.long_lora_scaling_factors:
            # We need to replace rotary emb layer to do batch computation
            # for long lora.
            self.supported_lora_modules.append("rotary_emb")
367
368

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
369
        # Used to indicate whether the model is a multimodal model
370
371
372
373
374
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
375
        self.is_pooling_model = is_pooling_model(self.model)
376
377
378
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
379
        self._last_mapping: Optional[LoRAMapping] = None
380
        self._create_lora_modules()
381
        self.model.lora_manager = self
382
        self.adapter_type = 'LoRA'
383
384
385
386
387
388
389
390
391

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

392
393
394
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
395

396
    def activate_adapter(
397
398
399
400
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
401
        if lora_id in self._active_adapters:
402
403
404
405
406
407
408
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
409
410
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
411
412
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
413
414
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
415
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
416
417
            if module_lora:
                module_lora.optimize()
418
419
420
421
422
423
424
425
426
427
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
428
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
429
430
                                module_lora.embeddings_tensor,
                                module_lora.bias)
431
432
433
434
            else:
                module.reset_lora(index)
        return True

435
    def _deactivate_adapter(self, lora_id: int):
436
437
438
439
440
441
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

457
    def _add_adapter(self, lora: LoRAModel):
458
        self._create_merged_loras_inplace(lora)
459
        self._registered_adapters[lora.id] = lora
460
        self._set_long_lora_context(lora)
461

462
    def pin_adapter(self, lora_id: int) -> bool:
463
464
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
465
            "Pinning is not supported in LoRAModelManager. "
466
467
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

468
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
469
470
471
472
473
474
475
476
477
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
478

479
    def remove_all_adapters(self):
480
        """Remove all LoRAModels from the manager."""
481
        self._registered_adapters.clear()
482
        self.lora_index_to_id = [None] * self.lora_slots
483
        self._active_adapters.clear()
484
485

    def _create_lora_modules(self):
486
487
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
488
489
            if isinstance(module, PPMissingLayer):
                continue
490
491
            if not self._match_target_modules(module_name):
                continue
492
493
494
495
496
497
498
499
500
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
501
502
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
503
504
505
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
506
                           packed_moduled_lst, self.model.config))
507

508
            # LinearScalingRotaryEmbeddingWithLoRA is used to handle
509
            # long context lora. Register relevant metadata.
510
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
511
512
513
514
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
515
516
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
517
518
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
519
                new_module = replace_submodule(
520
521
522
523
524
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
525
526
527
528
529
530
531
532
533

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
534
535
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
536
537
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
538
539
540
541
542

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
543
544
545
546
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
547
            scaling_factor: Optional[float],
548
            embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
549
        """Create zero-initialized LoRAModel for warmup."""
550
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
551
        for module_name, module in self.model.named_modules():
552
            bias_enabled = self.lora_config.bias_enabled
553
554
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
555
                    or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
556
                    or self._filter_unsupported_mm_module(module_name)):
557
558
559
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
560
                assert embedding_modules is not None
Terry's avatar
Terry committed
561
                if parts[-1] in embedding_modules:
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
578
                        module.lora_a_stacked[0].dtype,
579
                        "cpu",
580
581
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
582
583
584
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
585
586
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
587
                        rank,
588
                        module.lora_a_stacked[0].dtype,
589
                        "cpu",
590
                        bias_enabled=bias_enabled,
591
592
593
594
595
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
596
                subloras: list[Optional[LoRALayerWeights]] = []
597
598
599
600
601
602
603
604
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
605
                        bias_enabled=bias_enabled,
606
607
608
609
610
611
612
613
614
615
616
617
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
618
            for target_module in self.supported_lora_modules)
619

620
621
622
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
623
        language model. LoRA for other modules, such as the vision tower, will
624
625
626
627
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
628
629
630
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
631
632
        return False

633
634
635
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
636
637
638
639
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
640
641
642
643
644
645
646
647
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
648
649
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
650
651
            has_replacement = False
            for r in new_module_names:
652
                lora = self._get_lora_layer_weights(lora_model, r)
653
654
655
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
656
                    replaced_module.add(r)
657
658
659
660
661
662
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
663
664
665
666
667
668
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(
                    module_name):
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
669
670
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
671
672
673
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
674

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
    def _get_lora_layer_weights(
            self, lora_model: LoRAModel,
            module_name: str) -> Optional[LoRALayerWeights]:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(
                module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'.")
        return lora_model.get_lora(org_module_name)

691
692
693
694
695
696
697
698
699
700
701
702
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
703

704
705
706
707
708
709
710
711
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

712
    def list_adapters(self) -> dict[int, Any]:
713
714
715
716
717
718
719
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
720

721
722
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
723
        super().__init__(capacity, deactivate_lora_fn)
724
725
726
727
728


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

729
730
731
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
732
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
733
                         vocab_size, lora_config, device)
734
735
736
737
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
738

739
    def list_adapters(self) -> dict[int, LoRAModel]:
740
        """List all registered LoRAModels."""
741
        return dict(self._registered_adapters.cache)
742

743
    def add_adapter(self, lora: LoRAModel) -> bool:
744
        """Add a LoRAModel to the manager."""
745
746
747
748
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
749
750
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
751
752
753
            was_added = True
        else:
            # We always touch to update the LRU cache order
754
            self._registered_adapters.touch(lora.id)
755
756
757
            was_added = False
        return was_added

758
    def activate_adapter(
759
760
761
        self,
        lora_id: int,
    ) -> bool:
762
763
764
765
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
766
        # We always touch to update the LRU cache order
767
        self._active_adapters.touch(lora_id)
768
769
        return result

770
771
772
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
773
774
775
            return True
        return False

776
    def pin_adapter(self, lora_id: int) -> bool:
777
778
779
780
781
782
783
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
784
            self._registered_adapters.pin(lora_id)
785
786
787
788
789
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
790
        if lora_id not in self._active_adapters:
791
            # move lora to gpu if not already active
792
            self.activate_adapter(lora_id)
793

794
        self._active_adapters.pin(lora_id)
795

796
797
798
799
800
801
802

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
803
        device: torch.device,
804
        lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
805
806
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
807
    if not hasattr(model, "packed_modules_mapping"):
808
809
810
811
812
813
814
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
815
        device=device,
816
817
        **kwargs)
    return lora_manager