models.py 31.4 KB
Newer Older
1
2
3
4
5
import copy
import json
import math
import os
import re
6
from dataclasses import dataclass, field
7
from typing import Any, Callable, Dict, List, Optional, Type
8
9
10
11
12

import safetensors.torch
import torch
from torch import nn

13
14
15
16
17
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
18
from vllm.config import LoRAConfig
19
from vllm.logger import init_logger
20
21
22
from vllm.lora.layers import (BaseLayerWithLoRA,
                              LinearScalingRotaryEmbeddingWithLora,
                              LoRAMapping)
23
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
24
from vllm.lora.punica import PunicaWrapper
25
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
26
                             is_regex_target_modules,
27
                             parse_fine_tuned_lora_name, replace_submodule)
28
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.model_executor.models.utils import PPMissingLayer
31
from vllm.utils import is_pin_memory_available
32

33
logger = init_logger(__name__)
34
35
36
37

_GLOBAL_LORA_ID = 0


38
39
40
41
42
43
44
45
46
47
48
49
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
    scaling_factors: List[float]
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
    offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


50
51
52
53
54
55
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


56
class LoRAModel(AdapterModel):
57
58
59
60
61
62
63
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
        loras: Dict[str, LoRALayerWeights],
64
        scaling_factor: Optional[float] = None,
65
    ) -> None:
66
67
68
69
70
71
72
73
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
74
        self.id = lora_model_id
75
76
77
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
78
79
80
81
82
        assert (lora_model_id >
                0), f"a valid lora id should be greater than 0, got {self.id}"
        self.rank = rank
        self.loras: Dict[str, LoRALayerWeights] = loras

83
84
85
86
87
88
89
90
91
92
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
        rank: int,
        lora_alpha: int,
        tensors: Dict[str, torch.Tensor],
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        embeddings: Optional[Dict[str, torch.Tensor]] = None,
        target_embedding_padding: Optional[int] = None,
114
        scaling_factor: Optional[float] = None,
Terry's avatar
Terry committed
115
116
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
117
118
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
119
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
120
121
122
123
124
125
        loras: Dict[str, LoRALayerWeights] = {}
        for tensor_name, tensor in tensors.items():
            module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
126
                    assert embedding_modules is not None
127
                    embeddings_module = next(
Terry's avatar
Terry committed
128
                        (k for k in embedding_modules if k in module_name),
129
130
131
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
132
                            embedding_modules[embeddings_module]].to(
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
                loras[module_name] = LoRALayerWeights(module_name, rank,
                                                      lora_alpha, None, None,
                                                      lora_embeddings_tensor)
            if is_lora_a:
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
149
                assert embedding_padding_modules is not None
150
                if any(name in module_name
Terry's avatar
Terry committed
151
                       for name in embedding_padding_modules
152
153
154
155
156
157
158
159
160
161
162
163
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
164
        return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
165
166
167

    @classmethod
    def from_local_checkpoint(
Terry's avatar
Terry committed
168
169
        cls,
        lora_dir: str,
170
        expected_lora_modules: List[str],
171
172
        *,
        max_position_embeddings: Optional[int] = None,
Terry's avatar
Terry committed
173
174
175
176
177
178
179
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
    ) -> "LoRAModel":
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        """Create a LoRAModel from a local checkpoint.
        
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
            max_position_embeddings: Max position embedding length. Used to
                scaling the largest context length. If None, the lora model's
                context length is not scaled.
            lora_model_id: Lora model id. If not given, automatically set by
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
197
198
199
200
201
202
203
        lora_config_path = os.path.join(lora_dir, "adapter_config.json")
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
204
205
        with open(lora_config_path) as f:
            config = json.load(f)
206
        if os.path.isfile(lora_tensor_path):
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
            tensors: Dict[str, torch.Tensor] = {}
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                for lora_module in f.keys():  # noqa
                    module_name, _ = parse_fine_tuned_lora_name(lora_module)
                    part_name = module_name.split(".")[-1]
                    if part_name not in expected_lora_modules:
                        unexpected_modules.append(module_name)
                if unexpected_modules:
                    raise ValueError(
                        f"While loading {lora_dir}, expected"
                        f" target modules in {expected_lora_modules}"
                        f" but received {unexpected_modules}."
                        f" Please verify that the loaded LoRA module is correct"
                    )
                # Load tensors if there are only expected modules.
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
232
        elif os.path.isfile(lora_bin_file_path):
233
234
235
236
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
            target_modules = config["target_modules"]
237
238
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
239
240
241
242
243
244
245
246
247
248
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
249
250
            if unexpected_modules and not is_regex_target_modules(
                    config["target_modules"], expected_lora_modules):
251
252
253
254
255
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
256
            tensors = torch.load(lora_bin_file_path, map_location=device)
257
258
259
260
261
262
263
264
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
265
266
            embeddings = torch.load(new_embeddings_bin_file_path,
                                    map_location=device)
267
268
269

        rank = config["r"]
        lora_alpha = config["lora_alpha"]
270
271
272
273
274
275
276
277
        context_length = config.get("context_length", None)
        scaling_factor = None
        if context_length:
            if max_position_embeddings is None:
                max_position_embeddings = context_length
            scaling_factor = float(
                math.ceil(context_length / max_position_embeddings))

278
279
280
281
282
283
284
285
286
287
        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            rank=rank,
            lora_alpha=lora_alpha,
            tensors=tensors,
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
288
            scaling_factor=scaling_factor,
Terry's avatar
Terry committed
289
290
            embedding_modules=embedding_modules,
            embedding_padding_modules=embedding_padding_modules,
291
292
293
        )


294
class LoRAModelManager(AdapterModelManager):
295
296
297
298
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
299
        model: SupportsLoRA,
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
        self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
        self.vocab_size = vocab_size
322
        self.long_lora_context: Optional[LongContextLoRAContext] = None
323
324
325
        self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
                                            max_batches=self.max_num_seqs,
                                            device="cuda")
326
327
328
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
        self.scaling_factor_to_offset: Dict[float, int] = {}
329
        super().__init__(model)
Terry's avatar
Terry committed
330
331
332
        if hasattr(self.model, "supported_lora_modules"):
            self.supported_lora_modules = copy.deepcopy(
                self.model.supported_lora_modules)
333
334
335
336
            if lora_config.long_lora_scaling_factors:
                # We need to replace rotary emb layer to do batch computation
                # for long lora.
                self.supported_lora_modules.append("rotary_emb")
Terry's avatar
Terry committed
337
338
            self.packed_modules_mapping = copy.deepcopy(
                self.model.packed_modules_mapping)
339
340
        # Used to indicate whether the model is a multimodal model
        self.supports_mm: bool = supports_multimodal(self.model)
341
342
343
        self.packed_modules: Dict[str, List[str]] = {}
        self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
        # Dict instead of a Set for compatibility with LRUCache.
344
        self._last_mapping: Optional[LoRAMapping] = None
345
        self._create_lora_modules()
346
347
        self.model.lora_manager = self
        self.adapter_type = 'LoRa'
348
349
350
351
352
353
354
355
356

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

357
358
359
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
360

361
    def activate_adapter(
362
363
364
365
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
366
        if lora_id in self._active_adapters:
367
368
369
370
371
372
373
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
374
375
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
376
377
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
378
379
380
381
382
383
384
385
386
387
388
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
            module_lora = lora_model.get_lora(module_name)
            if module_lora:
                module_lora.optimize()
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
                                module_lora.embeddings_tensor)
            else:
                module.reset_lora(index)
        return True

389
    def _deactivate_adapter(self, lora_id: int):
390
391
392
393
394
395
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

411
    def _add_adapter(self, lora: LoRAModel):
412
        self._create_merged_loras_inplace(lora)
413
        self._registered_adapters[lora.id] = lora
414
        self._set_long_lora_context(lora)
415

416
    def pin_adapter(self, lora_id: int) -> bool:
417
418
419
420
421
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
            "Pinning is not supported in LoRAModelManager."
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

422
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
423
424
425
426
427
428
429
430
431
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
432

433
    def remove_all_adapters(self):
434
        """Remove all LoRAModels from the manager."""
435
        self._registered_adapters.clear()
436
        self.lora_index_to_id = [None] * self.lora_slots
437
        self._active_adapters.clear()
438
439

    def _create_lora_modules(self):
440
441
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
442
443
            if isinstance(module, PPMissingLayer):
                continue
444
445
            if not self._match_target_modules(module_name):
                continue
446
447
448
449
450
451
452
453
454
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
455
456
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
457
458
459
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
460
                           packed_moduled_lst, self.model.config))
461

462
463
464
465
466
467
468
            # LinearScalingRotaryEmbeddingWithLora is used to handle
            # long context lora. Register relevant metadata.
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
469
470
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
471
472
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
473
                new_module = replace_submodule(
474
475
476
477
478
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
479
480
481
482
483
484
485
486
487

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
            if self.supports_mm and not isinstance(new_module,
                                                   BaseLayerWithLoRA):
                continue
488
489
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
490
491
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
492
493
494
495
496

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
497
498
499
500
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
501
            scaling_factor: Optional[float],
Terry's avatar
Terry committed
502
            embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
503
        """Create zero-initialized LoRAModel for warmup."""
504
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
505
        for module_name, module in self.model.named_modules():
506
507
508
509
            if (not self._match_target_modules(module_name)
                    or not isinstance(module, BaseLayerWithLoRA)
                    or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
                    or self._filter_unsupported_mm_module(module_name)):
510
511
512
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
513
                assert embedding_modules is not None
Terry's avatar
Terry committed
514
                if parts[-1] in embedding_modules:
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
                        module.lora_a_stacked.dtype,
                        "cpu",
                        embeddings_tensor_dim=embeddings_tensor_dim)
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.lora_a_stacked.shape[-1],
                        module.lora_b_stacked.shape[-2],
                        rank,
                        module.lora_a_stacked.dtype,
                        "cpu",
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
547
                subloras: List[Optional["LoRALayerWeights"]] = []
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
568
            for target_module in self.supported_lora_modules)
569

570
571
572
573
574
575
576
577
578
579
580
581
582
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
        language model. LoRA for other modules, such as the vision tower, will 
        be filtered out.
        """
        if self.supports_mm:
            prefix = module_name.split(".")[0]
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
            return (prefix in module_mapping.connector
                    or prefix in module_mapping.tower_model)
        return False

583
584
585
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
586
587
588
589
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
590
591
592
593
594
595
596
597
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
598
            replacement_loras: List[Optional[LoRALayerWeights]] = []
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
            has_replacement = False
            for r in new_module_names:
                lora = lora_model.get_lora(r)
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)

614
615
616
617
618
619
620
621
622
623
624
625
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
626

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

    def list_adapters(self) -> Dict[int, Any]:
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
643

644
645
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
646
        super().__init__(capacity, deactivate_lora_fn)
647
648
649
650
651
652
653
654
655
656
657
658
659
660


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
    ):
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
Terry's avatar
Terry committed
661
                         vocab_size, lora_config)
662
663
664
665
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
666

667
    def list_adapters(self) -> Dict[int, LoRAModel]:
668
        """List all registered LoRAModels."""
669
        return dict(self._registered_adapters.cache)
670

671
    def add_adapter(self, lora: LoRAModel) -> bool:
672
        """Add a LoRAModel to the manager."""
673
674
675
676
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
677
678
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
679
680
681
            was_added = True
        else:
            # We always touch to update the LRU cache order
682
            self._registered_adapters.touch(lora.id)
683
684
685
            was_added = False
        return was_added

686
    def activate_adapter(
687
688
689
        self,
        lora_id: int,
    ) -> bool:
690
691
692
693
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
694
        # We always touch to update the LRU cache order
695
        self._active_adapters.touch(lora_id)
696
697
        return result

698
699
700
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
701
702
703
            return True
        return False

704
    def pin_adapter(self, lora_id: int) -> bool:
705
706
707
708
709
710
711
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
712
            self._registered_adapters.pin(lora_id)
713
714
715
716
717
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
718
        if lora_id not in self._active_adapters:
719
            # move lora to gpu if not already active
720
            self.activate_adapter(lora_id)
721

722
        self._active_adapters.pin(lora_id)
723

724
725
726
727
728
729
730
731
732
733

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
Terry's avatar
Terry committed
734
    if not hasattr(model, "supported_lora_modules"):
735
736
737
738
739
740
741
742
743
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
        **kwargs)
    return lora_manager