model_manager.py 27.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
6
from collections.abc import Callable
from typing import TypeVar
7

8
import regex as re
9
10
11
import torch
from torch import nn

12
from vllm.config.lora import LoRAConfig
13
from vllm.logger import init_logger
14
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
15
from vllm.lora.lora_model import LoRAModel
16
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
17
from vllm.lora.punica_wrapper import get_punica_wrapper
18
19
20
21
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
22
    is_moe_model,
23
    process_packed_modules_mapping,
24
25
    replace_submodule,
)
26
from vllm.model_executor.layers.fused_moe import FusedMoE
27
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
28
from vllm.model_executor.models.interfaces import is_pooling_model
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.model_executor.models.utils import PPMissingLayer
31
from vllm.utils.cache import LRUCache
32
from vllm.utils.platform_utils import is_pin_memory_available
33

34
logger = init_logger(__name__)
35

36
37
38
39
40
41
42
43
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

44
    def _on_remove(self, key: int, value: T | None):
45
46
47
48
49
50
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


class LoRAModelManager:
51
52
53
54
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
55
        model: SupportsLoRA,
56
57
58
59
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
60
        device: torch.device,
61
62
63
64
65
66
67
68
69
70
71
72
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
73
74
75
76
77
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
78
        self.lora_config = lora_config
79
        self.device = device
80
81
82
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
83
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
84
        self.vocab_size = vocab_size
85
86
87
88
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
89
90
            max_loras=self.lora_config.max_loras,
        )
91

92
93
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
94
        f" {self.model.__class__.__name__}."
95

96
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
97
        # Used to indicate whether the model is a multimodal model
98
99
100
101
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
102
103
            and hasattr(self.model, "get_mm_mapping")
        )
104
        self.is_pooling_model = is_pooling_model(self.model)
105
106
107
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
108
        self._last_mapping: LoRAMapping | None = None
109
        self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
110
        self._create_lora_modules()
111

112
        self.model.lora_manager = self
113
114
115

    def __len__(self) -> int:
        return len(self._registered_adapters)
116
117
118
119
120
121
122
123
124

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

125
126
127
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
128

129
    def activate_adapter(
130
131
132
133
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
134
        if lora_id in self._active_adapters:
135
136
            return False
        first_free_slot = next(
137
138
139
140
141
142
143
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
144
145
146
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
147
148
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
149
150
151
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
152
153
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
154
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
155
156
157
158
159
            if not module_lora:
                module.reset_lora(index)
                continue
            # Note (gnovack) - If MOE lora weights are not split into
            # num_experts chunks, we split them here
160
            if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                module_lora.lora_a
            ):
                # Handle PEFT file format where experts.base_layer is the
                # gate_up_proj and experts is the down_proj
                gate_up_proj_lora = self._get_lora_layer_weights(
                    lora_model, module_name + ".base_layer"
                )
                down_proj_lora = module_lora
                # FIXME Edge case where LoRA is not added to gate_up_proj
                # or down_proj
                assert gate_up_proj_lora is not None
                assert down_proj_lora is not None
                if self._is_3d_moe_model:
                    module_lora.lora_a = [
                        gate_up_proj_lora.lora_a,
                        down_proj_lora.lora_a,
                    ]
                    module_lora.lora_b = [
                        gate_up_proj_lora.lora_b,
                        down_proj_lora.lora_b,
                    ]
                else:
                    # Some 3D MoE models haven't added the `is_3d_moe_weight`
                    # attribute yet, so fallback here
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                    num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                    gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                    up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                    gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                        num_experts, dim=-1
                    )
                    up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                        num_experts, dim=-1
                    )

                    down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                    down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                    lora_a = []
                    lora_b = []
                    for i in range(num_experts):
                        lora_a.append(gate_proj_a[i])
                        lora_a.append(down_proj_a[i])
                        lora_a.append(up_proj_a[i])

                        lora_b.append(gate_proj_b[i])
                        lora_b.append(down_proj_b[i])
                        lora_b.append(up_proj_b[i])

                    module_lora.lora_a = lora_a
                    module_lora.lora_b = lora_b
213
214
215
216
217
            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
218

219
220
        return True

221
    def _deactivate_adapter(self, lora_id: int):
222
223
224
225
226
227
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

228
    def _add_adapter(self, lora: LoRAModel):
229
        self._create_merged_loras_inplace(lora)
230
        self._registered_adapters[lora.id] = lora
231

232
    def pin_adapter(self, lora_id: int) -> bool:
233
234
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
235
            "Pinning is not supported in LoRAModelManager. "
236
237
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
238

239
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
240
241
242
243
244
245
246
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
247

248
    def remove_all_adapters(self):
249
        """Remove all LoRAModels from the manager."""
250
        self._registered_adapters.clear()
251
        self.lora_index_to_id = [None] * self.lora_slots
252
        self._active_adapters.clear()
253
254

    def _create_lora_modules(self):
255
256
257
258
259
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
260
            return module_name.rpartition(".")[0]
261

262
        for module_name, module in self.model.named_modules(remove_duplicate=False):
263
264
            if isinstance(module, PPMissingLayer):
                continue
265

266
267
            if not self._match_target_modules(module_name):
                continue
268
269
270
271
272
273
274
275
276
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
277
278
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
279
280
281
282
283
284
285
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
286
            new_module = replace_submodule(
287
288
289
290
291
292
293
294
295
296
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
297

298
299
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
300
                logits_processor_module_name = "logits_processor"
301
302
303
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
304
305
                        f"{parent_module}.{logits_processor_module_name}"
                    )
306

307
                logits_processor_module = self.model.get_submodule(
308
309
                    logits_processor_module_name
                )
310

311
                new_module = replace_submodule(
312
313
314
315
316
317
318
319
320
321
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
322
323
324
325
326
327

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
328
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
329
                continue
330
331
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
332
333
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
334
        pass
335
336

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
337
        assert isinstance(module, BaseLayerWithLoRA), (
338
339
            f"Module {module_name} must be a BaseLayerWithLoRA instance, "
            f"got {type(module)}"
340
        )
341
342
        self.modules[module_name] = module

Terry's avatar
Terry committed
343
    def create_dummy_lora(
344
345
346
        self,
        lora_id: int,
        rank: int,
347
        embedding_modules: dict[str, str] | None = None,
348
    ) -> LoRAModel:
349
        """Create zero-initialized LoRAModel for warmup."""
350
        model = LoRAModel(lora_id, rank, {})
351
        for module_name, module in self.model.named_modules():
352
353
354
355
356
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
357
358
359
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
360
                assert embedding_modules is not None
Terry's avatar
Terry committed
361
                if parts[-1] in embedding_modules:
362
363
364
365
366
367
368
369
370
371
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
372
373
374
375
376
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
377
                        module.lora_a_stacked[0].dtype,
378
                        "cpu",
379
                    )
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
404
405
406
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
407
408
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
409
                        rank,
410
                        module.lora_a_stacked[0].dtype,
411
412
                        "cpu",
                    )
413
                    model.loras[module_name] = lora
414
415
416
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
417
                subloras: list[LoRALayerWeights | None] = []
418
419
420
421
422
423
424
425
426
427
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
428
429
430
431
                if module.__class__.__name__ == "FusedMoEWithLoRA":
                    lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
                else:
                    lora = PackedLoRALayerWeights.pack(subloras)
432
                model.loras[module_name] = lora
433
434
435
436
437
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
438
439
440
441
442
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
443

444
445
446
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
447
        language model. LoRA for other modules, such as the vision tower, will
448
449
450
451
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
452
            prefix_lst = module_mapping.connector + module_mapping.tower_model
453
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
454
455
        return False

456
457
458
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
459
460
461
462
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
463
464
465
466
467
468
469
470
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
471
            replacement_loras: list[LoRALayerWeights | None] = []
472
            replaced_module: set[str] = set()
473
474
            has_replacement = False
            for r in new_module_names:
475
                lora = self._get_lora_layer_weights(lora_model, r)
476
477
478
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
479
                    replaced_module.add(r)
480
481
482
483
484
485
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
486
            # HACK Temporary solution for the pool model.
487
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
488
489
490
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
491
492
493
494
495
496
497
498
            if module_name.endswith(".experts"):
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
                    replacement_loras, module_name
                )
            else:
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                    replacement_loras
                )
499
500
501
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
502

503
504
505
        for lora in lora_model.loras.values():
            lora.optimize()

506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        first_lora: LoRALayerWeights = next(iter(lora_model.loras.values()))
        assert first_lora.lora_a is not None
        if isinstance(first_lora.lora_a, list):
            lora_device = next(iter(first_lora.lora_a))
        else:
            lora_device = first_lora.lora_a.device
        # Execute pin_memory after LoRA weight merging, mainly because:
        # 1. Some MoE models have a large number of LoRA weights. If we
        # perform # pin_memory immediately after loading weights, the
        # overhead is significant.
        # 2. The weight packing above (e.g., pack_moe) may invalidate the
        # pin_memory allocation, so we execute it after packing.

        pin_memory = str(lora_device) == "cpu" and is_pin_memory_available()
        if pin_memory:
            for lora in lora_model.loras.values():
                if isinstance(lora.lora_a, list):
                    for index in range(len(lora.lora_a)):
                        if lora.lora_a[index] is None:
                            continue
                        lora.lora_a[index] = lora.lora_a[index].pin_memory()
                        lora.lora_b[index] = lora.lora_b[index].pin_memory()
                else:
                    lora.lora_a = lora.lora_a.pin_memory()
                    lora.lora_b = lora.lora_b.pin_memory()

532
    def _get_lora_layer_weights(
533
        self, lora_model: LoRAModel, module_name: str
534
    ) -> LoRALayerWeights | None:
535
        org_module_name = module_name
536
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
537
538
539
540
541
542
543
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
544
545
                    "after removing the prefix 'model.'."
                )
546
547
        return lora_model.get_lora(org_module_name)

548
    def deactivate_adapter(self, adapter_id: int) -> bool:
549
550
551
552
553
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
554
555

    def add_adapter(self, adapter: LoRAModel) -> bool:
556
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
557
558
559
560
561
562
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
563

564
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
565
566
567
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
568
569

    def remove_adapter(self, adapter_id: int) -> bool:
570
571
572
573
574
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
575

576
577
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
578

579
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
580
        return self._registered_adapters.get(adapter_id)
581
582
583


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
584
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
585
        super().__init__(capacity, deactivate_lora_fn)
586
587
588
589
590


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

591
592
593
594
595
596
597
598
599
600
601
602
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
603
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
604
605
            self.capacity, self.deactivate_adapter
        )
606
        self._active_adapters: LoRALRUCache = LoRALRUCache(
607
608
            self.lora_slots, self._deactivate_adapter
        )
609

610
    def list_adapters(self) -> dict[int, LoRAModel]:
611
        """List all registered LoRAModels."""
612
        return dict(self._registered_adapters.cache)
613

614
    def add_adapter(self, lora: LoRAModel) -> bool:
615
        """Add a LoRAModel to the manager."""
616
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
617
618
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
619
620
621
            was_added = True
        else:
            # We always touch to update the LRU cache order
622
            self._registered_adapters.touch(lora.id)
623
624
625
            was_added = False
        return was_added

626
    def activate_adapter(
627
628
629
        self,
        lora_id: int,
    ) -> bool:
630
631
632
633
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
634
635
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
636
        # We always touch to update the LRU cache order
637
        self._active_adapters.touch(lora_id)
638
639
        return result

640
641
642
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
643
644
645
            return True
        return False

646
    def pin_adapter(self, lora_id: int) -> bool:
647
648
649
650
651
652
653
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
654
            self._registered_adapters.pin(lora_id)
655
        except ValueError as err:
656
657
658
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
659
660

    def _pin_lora_in_gpu_cache(self, lora_id: int):
661
        if lora_id not in self._active_adapters:
662
            # move lora to gpu if not already active
663
            self.activate_adapter(lora_id)
664

665
        self._active_adapters.pin(lora_id)
666

667
668

def create_lora_manager(
669
670
671
672
673
674
675
676
677
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
678
    """Create a LoRA adapter for a given model."""
679
    if not isinstance(model, SupportsLoRA):
680
681
682
683
684
685
686
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
687
        device=device,
688
689
        **kwargs,
    )
690
    return lora_manager