models.py 32.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import copy
import math
import os
import re
7
from dataclasses import dataclass, field
8
9
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
                    Union)
10
11
12
13
14

import safetensors.torch
import torch
from torch import nn

15
16
17
18
19
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
20
from vllm.config import LoRAConfig
21
from vllm.logger import init_logger
22
from vllm.lora.layers import (BaseLayerWithLoRA,
23
                              LinearScalingRotaryEmbeddingWithLoRA,
24
                              LoRAMapping)
25
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
26
from vllm.lora.peft_helper import PEFTHelper
27
from vllm.lora.punica_wrapper import get_punica_wrapper
28
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
29
                             get_supported_lora_modules,
30
                             is_regex_target_modules,
31
                             parse_fine_tuned_lora_name, replace_submodule)
32
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
33
from vllm.model_executor.models.module_mapping import MultiModelKeys
34
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
35
from vllm.utils import is_pin_memory_available
36

37
logger = init_logger(__name__)
38
39
40
41

_GLOBAL_LORA_ID = 0


42
43
44
45
46
47
48
49
50
51
52
53
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
    scaling_factors: List[float]
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
    offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


54
55
56
57
58
59
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


60
class LoRAModel(AdapterModel):
61
62
63
64
65
66
67
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
        loras: Dict[str, LoRALayerWeights],
68
        scaling_factor: Optional[float] = None,
69
    ) -> None:
70
71
72
73
74
75
76
77
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
78
        self.id = lora_model_id
79
80
81
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
82
83
84
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
85
86
87
        self.rank = rank
        self.loras: Dict[str, LoRALayerWeights] = loras

88
89
90
91
92
93
94
95
96
97
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
        tensors: Dict[str, torch.Tensor],
113
        peft_helper: PEFTHelper,
114
115
116
117
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        embeddings: Optional[Dict[str, torch.Tensor]] = None,
        target_embedding_padding: Optional[int] = None,
Terry's avatar
Terry committed
118
119
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
120
        weights_mapper: Optional[WeightsMapper] = None,
121
122
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
123
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
124
125
        loras: Dict[str, LoRALayerWeights] = {}
        for tensor_name, tensor in tensors.items():
126
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
127
                tensor_name, weights_mapper)
128
129
130
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
131
                    assert embedding_modules is not None
132
                    embeddings_module = next(
Terry's avatar
Terry committed
133
                        (k for k in embedding_modules if k in module_name),
134
135
136
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
137
                            embedding_modules[embeddings_module]].to(
138
139
140
141
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
142
143
144
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

145
146
147
148
149
150
151
152
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
153
154
155
156
157
158
159
160
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
161
                assert embedding_padding_modules is not None
162
                if any(name in module_name
Terry's avatar
Terry committed
163
                       for name in embedding_padding_modules
164
165
166
167
168
169
170
171
172
173
174
175
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
176
177
178
179

        return cls(lora_model_id,
                   peft_helper.r,
                   loras,
180
                   scaling_factor=peft_helper.vllm_long_context_scaling_factor)
181
182
183

    @classmethod
    def from_local_checkpoint(
Terry's avatar
Terry committed
184
185
        cls,
        lora_dir: str,
186
        expected_lora_modules: List[str],
187
        peft_helper: PEFTHelper,
188
        *,
Terry's avatar
Terry committed
189
190
191
192
193
194
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
195
        weights_mapper: Optional[WeightsMapper] = None,
Terry's avatar
Terry committed
196
    ) -> "LoRAModel":
197
198
199
200
201
202
        """Create a LoRAModel from a local checkpoint.
        
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
203
            peft_helper: Loaded lora configuration information.
204
            lora_model_id: LoRA model id. If not given, automatically set by
205
206
207
208
209
210
211
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
212
213
214
215
216
217
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
218

219
        unexpected_modules: List[Union[list[str], str]]
220
        if os.path.isfile(lora_tensor_path):
221
222
223
224
225
226
227
228
229
230
231
            tensors: Dict[str, torch.Tensor] = {}
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                for lora_module in f.keys():  # noqa
232
233
                    module_name, _, _ = parse_fine_tuned_lora_name(
                        lora_module, weights_mapper)
234
235
236
237
238
239
240
241
242
243
244
245
246
                    part_name = module_name.split(".")[-1]
                    if part_name not in expected_lora_modules:
                        unexpected_modules.append(module_name)
                if unexpected_modules:
                    raise ValueError(
                        f"While loading {lora_dir}, expected"
                        f" target modules in {expected_lora_modules}"
                        f" but received {unexpected_modules}."
                        f" Please verify that the loaded LoRA module is correct"
                    )
                # Load tensors if there are only expected modules.
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
247
        elif os.path.isfile(lora_bin_file_path):
248
249
250
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
251
            target_modules = peft_helper.target_modules
252
253
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
254
255
256
257
258
259
260
261
262
263
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
264
            if unexpected_modules and not is_regex_target_modules(
265
                    peft_helper.target_modules, expected_lora_modules):
266
267
268
269
270
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
271
            tensors = torch.load(lora_bin_file_path, map_location=device)
272
273
274
275
276
277
278
279
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
280
            embeddings = torch.load(new_embeddings_bin_file_path,
281
282
                                    map_location=device,
                                    weights_only=True)
283
284
285
286
287

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
288
            peft_helper=peft_helper,
289
290
291
292
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
293
            embedding_modules=embedding_modules,
294
295
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
296
297


298
class LoRAModelManager(AdapterModelManager):
299
300
301
302
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
303
        model: SupportsLoRA,
304
305
306
307
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
308
        device: torch.device,
309
310
311
312
313
314
315
316
317
318
319
320
321
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
322
        self.device = device
323
324
325
326
327
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
        self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
        self.vocab_size = vocab_size
328
        self.long_lora_context: Optional[LongContextLoRAContext] = None
329
330
331
332
333
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
            max_loras=self.lora_config.max_loras)
334
335
336
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
        self.scaling_factor_to_offset: Dict[float, int] = {}
337
        super().__init__(model)
338
339
340
341
342
343
344
345
346
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
        f"{self.model.__class__.__name__}."
        if lora_config.long_lora_scaling_factors:
            # We need to replace rotary emb layer to do batch computation
            # for long lora.
            self.supported_lora_modules.append("rotary_emb")
        self.packed_modules_mapping = copy.deepcopy(
            self.model.packed_modules_mapping)
347
        # Used to indicate whether the model is a multimodal model
348
349
350
351
352
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
353
        self.packed_modules: Dict[str, List[str]] = {}
354
        self.modules: Dict[str, BaseLayerWithLoRA] = {}
355
        # Dict instead of a Set for compatibility with LRUCache.
356
        self._last_mapping: Optional[LoRAMapping] = None
357
        self._create_lora_modules()
358
359
        self.model.lora_manager = self
        self.adapter_type = 'LoRa'
360
361
362
363
364
365
366
367
368

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

369
370
371
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
372

373
    def activate_adapter(
374
375
376
377
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
378
        if lora_id in self._active_adapters:
379
380
381
382
383
384
385
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
386
387
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
388
389
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
390
391
392
393
394
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
            module_lora = lora_model.get_lora(module_name)
            if module_lora:
                module_lora.optimize()
395
396
397
398
399
400
401
402
403
404
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
405
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
406
407
                                module_lora.embeddings_tensor,
                                module_lora.bias)
408
409
410
411
            else:
                module.reset_lora(index)
        return True

412
    def _deactivate_adapter(self, lora_id: int):
413
414
415
416
417
418
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

434
    def _add_adapter(self, lora: LoRAModel):
435
        self._create_merged_loras_inplace(lora)
436
        self._registered_adapters[lora.id] = lora
437
        self._set_long_lora_context(lora)
438

439
    def pin_adapter(self, lora_id: int) -> bool:
440
441
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
442
            "Pinning is not supported in LoRAModelManager. "
443
444
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

445
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
446
447
448
449
450
451
452
453
454
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
455

456
    def remove_all_adapters(self):
457
        """Remove all LoRAModels from the manager."""
458
        self._registered_adapters.clear()
459
        self.lora_index_to_id = [None] * self.lora_slots
460
        self._active_adapters.clear()
461
462

    def _create_lora_modules(self):
463
464
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
465
466
            if isinstance(module, PPMissingLayer):
                continue
467
468
            if not self._match_target_modules(module_name):
                continue
469
470
471
472
473
474
475
476
477
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
478
479
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
480
481
482
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
483
                           packed_moduled_lst, self.model.config))
484

485
            # LinearScalingRotaryEmbeddingWithLoRA is used to handle
486
            # long context lora. Register relevant metadata.
487
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
488
489
490
491
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
492
493
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
494
495
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
496
                new_module = replace_submodule(
497
498
499
500
501
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
502
503
504
505
506
507
508
509
510

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
511
512
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
513
514
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
515
516
517
518
519

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
520
521
522
523
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
524
            scaling_factor: Optional[float],
Terry's avatar
Terry committed
525
            embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
526
        """Create zero-initialized LoRAModel for warmup."""
527
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
528
        for module_name, module in self.model.named_modules():
529
            bias_enabled = self.lora_config.bias_enabled
530
531
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
532
                    or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
533
                    or self._filter_unsupported_mm_module(module_name)):
534
535
536
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
537
                assert embedding_modules is not None
Terry's avatar
Terry committed
538
                if parts[-1] in embedding_modules:
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
555
                        module.lora_a_stacked[0].dtype,
556
                        "cpu",
557
558
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
559
560
561
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
562
563
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
564
                        rank,
565
                        module.lora_a_stacked[0].dtype,
566
                        "cpu",
567
                        bias_enabled=bias_enabled,
568
569
570
571
572
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
573
                subloras: List[Optional[LoRALayerWeights]] = []
574
575
576
577
578
579
580
581
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
582
                        bias_enabled=bias_enabled,
583
584
585
586
587
588
589
590
591
592
593
594
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
595
            for target_module in self.supported_lora_modules)
596

597
598
599
600
601
602
603
604
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
        language model. LoRA for other modules, such as the vision tower, will 
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
605
606
607
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
608
609
        return False

610
611
612
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
613
614
615
616
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
617
618
619
620
621
622
623
624
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
625
            replacement_loras: List[Optional[LoRALayerWeights]] = []
626
            replaced_module: Set[str] = set()
627
628
629
630
631
632
            has_replacement = False
            for r in new_module_names:
                lora = lora_model.get_lora(r)
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
633
                    replaced_module.add(r)
634
635
636
637
638
639
640
641
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)
642
643
644
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
645

646
647
648
649
650
651
652
653
654
655
656
657
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
658

659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

    def list_adapters(self) -> Dict[int, Any]:
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
675

676
677
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
678
        super().__init__(capacity, deactivate_lora_fn)
679
680
681
682
683


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

684
685
686
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
687
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
688
                         vocab_size, lora_config, device)
689
690
691
692
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
693

694
    def list_adapters(self) -> Dict[int, LoRAModel]:
695
        """List all registered LoRAModels."""
696
        return dict(self._registered_adapters.cache)
697

698
    def add_adapter(self, lora: LoRAModel) -> bool:
699
        """Add a LoRAModel to the manager."""
700
701
702
703
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
704
705
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
706
707
708
            was_added = True
        else:
            # We always touch to update the LRU cache order
709
            self._registered_adapters.touch(lora.id)
710
711
712
            was_added = False
        return was_added

713
    def activate_adapter(
714
715
716
        self,
        lora_id: int,
    ) -> bool:
717
718
719
720
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
721
        # We always touch to update the LRU cache order
722
        self._active_adapters.touch(lora_id)
723
724
        return result

725
726
727
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
728
729
730
            return True
        return False

731
    def pin_adapter(self, lora_id: int) -> bool:
732
733
734
735
736
737
738
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
739
            self._registered_adapters.pin(lora_id)
740
741
742
743
744
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
745
        if lora_id not in self._active_adapters:
746
            # move lora to gpu if not already active
747
            self.activate_adapter(lora_id)
748

749
        self._active_adapters.pin(lora_id)
750

751
752
753
754
755
756
757

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
758
        device: torch.device,
759
760
761
        lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
762
    if not hasattr(model, "packed_modules_mapping"):
763
764
765
766
767
768
769
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
770
        device=device,
771
772
        **kwargs)
    return lora_manager