models.py 34.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
from collections.abc import Sequence
7
from typing import Callable, Optional, TypeVar, Union
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
21
                             get_supported_lora_modules,
22
                             is_regex_target_modules,
23
                             parse_fine_tuned_lora_name, replace_submodule)
24
from vllm.model_executor.layers.fused_moe import FusedMoE
25
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
26
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
27
from vllm.model_executor.models.interfaces import is_pooling_model
28
from vllm.model_executor.models.module_mapping import MultiModelKeys
29
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
30
from vllm.model_executor.utils import get_packed_modules_mapping
31
from vllm.utils import LRUCache, is_pin_memory_available
32

33
logger = init_logger(__name__)
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):

    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

    def _on_remove(self, key: int, value: Optional[T]):
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


50
51
52
53
54
55
56
57
58
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


59
60
61
62
63
64
65
66
67
68
69
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
            "contain expert weights.")
        return True
    return False


70
class LoRAModel:
71
72
73
74
75
76
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
77
        loras: dict[str, LoRALayerWeights],
78
    ) -> None:
79
80
81
82
83
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
84

85
        """
86
        self.id = lora_model_id
87

88
89
90
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
91
        self.rank = rank
92
        self.loras: dict[str, LoRALayerWeights] = loras
93

94
95
96
97
98
99
100
101
102
103
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

104
105
106
107
108
109
110
111
112
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

113
114
115
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

116
117
118
119
120
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
121
        tensors: dict[str, torch.Tensor],
122
        peft_helper: PEFTHelper,
123
124
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
125
        embeddings: Optional[dict[str, torch.Tensor]] = None,
126
        target_embedding_padding: Optional[int] = None,
127
128
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
129
        weights_mapper: Optional[WeightsMapper] = None,
130
131
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
132
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
133
        loras: dict[str, LoRALayerWeights] = {}
134
        for tensor_name, tensor in tensors.items():
135
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
136
                tensor_name, weights_mapper)
137
138
139
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
140
                    assert embedding_modules is not None
141
                    embeddings_module = next(
Terry's avatar
Terry committed
142
                        (k for k in embedding_modules if k in module_name),
143
144
145
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
146
                            embedding_modules[embeddings_module]].to(
147
148
149
150
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
151
152
153
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

154
            if is_bias:
155
156
                loras[module_name].bias = tensor.to(device=device, dtype=dtype)
                bias = tensor.to(device=device, dtype=dtype)
157
158
159
160
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
161
                loras[module_name].lora_a = tensor.to(device=device,
162
                                                      dtype=dtype)
163
164
165
166
167
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
168
                                                      dtype=dtype)
169
                assert embedding_padding_modules is not None
170
                if any(name in module_name
Terry's avatar
Terry committed
171
                       for name in embedding_padding_modules
172
173
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
174
175
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
176
                    loras[module_name].lora_b = torch.nn.functional.pad(
177
                        lora_b, (0, 0, 0, addition))
178
179
180
181
182
183
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
184

185
        return cls(lora_model_id, peft_helper.r, loras)
186
187
188

    @classmethod
    def from_local_checkpoint(
189
190
191
192
193
194
195
196
197
198
199
200
201
            cls,
            lora_dir: str,
            expected_lora_modules: list[str],
            peft_helper: PEFTHelper,
            *,
            lora_model_id: Optional[int] = None,
            device: str = "cuda",
            dtype: Optional[torch.dtype] = None,
            target_embedding_padding: Optional[int] = None,
            embedding_modules: Optional[dict[str, str]] = None,
            embedding_padding_modules: Optional[list[str]] = None,
            weights_mapper: Optional[WeightsMapper] = None,
            tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
202
        """Create a LoRAModel from a local checkpoint.
203

204
205
206
207
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
208
            peft_helper: Loaded lora configuration information.
209
            lora_model_id: LoRA model id. If not given, automatically set by
210
211
212
213
214
215
216
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
217
218
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
219
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
220
221
222
223
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
                module_name, _, _ = parse_fine_tuned_lora_name(
                    lora_module, weights_mapper)
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
            lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
                                            "adapter_model.tensors")
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
248
249
250
251
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
                **tensorizer_args.deserialization_kwargs)
252
            check_unexpected_modules(tensors)
253

254
        elif os.path.isfile(lora_tensor_path):
255
256
257
258
259
260
261
262
263
264
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                # Load tensors if there are only expected modules.
265
                check_unexpected_modules(f)
266
267
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
268
269
270
271
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
                lora_pt_file_path):
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
272
            unexpected_modules = []
273
            target_modules = peft_helper.target_modules
274
275
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
276
277
278
279
280
281
282
283
284
285
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
286
            if unexpected_modules and not is_regex_target_modules(
287
                    peft_helper.target_modules, expected_lora_modules):
288
289
290
291
292
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
293
294
295
296
            lora_file_path = (lora_bin_file_path
                              if os.path.isfile(lora_bin_file_path) else
                              lora_pt_file_path)
            tensors = torch.load(lora_file_path,
cyyever's avatar
cyyever committed
297
298
                                 map_location=device,
                                 weights_only=True)
299
300
301
302
303
304
305
306
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
307
            embeddings = torch.load(new_embeddings_bin_file_path,
308
309
                                    map_location=device,
                                    weights_only=True)
310
311
312
313
314

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
315
            peft_helper=peft_helper,
316
317
318
319
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
320
            embedding_modules=embedding_modules,
321
322
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
323
324


325
class LoRAModelManager:
326
327
328
329
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
330
        model: SupportsLoRA,
331
332
333
334
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
335
        device: torch.device,
336
337
338
339
340
341
342
343
344
345
346
347
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
348
349
350
351
352
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
353
        self.lora_config = lora_config
354
        self.device = device
355
356
357
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
358
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
359
        self.vocab_size = vocab_size
360
361
362
363
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
364
365
            max_loras=self.lora_config.max_loras,
        )
366

367
368
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
369
        f" {self.model.__class__.__name__}."
370
371

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
372
        # Used to indicate whether the model is a multimodal model
373
374
375
376
377
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
378
        self.is_pooling_model = is_pooling_model(self.model)
379
        self.is_moe_model = is_moe_model(self.model)
380
381
382
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
383
        self._last_mapping: Optional[LoRAMapping] = None
384
        self._create_lora_modules()
385
        self.model.lora_manager = self
386
387
388

    def __len__(self) -> int:
        return len(self._registered_adapters)
389
390
391
392
393
394
395
396
397

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

398
399
400
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
401

402
    def activate_adapter(
403
404
405
406
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
407
        if lora_id in self._active_adapters:
408
409
410
411
412
413
414
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
415
416
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
417
418
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
419
420
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
421
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
422
423
            if module_lora:
                module_lora.optimize()
424
425
426
427
428
429
430
431
432
433
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
434
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
435
436
                                module_lora.embeddings_tensor,
                                module_lora.bias)
437
438
439
440
            else:
                module.reset_lora(index)
        return True

441
    def _deactivate_adapter(self, lora_id: int):
442
443
444
445
446
447
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

448
    def _add_adapter(self, lora: LoRAModel):
449
        self._create_merged_loras_inplace(lora)
450
        self._registered_adapters[lora.id] = lora
451

452
    def pin_adapter(self, lora_id: int) -> bool:
453
454
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
455
            "Pinning is not supported in LoRAModelManager. "
456
457
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

458
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
459
460
461
462
463
464
465
466
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
467

468
    def remove_all_adapters(self):
469
        """Remove all LoRAModels from the manager."""
470
        self._registered_adapters.clear()
471
        self.lora_index_to_id = [None] * self.lora_slots
472
        self._active_adapters.clear()
473
474

    def _create_lora_modules(self):
475
476
477
478
479
480
481
482

        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
            return module_name.rpartition('.')[0]

483
484
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
485
486
            if isinstance(module, PPMissingLayer):
                continue
487
488
            if not self._match_target_modules(module_name):
                continue
489
490
491
492
493
494
495
496
497
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
498
499
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
500
501
502
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
503
                           packed_moduled_lst, self.model.config))
504

505
506
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
507
508
509
510
511
512
                logits_processor_module_name = 'logits_processor'
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
                        f"{parent_module}.{logits_processor_module_name}")

513
                logits_processor_module = self.model.get_submodule(
514
515
                    logits_processor_module_name)

516
                new_module = replace_submodule(
517
                    self.model, logits_processor_module_name,
518
519
520
521
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
522
523
524
525
526
527
528
529
530

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
531
532
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
533
534
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
535
536
537
538
539

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
540
541
542
543
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
544
            embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
545
        """Create zero-initialized LoRAModel for warmup."""
546
        model = LoRAModel(lora_id, rank, {})
547
        for module_name, module in self.model.named_modules():
548
            bias_enabled = self.lora_config.bias_enabled
549
550
551
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
                    or self._filter_unsupported_mm_module(module_name)):
552
553
554
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
555
                assert embedding_modules is not None
Terry's avatar
Terry committed
556
                if parts[-1] in embedding_modules:
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
573
                        module.lora_a_stacked[0].dtype,
574
                        "cpu",
575
576
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
577
578
579
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
580
581
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
582
                        rank,
583
                        module.lora_a_stacked[0].dtype,
584
                        "cpu",
585
                        bias_enabled=bias_enabled,
586
587
588
589
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
590
                subloras: list[Optional[LoRALayerWeights]] = []
591
592
593
594
595
596
597
598
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
599
                        bias_enabled=bias_enabled,
600
601
602
603
604
605
606
607
608
609
610
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
611
            for target_module in self.supported_lora_modules)
612

613
614
615
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
616
        language model. LoRA for other modules, such as the vision tower, will
617
618
619
620
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
621
622
623
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
624
625
        return False

626
627
628
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
629
630
631
632
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
633
634
635
636
637
638
639
640
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
641
642
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
643
644
            has_replacement = False
            for r in new_module_names:
645
                lora = self._get_lora_layer_weights(lora_model, r)
646
647
648
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
649
                    replaced_module.add(r)
650
651
652
653
654
655
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
656
657
658
659
660
661
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(
                    module_name):
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
662
663
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
664
665
666
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
667

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    def _get_lora_layer_weights(
            self, lora_model: LoRAModel,
            module_name: str) -> Optional[LoRALayerWeights]:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(
                module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'.")
        return lora_model.get_lora(org_module_name)

684
    def deactivate_adapter(self, adapter_id: int) -> bool:
685
686
687
688
689
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
690
691

    def add_adapter(self, adapter: LoRAModel) -> bool:
692
693
        logger.debug("Adding lora. Model id: %d, "
                     "int id: %d", adapter.id, adapter.id)
694
695
696
697
698
699
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
700

701
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
702
703
704
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
705
706

    def remove_adapter(self, adapter_id: int) -> bool:
707
708
709
710
711
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
712

713
714
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
715

716
717
    def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]:
        return self._registered_adapters.get(adapter_id)
718
719
720


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
721

722
723
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
724
        super().__init__(capacity, deactivate_lora_fn)
725
726
727
728
729


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

730
731
732
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
733
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
734
                         vocab_size, lora_config, device)
735
736
737
738
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
739

740
    def list_adapters(self) -> dict[int, LoRAModel]:
741
        """List all registered LoRAModels."""
742
        return dict(self._registered_adapters.cache)
743

744
    def add_adapter(self, lora: LoRAModel) -> bool:
745
        """Add a LoRAModel to the manager."""
746
747
        logger.debug("Adding lora. Model id: %d, "
                     "int id: %d", lora.id, lora.id)
748
749
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
750
751
752
            was_added = True
        else:
            # We always touch to update the LRU cache order
753
            self._registered_adapters.touch(lora.id)
754
755
756
            was_added = False
        return was_added

757
    def activate_adapter(
758
759
760
        self,
        lora_id: int,
    ) -> bool:
761
762
763
764
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
765
        # We always touch to update the LRU cache order
766
        self._active_adapters.touch(lora_id)
767
768
        return result

769
770
771
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
772
773
774
            return True
        return False

775
    def pin_adapter(self, lora_id: int) -> bool:
776
777
778
779
780
781
782
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
783
            self._registered_adapters.pin(lora_id)
784
785
786
787
788
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
789
        if lora_id not in self._active_adapters:
790
            # move lora to gpu if not already active
791
            self.activate_adapter(lora_id)
792

793
        self._active_adapters.pin(lora_id)
794

795
796
797
798
799
800
801

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
802
        device: torch.device,
803
        lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
804
805
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
806
    if not isinstance(model, SupportsLoRA):
807
808
809
810
811
812
813
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
814
        device=device,
815
816
        **kwargs)
    return lora_manager