"vscode:/vscode.git/clone" did not exist on "3650a74ed8fb27d4d53199969f265e426c22891b"
models.py 29.4 KB
Newer Older
1
2
3
4
5
import copy
import json
import math
import os
import re
6
from dataclasses import dataclass, field
7
from typing import Any, Callable, Dict, List, Optional, Type
8
9
10
11
12

import safetensors.torch
import torch
from torch import nn

13
14
15
16
17
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
                                         AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
                                        get_adapter, list_adapters,
                                        remove_adapter, set_adapter_mapping)
18
from vllm.config import LoRAConfig
19
from vllm.logger import init_logger
20
21
22
from vllm.lora.layers import (BaseLayerWithLoRA,
                              LinearScalingRotaryEmbeddingWithLora,
                              LoRAMapping)
23
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
24
from vllm.lora.punica import PunicaWrapper
25
26
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
                             parse_fine_tuned_lora_name, replace_submodule)
27
from vllm.model_executor.models.interfaces import SupportsLoRA
28
from vllm.model_executor.models.utils import PPMissingLayer
29
from vllm.utils import is_pin_memory_available
30

31
logger = init_logger(__name__)
32
33
34
35

_GLOBAL_LORA_ID = 0


36
37
38
39
40
41
42
43
44
45
46
47
@dataclass
class LongContextLoRAContext:
    """Context for lora adapters that support long context."""
    # The scaling factors to support long context lora fine tuned models.
    scaling_factors: List[float]
    # dimension to apply rotary embedding.
    rot_dim: int
    # offsets to the sin_cos_cache for each lora_id loaded.
    # This value is dynamically modified.
    offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


48
49
50
51
52
53
def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


54
class LoRAModel(AdapterModel):
55
56
57
58
59
60
61
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
        loras: Dict[str, LoRALayerWeights],
62
        scaling_factor: Optional[float] = None,
63
    ) -> None:
64
65
66
67
68
69
70
71
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
            scaling_factor: Scaling factor to support long context lora model.
                None if the lora is not tuned for long context support.
        """
72
        self.id = lora_model_id
73
74
75
        # Scaling factor for long context lora model. None if it is not
        # fine tuned for the long context.
        self.scaling_factor = scaling_factor
76
77
78
79
80
        assert (lora_model_id >
                0), f"a valid lora id should be greater than 0, got {self.id}"
        self.rank = rank
        self.loras: Dict[str, LoRALayerWeights] = loras

81
82
83
84
85
86
87
88
89
90
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    @property
    def extra_vocab_size(self) -> int:
        return max(lora.extra_vocab_size
                   for lora in self.loras.values()) if self.loras else 0

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
        rank: int,
        lora_alpha: int,
        tensors: Dict[str, torch.Tensor],
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        embeddings: Optional[Dict[str, torch.Tensor]] = None,
        target_embedding_padding: Optional[int] = None,
112
        scaling_factor: Optional[float] = None,
Terry's avatar
Terry committed
113
114
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
115
116
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
117
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
118
119
120
121
122
123
        loras: Dict[str, LoRALayerWeights] = {}
        for tensor_name, tensor in tensors.items():
            module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
124
                    assert embedding_modules is not None
125
                    embeddings_module = next(
Terry's avatar
Terry committed
126
                        (k for k in embedding_modules if k in module_name),
127
128
129
                        None)
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
Terry's avatar
Terry committed
130
                            embedding_modules[embeddings_module]].to(
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                                device=device, dtype=dtype)
                        if pin_memory:
                            lora_embeddings_tensor = (
                                lora_embeddings_tensor.pin_memory())
                loras[module_name] = LoRALayerWeights(module_name, rank,
                                                      lora_alpha, None, None,
                                                      lora_embeddings_tensor)
            if is_lora_a:
                loras[module_name].lora_a = tensor.to(device=device,
                                                      dtype=dtype).t()
                if pin_memory:
                    loras[module_name].lora_a = loras[
                        module_name].lora_a.pin_memory()
            else:
                loras[module_name].lora_b = tensor.to(device=device,
                                                      dtype=dtype).t()
147
                assert embedding_padding_modules is not None
148
                if any(name in module_name
Terry's avatar
Terry committed
149
                       for name in embedding_padding_modules
150
151
152
153
154
155
156
157
158
159
160
161
                       ) and target_embedding_padding is not None:
                    lora_b = loras[module_name].lora_b
                    assert target_embedding_padding >= lora_b.shape[1]
                    addition = target_embedding_padding - lora_b.shape[1]
                    loras[module_name].lora_b = torch.nn.functional.pad(
                        lora_b, (0, addition))
                if pin_memory:
                    loras[module_name].lora_b = loras[
                        module_name].lora_b.pin_memory()

        for lora in loras.values():
            lora.optimize()
162
        return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
163
164
165

    @classmethod
    def from_local_checkpoint(
Terry's avatar
Terry committed
166
167
        cls,
        lora_dir: str,
168
        expected_lora_modules: List[str],
169
170
        *,
        max_position_embeddings: Optional[int] = None,
Terry's avatar
Terry committed
171
172
173
174
175
176
177
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[Dict[str, str]] = None,
        embedding_padding_modules: Optional[List[str]] = None,
    ) -> "LoRAModel":
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        """Create a LoRAModel from a local checkpoint.
        
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
            max_position_embeddings: Max position embedding length. Used to
                scaling the largest context length. If None, the lora model's
                context length is not scaled.
            lora_model_id: Lora model id. If not given, automatically set by
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
195
196
197
198
199
200
201
        lora_config_path = os.path.join(lora_dir, "adapter_config.json")
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
        new_embeddings_tensor_path = os.path.join(
            lora_dir, "new_embeddings.safetensors")
        new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                    "new_embeddings.bin")
202
203
        with open(lora_config_path) as f:
            config = json.load(f)
204
        if os.path.isfile(lora_tensor_path):
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            tensors: Dict[str, torch.Tensor] = {}
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
            with safetensors.safe_open(lora_tensor_path,
                                       framework="pt") as f:  # type: ignore
                for lora_module in f.keys():  # noqa
                    module_name, _ = parse_fine_tuned_lora_name(lora_module)
                    part_name = module_name.split(".")[-1]
                    if part_name not in expected_lora_modules:
                        unexpected_modules.append(module_name)
                if unexpected_modules:
                    raise ValueError(
                        f"While loading {lora_dir}, expected"
                        f" target modules in {expected_lora_modules}"
                        f" but received {unexpected_modules}."
                        f" Please verify that the loaded LoRA module is correct"
                    )
                # Load tensors if there are only expected modules.
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
230
        elif os.path.isfile(lora_bin_file_path):
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            # When a bin file is provided, we rely on config to find unexpected
            # modules.
            unexpected_modules = []
            target_modules = config["target_modules"]
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
            if unexpected_modules:
                print(unexpected_modules, "modules")
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
                    f" Please verify that the loaded LoRA module is correct")
252
            tensors = torch.load(lora_bin_file_path, map_location=device)
253
254
255
256
257
258
259
260
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
            embeddings = safetensors.torch.load_file(
                new_embeddings_tensor_path)
        elif os.path.isfile(new_embeddings_bin_file_path):
261
262
            embeddings = torch.load(new_embeddings_bin_file_path,
                                    map_location=device)
263
264
265

        rank = config["r"]
        lora_alpha = config["lora_alpha"]
266
267
268
269
270
271
272
273
        context_length = config.get("context_length", None)
        scaling_factor = None
        if context_length:
            if max_position_embeddings is None:
                max_position_embeddings = context_length
            scaling_factor = float(
                math.ceil(context_length / max_position_embeddings))

274
275
276
277
278
279
280
281
282
283
        return cls.from_lora_tensors(
            lora_model_id=get_lora_id()
            if lora_model_id is None else lora_model_id,
            rank=rank,
            lora_alpha=lora_alpha,
            tensors=tensors,
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
284
            scaling_factor=scaling_factor,
Terry's avatar
Terry committed
285
286
            embedding_modules=embedding_modules,
            embedding_padding_modules=embedding_padding_modules,
287
288
289
        )


290
class LoRAModelManager(AdapterModelManager):
291
292
293
294
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
295
        model: SupportsLoRA,
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.lora_config = lora_config
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
        self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
        self.vocab_size = vocab_size
318
        self.long_lora_context: Optional[LongContextLoRAContext] = None
319
320
321
        self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
                                            max_batches=self.max_num_seqs,
                                            device="cuda")
322
323
324
        # Scaling factor -> offset to the sin_cos_cache to it.
        # Used for long context lora.
        self.scaling_factor_to_offset: Dict[float, int] = {}
325
        super().__init__(model)
Terry's avatar
Terry committed
326
327
328
        if hasattr(self.model, "supported_lora_modules"):
            self.supported_lora_modules = copy.deepcopy(
                self.model.supported_lora_modules)
329
330
331
332
            if lora_config.long_lora_scaling_factors:
                # We need to replace rotary emb layer to do batch computation
                # for long lora.
                self.supported_lora_modules.append("rotary_emb")
Terry's avatar
Terry committed
333
334
            self.packed_modules_mapping = copy.deepcopy(
                self.model.packed_modules_mapping)
335
336
337
        self.packed_modules: Dict[str, List[str]] = {}
        self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
        # Dict instead of a Set for compatibility with LRUCache.
338
        self._last_mapping: Optional[LoRAMapping] = None
339
        self._create_lora_modules()
340
341
        self.model.lora_manager = self
        self.adapter_type = 'LoRa'
342
343
344
345
346
347
348
349
350

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

351
352
353
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
354

355
    def activate_adapter(
356
357
358
359
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
360
        if lora_id in self._active_adapters:
361
362
363
364
365
366
367
            return False
        first_free_slot = next(
            ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
             if lora_id is None), None)
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
368
369
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
370
371
        logger.debug("Activating LoRA. int id: %d, slot index: %d",
                     lora_model.id, index)
372
373
374
375
376
377
378
379
380
381
382
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
            module_lora = lora_model.get_lora(module_name)
            if module_lora:
                module_lora.optimize()
                module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
                                module_lora.embeddings_tensor)
            else:
                module.reset_lora(index)
        return True

383
    def _deactivate_adapter(self, lora_id: int):
384
385
386
387
388
389
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    def _set_long_lora_context(self, lora: LoRAModel):
        if self.long_lora_context is None:
            return

        if lora.scaling_factor is None:
            return

        if (lora.scaling_factor not in self.scaling_factor_to_offset):
            raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
                             " has not been initialized.")

        offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
        if offsets:
            self.long_lora_context.offsets_by_lora_id[lora.id] = offsets

405
    def _add_adapter(self, lora: LoRAModel):
406
        self._create_merged_loras_inplace(lora)
407
        self._registered_adapters[lora.id] = lora
408
        self._set_long_lora_context(lora)
409

410
    def pin_adapter(self, lora_id: int) -> bool:
411
412
413
414
415
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
            "Pinning is not supported in LoRAModelManager."
            "Use LRUCacheLoRAModelManager for pinning")  # type: ignore

416
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
417
418
419
420
421
422
423
424
425
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
            self.long_lora_context,
        )
426

427
    def remove_all_adapters(self):
428
        """Remove all LoRAModels from the manager."""
429
        self._registered_adapters.clear()
430
        self.lora_index_to_id = [None] * self.lora_slots
431
        self._active_adapters.clear()
432
433

    def _create_lora_modules(self):
434
435
        for module_name, module in self.model.named_modules(
                remove_duplicate=False):
436
437
            if isinstance(module, PPMissingLayer):
                continue
438
439
            if not self._match_target_modules(module_name):
                continue
440
441
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
442
443
444
            new_module = replace_submodule(
                self.model, module_name,
                from_layer(module, self.lora_slots, self.lora_config,
445
                           packed_moduled_lst, self.model.config))
446
447
448
449
450
451
452
            # LinearScalingRotaryEmbeddingWithLora is used to handle
            # long context lora. Register relevant metadata.
            if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
                self.long_lora_context = LongContextLoRAContext(
                    new_module.scaling_factors, new_module.rotary_dim)
                self.scaling_factor_to_offset = \
                    new_module.scaling_factor_to_offset
453
454
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
455
456
                logits_processor_module = self.model.get_submodule(
                    "logits_processor")
457
                new_module = replace_submodule(
458
459
460
461
462
                    self.model, "logits_processor",
                    from_layer_logits_processor(logits_processor_module,
                                                module, self.lora_slots,
                                                self.lora_config,
                                                self.model.config))
463
464
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
465
466
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
467
468
469
470
471

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
472
473
474
475
    def create_dummy_lora(
            self,
            lora_id: int,
            rank: int,
476
            scaling_factor: Optional[float],
Terry's avatar
Terry committed
477
            embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
478
        """Create zero-initialized LoRAModel for warmup."""
479
        model = LoRAModel(lora_id, rank, {}, scaling_factor)
480
481
        for module_name, module in self.model.named_modules():
            if not self._match_target_modules(module_name) or not isinstance(
482
483
                    module, BaseLayerWithLoRA) or isinstance(
                        module, LinearScalingRotaryEmbeddingWithLora):
484
485
486
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
487
                assert embedding_modules is not None
Terry's avatar
Terry committed
488
                if parts[-1] in embedding_modules:
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
                    input_dim = (module.base_layer.org_vocab_size +
                                 self.lora_config.lora_extra_vocab_size if
                                 hasattr(module.base_layer, "org_vocab_size")
                                 else module.base_layer.weight.shape[1])
                    output_dim = module.base_layer.embedding_dim if hasattr(
                        module.base_layer,
                        "embedding_dim") else module.base_layer.weight.shape[0]
                    embeddings_tensor_dim = (module.base_layer.embedding_dim if
                                             hasattr(module.base_layer,
                                                     "embedding_dim") else
                                             module.base_layer.weight.shape[1])
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
                        module.lora_a_stacked.dtype,
                        "cpu",
                        embeddings_tensor_dim=embeddings_tensor_dim)
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.lora_a_stacked.shape[-1],
                        module.lora_b_stacked.shape[-2],
                        rank,
                        module.lora_a_stacked.dtype,
                        "cpu",
                    )
                lora.optimize()
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
521
                subloras: List[Optional["LoRALayerWeights"]] = []
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    lora.optimize()
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
                r".*\.{target_module}$".format(target_module=target_module),
                module_name) or target_module == module_name
Terry's avatar
Terry committed
542
            for target_module in self.supported_lora_modules)
543
544
545
546

    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
547
548
549
550
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
551
552
553
554
555
556
557
558
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
559
            replacement_loras: List[Optional[LoRALayerWeights]] = []
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
            has_replacement = False
            for r in new_module_names:
                lora = lora_model.get_lora(r)
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                replacement_loras)

575
576
577
578
579
580
581
582
583
584
585
586
    def deactivate_adapter(self, adapter_id: int) -> bool:
        return deactivate_adapter(adapter_id, self._active_adapters,
                                  self._deactivate_adapter)

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", adapter.id, adapter.id,
            adapter.scaling_factor)
        return add_adapter(adapter, self._registered_adapters, self.capacity,
                           self._add_adapter)
587

588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
                                                 self._set_adapter_mapping)

    def remove_adapter(self, adapter_id: int) -> bool:
        return remove_adapter(adapter_id, self._registered_adapters,
                              self.deactivate_adapter)

    def list_adapters(self) -> Dict[int, Any]:
        return list_adapters(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> Optional[Any]:
        return get_adapter(adapter_id, self._registered_adapters)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
604

605
606
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
                                                                   bool]):
607
        super().__init__(capacity, deactivate_lora_fn)
608
609
610
611
612
613
614
615
616
617
618
619
620
621


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
    ):
        super().__init__(model, max_num_seqs, max_num_batched_tokens,
Terry's avatar
Terry committed
622
                         vocab_size, lora_config)
623
624
625
626
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter)
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter)
627

628
    def list_adapters(self) -> Dict[int, LoRAModel]:
629
        """List all registered LoRAModels."""
630
        return dict(self._registered_adapters.cache)
631

632
    def add_adapter(self, lora: LoRAModel) -> bool:
633
        """Add a LoRAModel to the manager."""
634
635
636
637
        logger.debug(
            "Adding lora. Model id: %d, "
            "int id: %d, "
            "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
638
639
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
640
641
642
            was_added = True
        else:
            # We always touch to update the LRU cache order
643
            self._registered_adapters.touch(lora.id)
644
645
646
            was_added = False
        return was_added

647
    def activate_adapter(
648
649
650
        self,
        lora_id: int,
    ) -> bool:
651
652
653
654
        if lora_id not in self._active_adapters and len(
                self._active_adapters) >= self.lora_slots:
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
655
        # We always touch to update the LRU cache order
656
        self._active_adapters.touch(lora_id)
657
658
        return result

659
660
661
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
662
663
664
            return True
        return False

665
    def pin_adapter(self, lora_id: int) -> bool:
666
667
668
669
670
671
672
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
673
            self._registered_adapters.pin(lora_id)
674
675
676
677
678
        except ValueError as err:
            raise ValueError("Pinning failed. "
                             f"LoRA {lora_id} is not registered.") from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
679
        if lora_id not in self._active_adapters:
680
            # move lora to gpu if not already active
681
            self.activate_adapter(lora_id)
682

683
        self._active_adapters.pin(lora_id)
684

685
686
687
688
689
690
691
692
693
694

def create_lora_manager(
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
        **kwargs) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
Terry's avatar
Terry committed
695
    if not hasattr(model, "supported_lora_modules"):
696
697
698
699
700
701
702
703
704
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
        **kwargs)
    return lora_manager