models.py 34.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
from collections.abc import Sequence
7
from typing import Callable, Optional, TypeVar, Union
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
21
                             get_supported_lora_modules,
22
                             is_regex_target_modules,
23
                             parse_fine_tuned_lora_name, replace_submodule)
24
from vllm.model_executor.layers.fused_moe import FusedMoE
25
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
26
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
27
from vllm.model_executor.models.interfaces import is_pooling_model
28
from vllm.model_executor.models.module_mapping import MultiModelKeys
29
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
30
from vllm.model_executor.utils import get_packed_modules_mapping
31
from vllm.utils import LRUCache, is_pin_memory_available
32

33
logger = init_logger(__name__)
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):

    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

    def _on_remove(self, key: int, value: Optional[T]):
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


50
51
52
53
54
55
56
57
58
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


59
60
61
62
63
64
65
66
67
68
69
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
            "contain expert weights.")
        return True
    return False


70
class LoRAModel:
71
72
73
74
75
76
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
77
        loras: dict[str, LoRALayerWeights],
78
    ) -> None:
79
80
81
82
83
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
84

85
        """
86
        self.id = lora_model_id
87

88
89
90
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
91
        self.rank = rank
92
        self.loras: dict[str, LoRALayerWeights] = loras
93

94
95
96
97
98
99
100
101
102
103
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

104
105
106
107
108
109
110
111
112
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

113
114
115
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

116
117
118
119
120
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
121
        tensors: dict[str, torch.Tensor],
122
        peft_helper: PEFTHelper,
123
124
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
125
        embeddings: Optional[dict[str, torch.Tensor]] = None,
126
        target_embedding_padding: Optional[int] = None,
127
128
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
129
        weights_mapper: Optional[WeightsMapper] = None,
130
131
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
132
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
133
        loras: dict[str, LoRALayerWeights] = {}
134
        for tensor_name, tensor in tensors.items():
135
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
136
                tensor_name, weights_mapper)
137
138
139
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
140
                    assert embedding_modules is not None
141
                    embeddings_module = next(
Terry's avatar
Terry committed
142
                        (k for k in embedding_modules if k in module_name),
143
144
145
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
146
                            embedding_modules[embeddings_module]].to(
147
148
149
150
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
151
152
153
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

154
155
156
157
158
159
160
161
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
162
163
164
165
166
167
168
169
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
170
                assert embedding_padding_modules is not None
171
                if any(name in module_name
Terry's avatar
Terry committed
172
                       for name in embedding_padding_modules
173
174
175
176
177
178
179
180
181
182
183
184
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
185

186
        return cls(lora_model_id, peft_helper.r, loras)
187
188
189

    @classmethod
    def from_local_checkpoint(
190
191
192
193
194
195
196
197
198
199
200
201
202
            cls,
            lora_dir: str,
            expected_lora_modules: list[str],
            peft_helper: PEFTHelper,
            *,
            lora_model_id: Optional[int] = None,
            device: str = "cuda",
            dtype: Optional[torch.dtype] = None,
            target_embedding_padding: Optional[int] = None,
            embedding_modules: Optional[dict[str, str]] = None,
            embedding_padding_modules: Optional[list[str]] = None,
            weights_mapper: Optional[WeightsMapper] = None,
            tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
203
        """Create a LoRAModel from a local checkpoint.
204

205
206
207
208
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
209
            peft_helper: Loaded lora configuration information.
210
            lora_model_id: LoRA model id. If not given, automatically set by
211
212
213
214
215
216
217
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
218
219
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
220
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
221
222
223
224
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
                    lora_module, weights_mapper)
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
                                            "adapter_model.tensors")
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
249
250
251
252
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
                **tensorizer_args.deserialization_kwargs)
253
            check_unexpected_modules(tensors)
254

255
        elif os.path.isfile(lora_tensor_path):
256
257
258
259
260
261
262
263
264
265
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                # Load tensors if there are only expected modules.
266
                check_unexpected_modules(f)
267
268
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
269
270
271
272
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
                lora_pt_file_path):
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
273
            unexpected_modules = []
274
            target_modules = peft_helper.target_modules
275
276
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
277
278
279
280
281
282
283
284
285
286
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
287
            if unexpected_modules and not is_regex_target_modules(
288
                    peft_helper.target_modules, expected_lora_modules):
289
290
291
292
293
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
294
295
296
297
            lora_file_path = (lora_bin_file_path
                              if os.path.isfile(lora_bin_file_path) else
                              lora_pt_file_path)
            tensors = torch.load(lora_file_path,
cyyever's avatar
cyyever committed
298
299
                                 map_location=device,
                                 weights_only=True)
300
301
302
303
304
305
306
307
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
308
            embeddings = torch.load(new_embeddings_bin_file_path,
309
310
                                    map_location=device,
                                    weights_only=True)
311
312
313
314
315

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
316
            peft_helper=peft_helper,
317
318
319
320
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
321
            embedding_modules=embedding_modules,
322
323
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
324
325


326
class LoRAModelManager:
327
328
329
330
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
331
        model: SupportsLoRA,
332
333
334
335
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
336
        device: torch.device,
337
338
339
340
341
342
343
344
345
346
347
348
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
349
350
351
352
353
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
354
        self.lora_config = lora_config
355
        self.device = device
356
357
358
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
359
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
360
        self.vocab_size = vocab_size
361
362
363
364
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
365
366
            max_loras=self.lora_config.max_loras,
        )
367

368
369
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
370
        f" {self.model.__class__.__name__}."
371
372

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
373
        # Used to indicate whether the model is a multimodal model
374
375
376
377
378
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
379
        self.is_pooling_model = is_pooling_model(self.model)
380
        self.is_moe_model = is_moe_model(self.model)
381
382
383
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
384
        self._last_mapping: Optional[LoRAMapping] = None
385
        self._create_lora_modules()
386
        self.model.lora_manager = self
387
388
389

    def __len__(self) -> int:
        return len(self._registered_adapters)
390
391
392
393
394
395
396
397
398

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

399
400
401
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
402

403
    def activate_adapter(
404
405
406
407
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
408
        if lora_id in self._active_adapters:
409
410
411
412
413
414
415
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
416
417
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
418
419
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
420
421
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
422
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
423
424
            if module_lora:
                module_lora.optimize()
425
426
427
428
429
430
431
432
433
434
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
435
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
436
437
                                module_lora.embeddings_tensor,
                                module_lora.bias)
438
439
440
441
            else:
                module.reset_lora(index)
        return True

442
    def _deactivate_adapter(self, lora_id: int):
443
444
445
446
447
448
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

449
    def _add_adapter(self, lora: LoRAModel):
450
        self._create_merged_loras_inplace(lora)
451
        self._registered_adapters[lora.id] = lora
452

453
    def pin_adapter(self, lora_id: int) -> bool:
454
455
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
456
            "Pinning is not supported in LoRAModelManager. "
457
458
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

459
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
460
461
462
463
464
465
466
467
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
468

469
    def remove_all_adapters(self):
470
        """Remove all LoRAModels from the manager."""
471
        self._registered_adapters.clear()
472
        self.lora_index_to_id = [None] * self.lora_slots
473
        self._active_adapters.clear()
474
475

    def _create_lora_modules(self):
476
477
478
479
480
481
482
483

        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
            return module_name.rpartition('.')[0]

484
485
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
486
487
            if isinstance(module, PPMissingLayer):
                continue
488
489
            if not self._match_target_modules(module_name):
                continue
490
491
492
493
494
495
496
497
498
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
499
500
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
501
502
503
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
504
                           packed_moduled_lst, self.model.config))
505

506
507
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
508
509
510
511
512
513
                logits_processor_module_name = 'logits_processor'
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
                        f"{parent_module}.{logits_processor_module_name}")

514
                logits_processor_module = self.model.get_submodule(
515
516
                    logits_processor_module_name)

517
                new_module = replace_submodule(
518
                    self.model, logits_processor_module_name,
519
520
521
522
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
523
524
525
526
527
528
529
530
531

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
532
533
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
534
535
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
536
537
538
539
540

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
541
542
543
544
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
545
            embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
546
        """Create zero-initialized LoRAModel for warmup."""
547
        model = LoRAModel(lora_id, rank, {})
548
        for module_name, module in self.model.named_modules():
549
            bias_enabled = self.lora_config.bias_enabled
550
551
552
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
                    or self._filter_unsupported_mm_module(module_name)):
553
554
555
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
556
                assert embedding_modules is not None
Terry's avatar
Terry committed
557
                if parts[-1] in embedding_modules:
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
574
                        module.lora_a_stacked[0].dtype,
575
                        "cpu",
576
577
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
578
579
580
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
581
582
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
583
                        rank,
584
                        module.lora_a_stacked[0].dtype,
585
                        "cpu",
586
                        bias_enabled=bias_enabled,
587
588
589
590
591
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
592
                subloras: list[Optional[LoRALayerWeights]] = []
593
594
595
596
597
598
599
600
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
601
                        bias_enabled=bias_enabled,
602
603
604
605
606
607
608
609
610
611
612
613
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
614
            for target_module in self.supported_lora_modules)
615

616
617
618
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
619
        language model. LoRA for other modules, such as the vision tower, will
620
621
622
623
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
624
625
626
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
627
628
        return False

629
630
631
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
632
633
634
635
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
636
637
638
639
640
641
642
643
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
644
645
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
646
647
            has_replacement = False
            for r in new_module_names:
648
                lora = self._get_lora_layer_weights(lora_model, r)
649
650
651
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
652
                    replaced_module.add(r)
653
654
655
656
657
658
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
659
660
661
662
663
664
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(
                    module_name):
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
665
666
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
667
668
669
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
670

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
    def _get_lora_layer_weights(
            self, lora_model: LoRAModel,
            module_name: str) -> Optional[LoRALayerWeights]:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(
                module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'.")
        return lora_model.get_lora(org_module_name)

687
    def deactivate_adapter(self, adapter_id: int) -> bool:
688
689
690
691
692
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
693
694

    def add_adapter(self, adapter: LoRAModel) -> bool:
695
696
        logger.debug("Adding lora. Model id: %d, "
                     "int id: %d", adapter.id, adapter.id)
697
698
699
700
701
702
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
703

704
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
705
706
707
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
708
709

    def remove_adapter(self, adapter_id: int) -> bool:
710
711
712
713
714
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
715

716
717
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
718

719
720
    def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]:
        return self._registered_adapters.get(adapter_id)
721
722
723


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
724

725
726
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
727
        super().__init__(capacity, deactivate_lora_fn)
728
729
730
731
732


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

733
734
735
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
736
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
737
                         vocab_size, lora_config, device)
738
739
740
741
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
742

743
    def list_adapters(self) -> dict[int, LoRAModel]:
744
        """List all registered LoRAModels."""
745
        return dict(self._registered_adapters.cache)
746

747
    def add_adapter(self, lora: LoRAModel) -> bool:
748
        """Add a LoRAModel to the manager."""
749
750
        logger.debug("Adding lora. Model id: %d, "
                     "int id: %d", lora.id, lora.id)
751
752
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
753
754
755
            was_added = True
        else:
            # We always touch to update the LRU cache order
756
            self._registered_adapters.touch(lora.id)
757
758
759
            was_added = False
        return was_added

760
    def activate_adapter(
761
762
763
        self,
        lora_id: int,
    ) -> bool:
764
765
766
767
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
768
        # We always touch to update the LRU cache order
769
        self._active_adapters.touch(lora_id)
770
771
        return result

772
773
774
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
775
776
777
            return True
        return False

778
    def pin_adapter(self, lora_id: int) -> bool:
779
780
781
782
783
784
785
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
786
            self._registered_adapters.pin(lora_id)
787
788
789
790
791
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
792
        if lora_id not in self._active_adapters:
793
            # move lora to gpu if not already active
794
            self.activate_adapter(lora_id)
795

796
        self._active_adapters.pin(lora_id)
797

798
799
800
801
802
803
804

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
805
        device: torch.device,
806
        lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
807
808
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
809
    if not isinstance(model, SupportsLoRA):
810
811
812
813
814
815
816
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
817
        device=device,
818
819
        **kwargs)
    return lora_manager