models.py 32.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import copy
import math
import os
import re
7
from dataclasses import dataclass, field
8
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
9
10
11
12
13

import safetensors.torch
import torch
from torch import nn

14
15
16
17
18
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
19
from vllm.config import LoRAConfig
20
from vllm.logger import init_logger
21
22
23
from vllm.lora.layers import (BaseLayerWithLoRA,
                              LinearScalingRotaryEmbeddingWithLora,
                              LoRAMapping)
24
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
25
from vllm.lora.peft_helper import PEFTHelper
26
from vllm.lora.punica_wrapper import get_punica_wrapper
27
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
28
                             is_regex_target_modules,
29
                             parse_fine_tuned_lora_name, replace_submodule)
30
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
31
from vllm.model_executor.models.module_mapping import MultiModelKeys
32
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
33
from vllm.utils import is_pin_memory_available
34

35
logger = init_logger(__name__)
36
37
38
39

_GLOBAL_LORA_ID = 0


40
41
42
43
44
45
46
47
48
49
50
51
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
    scaling_factors: List[float]
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
    offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


52
53
54
55
56
57
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


58
class LoRAModel(AdapterModel):
59
60
61
62
63
64
65
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
        loras: Dict[str, LoRALayerWeights],
66
        scaling_factor: Optional[float] = None,
67
    ) -> None:
68
69
70
71
72
73
74
75
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
76
        self.id = lora_model_id
77
78
79
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
80
81
82
        assert (
            lora_model_id
            > 0), f"a valid lora id should be greater than 0, got {self.id}"
83
84
85
        self.rank = rank
        self.loras: Dict[str, LoRALayerWeights] = loras

86
87
88
89
90
91
92
93
94
95
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
        tensors: Dict[str, torch.Tensor],
111
        peft_helper: PEFTHelper,
112
113
114
115
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        embeddings: Optional[Dict[str, torch.Tensor]] = None,
        target_embedding_padding: Optional[int] = None,
Terry's avatar
Terry committed
116
117
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
118
        weights_mapper: Optional[WeightsMapper] = None,
119
120
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
121
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
122
123
        loras: Dict[str, LoRALayerWeights] = {}
        for tensor_name, tensor in tensors.items():
124
            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
125
                tensor_name, weights_mapper)
126
127
128
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
129
                    assert embedding_modules is not None
130
                    embeddings_module = next(
Terry's avatar
Terry committed
131
                        (k for k in embedding_modules if k in module_name),
132
133
134
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
135
                            embedding_modules[embeddings_module]].to(
136
137
138
139
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
140
141
142
                loras[module_name] = LoRALayerWeights.from_config(
                    module_name, peft_helper, lora_embeddings_tensor)

143
144
145
146
147
148
149
150
            if is_bias:
                loras[module_name].bias = tensor.to(device=device,
                                                    dtype=dtype).t()
                bias = tensor.to(device=device, dtype=dtype).t()
                if pin_memory:
                    bias = bias.pin_memory()
                loras[module_name].bias = bias
            elif is_lora_a:
151
152
153
154
155
156
157
158
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
159
                assert embedding_padding_modules is not None
160
                if any(name in module_name
Terry's avatar
Terry committed
161
                       for name in embedding_padding_modules
162
163
164
165
166
167
168
169
170
171
172
173
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
174
175
176
177

        return cls(lora_model_id,
                   peft_helper.r,
                   loras,
178
                   scaling_factor=peft_helper.vllm_long_context_scaling_factor)
179
180
181

    @classmethod
    def from_local_checkpoint(
Terry's avatar
Terry committed
182
183
        cls,
        lora_dir: str,
184
        expected_lora_modules: List[str],
185
        peft_helper: PEFTHelper,
186
        *,
Terry's avatar
Terry committed
187
188
189
190
191
192
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
193
        weights_mapper: Optional[WeightsMapper] = None,
Terry's avatar
Terry committed
194
    ) -> "LoRAModel":
195
196
197
198
199
200
        """Create a LoRAModel from a local checkpoint.
        
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
201
            peft_helper: Loaded lora configuration information.
202
203
204
205
206
207
208
209
            lora_model_id: Lora model id. If not given, automatically set by
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
210
211
212
213
214
215
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
216

217
        unexpected_modules: List[Union[list[str], str]]
218
        if os.path.isfile(lora_tensor_path):
219
220
221
222
223
224
225
226
227
228
229
            tensors: Dict[str, torch.Tensor] = {}
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                for lora_module in f.keys():  # noqa
230
231
                    module_name, _, _ = parse_fine_tuned_lora_name(
                        lora_module, weights_mapper)
232
233
234
235
236
237
238
239
240
241
242
243
244
                    part_name = module_name.split(".")[-1]
                    if part_name not in expected_lora_modules:
                        unexpected_modules.append(module_name)
                if unexpected_modules:
                    raise ValueError(
                        f"While loading {lora_dir}, expected"
                        f" target modules in {expected_lora_modules}"
                        f" but received {unexpected_modules}."
                        f" Please verify that the loaded LoRA module is correct"
                    )
                # Load tensors if there are only expected modules.
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
245
        elif os.path.isfile(lora_bin_file_path):
246
247
248
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
249
            target_modules = peft_helper.target_modules
250
251
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
252
253
254
255
256
257
258
259
260
261
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
262
            if unexpected_modules and not is_regex_target_modules(
263
                    peft_helper.target_modules, expected_lora_modules):
264
265
266
267
268
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
269
            tensors = torch.load(lora_bin_file_path, map_location=device)
270
271
272
273
274
275
276
277
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
278
            embeddings = torch.load(new_embeddings_bin_file_path,
279
280
                                    map_location=device,
                                    weights_only=True)
281
282
283
284
285

        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            tensors=tensors,
286
            peft_helper=peft_helper,
287
288
289
290
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
291
            embedding_modules=embedding_modules,
292
293
            embedding_padding_modules=embedding_padding_modules,
            weights_mapper=weights_mapper)
294
295


296
class LoRAModelManager(AdapterModelManager):
297
298
299
300
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
301
        model: SupportsLoRA,
302
303
304
305
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
306
        device: torch.device,
307
308
309
310
311
312
313
314
315
316
317
318
319
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
320
        self.device = device
321
322
323
324
325
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
        self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
        self.vocab_size = vocab_size
326
        self.long_lora_context: Optional[LongContextLoRAContext] = None
327
328
329
        self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
                                                 max_batches=self.max_num_seqs,
                                                 device=self.device)
330
331
332
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
        self.scaling_factor_to_offset: Dict[float, int] = {}
333
        super().__init__(model)
Terry's avatar
Terry committed
334
335
336
        if hasattr(self.model, "supported_lora_modules"):
            self.supported_lora_modules = copy.deepcopy(
                self.model.supported_lora_modules)
337
338
339
340
            if lora_config.long_lora_scaling_factors:
                # We need to replace rotary emb layer to do batch computation
                # for long lora.
                self.supported_lora_modules.append("rotary_emb")
Terry's avatar
Terry committed
341
342
            self.packed_modules_mapping = copy.deepcopy(
                self.model.packed_modules_mapping)
343
        # Used to indicate whether the model is a multimodal model
344
345
346
347
348
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping"))
349
        self.packed_modules: Dict[str, List[str]] = {}
350
        self.modules: Dict[str, BaseLayerWithLoRA] = {}
351
        # Dict instead of a Set for compatibility with LRUCache.
352
        self._last_mapping: Optional[LoRAMapping] = None
353
        self._create_lora_modules()
354
355
        self.model.lora_manager = self
        self.adapter_type = 'LoRa'
356
357
358
359
360
361
362
363
364

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

365
366
367
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
368

369
    def activate_adapter(
370
371
372
373
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
374
        if lora_id in self._active_adapters:
375
376
377
378
379
380
381
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
382
383
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
384
385
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
386
387
388
389
390
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
            module_lora = lora_model.get_lora(module_name)
            if module_lora:
                module_lora.optimize()
391
392
393
394
395
396
397
398
399
400
                # Bias is not explicitly enabled with the flag enable_lora_bias.
                bias = module_lora.bias
                if ((torch.is_tensor(bias) or
                     (isinstance(bias, Sequence) and any(b is not None
                                                         for b in bias)))
                        and not self.lora_config.bias_enabled):
                    module_lora.bias = None
                    raise ValueError(
                        f"Adapter bias cannot be used for {module_name}"
                        " without --enable-lora-bias.")
401
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
402
403
                                module_lora.embeddings_tensor,
                                module_lora.bias)
404
405
406
407
            else:
                module.reset_lora(index)
        return True

408
    def _deactivate_adapter(self, lora_id: int):
409
410
411
412
413
414
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

430
    def _add_adapter(self, lora: LoRAModel):
431
        self._create_merged_loras_inplace(lora)
432
        self._registered_adapters[lora.id] = lora
433
        self._set_long_lora_context(lora)
434

435
    def pin_adapter(self, lora_id: int) -> bool:
436
437
438
439
440
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
            "Pinning is not supported in LoRAModelManager."
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

441
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
442
443
444
445
446
447
448
449
450
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
451

452
    def remove_all_adapters(self):
453
        """Remove all LoRAModels from the manager."""
454
        self._registered_adapters.clear()
455
        self.lora_index_to_id = [None] * self.lora_slots
456
        self._active_adapters.clear()
457
458

    def _create_lora_modules(self):
459
460
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
461
462
            if isinstance(module, PPMissingLayer):
                continue
463
464
            if not self._match_target_modules(module_name):
                continue
465
466
467
468
469
470
471
472
473
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
474
475
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
476
477
478
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
479
                           packed_moduled_lst, self.model.config))
480

481
482
483
484
485
486
487
            # LinearScalingRotaryEmbeddingWithLora is used to handle
            # long context lora. Register relevant metadata.
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
488
489
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
490
491
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
492
                new_module = replace_submodule(
493
494
495
496
497
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
498
499
500
501
502
503
504
505
506

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
507
508
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
509
510
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
511
512
513
514
515

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
516
517
518
519
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
520
            scaling_factor: Optional[float],
Terry's avatar
Terry committed
521
            embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
522
        """Create zero-initialized LoRAModel for warmup."""
523
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
524
        for module_name, module in self.model.named_modules():
525
            bias_enabled = self.lora_config.bias_enabled
526
527
528
529
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
                    or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
                    or self._filter_unsupported_mm_module(module_name)):
530
531
532
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
533
                assert embedding_modules is not None
Terry's avatar
Terry committed
534
                if parts[-1] in embedding_modules:
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
551
                        module.lora_a_stacked[0].dtype,
552
                        "cpu",
553
554
                        embeddings_tensor_dim=embeddings_tensor_dim,
                        bias_enabled=bias_enabled)
555
556
557
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
558
559
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
560
                        rank,
561
                        module.lora_a_stacked[0].dtype,
562
                        "cpu",
563
                        bias_enabled=bias_enabled,
564
565
566
567
568
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
569
                subloras: List[Optional[LoRALayerWeights]] = []
570
571
572
573
574
575
576
577
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
578
                        bias_enabled=bias_enabled,
579
580
581
582
583
584
585
586
587
588
589
590
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
591
            for target_module in self.supported_lora_modules)
592

593
594
595
596
597
598
599
600
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
        language model. LoRA for other modules, such as the vision tower, will 
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
601
602
603
            prefix_lst = module_mapping.connector + module_mapping.tower_model
            return any(
                [module_name.startswith(prefix) for prefix in prefix_lst])
604
605
        return False

606
607
608
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
609
610
611
612
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
613
614
615
616
617
618
619
620
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
621
            replacement_loras: List[Optional[LoRALayerWeights]] = []
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
            has_replacement = False
            for r in new_module_names:
                lora = lora_model.get_lora(r)
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)

637
638
639
640
641
642
643
644
645
646
647
648
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
649

650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

    def list_adapters(self) -> Dict[int, Any]:
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
666

667
668
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
669
        super().__init__(capacity, deactivate_lora_fn)
670
671
672
673
674


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

675
676
677
    def __init__(self, model: nn.Module, max_num_seqs: int,
                 max_num_batched_tokens: int, vocab_size: int,
                 lora_config: LoRAConfig, device: torch.device):
678
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
679
                         vocab_size, lora_config, device)
680
681
682
683
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
684

685
    def list_adapters(self) -> Dict[int, LoRAModel]:
686
        """List all registered LoRAModels."""
687
        return dict(self._registered_adapters.cache)
688

689
    def add_adapter(self, lora: LoRAModel) -> bool:
690
        """Add a LoRAModel to the manager."""
691
692
693
694
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
695
696
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
697
698
699
            was_added = True
        else:
            # We always touch to update the LRU cache order
700
            self._registered_adapters.touch(lora.id)
701
702
703
            was_added = False
        return was_added

704
    def activate_adapter(
705
706
707
        self,
        lora_id: int,
    ) -> bool:
708
709
710
711
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
712
        # We always touch to update the LRU cache order
713
        self._active_adapters.touch(lora_id)
714
715
        return result

716
717
718
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
719
720
721
            return True
        return False

722
    def pin_adapter(self, lora_id: int) -> bool:
723
724
725
726
727
728
729
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
730
            self._registered_adapters.pin(lora_id)
731
732
733
734
735
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
736
        if lora_id not in self._active_adapters:
737
            # move lora to gpu if not already active
738
            self.activate_adapter(lora_id)
739

740
        self._active_adapters.pin(lora_id)
741

742
743
744
745
746
747
748

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
749
        device: torch.device,
750
751
752
        lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
Terry's avatar
Terry committed
753
    if not hasattr(model, "supported_lora_modules"):
754
755
756
757
758
759
760
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
761
        device=device,
762
763
        **kwargs)
    return lora_manager