models.py 32.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
7
from collections.abc import Callable
from typing import TypeVar
8

9
import regex as re
10
11
12
13
import safetensors.torch
import torch
from torch import nn

14
from vllm.config.lora import LoRAConfig
15
from vllm.logger import init_logger
16
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
17
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
18
from vllm.lora.peft_helper import PEFTHelper
19
from vllm.lora.punica_wrapper import get_punica_wrapper
20
21
22
23
24
25
26
27
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
    replace_submodule,
)
28
from vllm.model_executor.layers.fused_moe import FusedMoE
29
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
30
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
31
from vllm.model_executor.models.interfaces import is_pooling_model
32
from vllm.model_executor.models.module_mapping import MultiModelKeys
33
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
34
from vllm.model_executor.utils import get_packed_modules_mapping
35
36
from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache
37

38
logger = init_logger(__name__)
39

40
41
42
43
44
45
46
47
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

48
    def _on_remove(self, key: int, value: T | None):
49
50
51
52
53
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


54
55
56
57
58
59
60
61
62
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


63
64
65
66
67
68
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
69
70
            "contain expert weights."
        )
71
72
73
74
        return True
    return False


75
class LoRAModel:
76
77
78
79
80
81
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
82
        loras: dict[str, LoRALayerWeights],
83
    ) -> None:
84
85
86
87
88
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
89

90
        """
91
        self.id = lora_model_id
92

93
94
95
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
96
        self.rank = rank
97
        self.loras: dict[str, LoRALayerWeights] = loras
98

99
100
101
102
103
104
105
106
107
108
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

109
110
    @property
    def extra_vocab_size(self) -> int:
111
112
113
114
115
        return (
            max(lora.extra_vocab_size for lora in self.loras.values())
            if self.loras
            else 0
        )
116

117
    def get_lora(self, module_name: str) -> LoRALayerWeights | None:
118
119
120
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

121
122
123
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

124
125
126
127
128
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
129
        tensors: dict[str, torch.Tensor],
130
        peft_helper: PEFTHelper,
131
        device: str = "cuda",
132
133
134
135
136
137
        dtype: torch.dtype | None = None,
        embeddings: dict[str, torch.Tensor] | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
138
139
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
140
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
141
        loras: dict[str, LoRALayerWeights] = {}
142
        for tensor_name, tensor in tensors.items():
143
            module_name, is_lora_a = parse_fine_tuned_lora_name(
144
145
                tensor_name, weights_mapper
            )
146
147
148
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
149
                    assert embedding_modules is not None
150
                    embeddings_module = next(
151
152
                        (k for k in embedding_modules if k in module_name), None
                    )
153
154
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
155
156
                            embedding_modules[embeddings_module]
                        ].to(device=device, dtype=dtype)
157
                        if pin_memory:
158
                            lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
159
                loras[module_name] = LoRALayerWeights.from_config(
160
161
                    module_name, peft_helper, lora_embeddings_tensor
                )
162

163
            if is_lora_a:
164
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
165
                if pin_memory:
166
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
167
            else:
168
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
169
                assert embedding_padding_modules is not None
170
171
172
173
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
174
                    lora_b = loras[module_name].lora_b
175
176
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
177
                    loras[module_name].lora_b = torch.nn.functional.pad(
178
179
                        lora_b, (0, 0, 0, addition)
                    )
180
                if pin_memory:
181
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
182
183
184

        for lora in loras.values():
            lora.optimize()
185

186
        return cls(lora_model_id, peft_helper.r, loras)
187
188
189

    @classmethod
    def from_local_checkpoint(
190
191
192
193
194
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
195
        lora_model_id: int | None = None,
196
        device: str = "cuda",
197
198
199
200
201
202
        dtype: torch.dtype | None = None,
        target_embedding_padding: int | None = None,
        embedding_modules: dict[str, str] | None = None,
        embedding_padding_modules: list[str] | None = None,
        weights_mapper: WeightsMapper | None = None,
        tensorizer_config_dict: dict | None = None,
203
    ) -> "LoRAModel":
204
        """Create a LoRAModel from a local checkpoint.
205

206
207
208
209
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
210
            peft_helper: Loaded lora configuration information.
211
            lora_model_id: LoRA model id. If not given, automatically set by
212
213
214
215
216
217
218
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
219
220
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
221
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
222
        new_embeddings_tensor_path = os.path.join(
223
224
225
            lora_dir, "new_embeddings.safetensors"
        )
        new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
226
        tensors: dict[str, torch.Tensor] = {}
227
        unexpected_modules: list[list[str] | str] = []
228
229
230

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
231
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
232
233
234
235
236
237
238
239
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
240
241
                    f" Please verify that the loaded LoRA module is correct"
                )
242
243
244
245
246

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
247
248
249
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
250
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
251
252
253
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
254
255
                **tensorizer_args.deserialization_kwargs,
            )
256
            check_unexpected_modules(tensors)
257

258
        elif os.path.isfile(lora_tensor_path):
259
260
261
262
263
264
265
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
266
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
267
                # Load tensors if there are only expected modules.
268
                check_unexpected_modules(f)
269
270
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
271
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
272
273
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
274
            unexpected_modules = []
275
            target_modules = peft_helper.target_modules
276
277
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
278
279
280
281
282
283
284
285
286
287
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
288
            if unexpected_modules and not is_regex_target_modules(
289
290
                peft_helper.target_modules, expected_lora_modules
            ):
291
292
293
294
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
295
296
297
298
299
300
301
302
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
303
304
305
306
307
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
308
            embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
309
        elif os.path.isfile(new_embeddings_bin_file_path):
310
311
312
            embeddings = torch.load(
                new_embeddings_bin_file_path, map_location=device, weights_only=True
            )
313
314

        return cls.from_lora_tensors(
315
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
316
            tensors=tensors,
317
            peft_helper=peft_helper,
318
319
320
321
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
322
            embedding_modules=embedding_modules,
323
            embedding_padding_modules=embedding_padding_modules,
324
325
            weights_mapper=weights_mapper,
        )
326
327


328
class LoRAModelManager:
329
330
331
332
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
333
        model: SupportsLoRA,
334
335
336
337
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
338
        device: torch.device,
339
340
341
342
343
344
345
346
347
348
349
350
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
351
352
353
354
355
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
356
        self.lora_config = lora_config
357
        self.device = device
358
359
360
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
361
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
362
        self.vocab_size = vocab_size
363
364
365
366
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
367
368
            max_loras=self.lora_config.max_loras,
        )
369

370
371
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
372
        f" {self.model.__class__.__name__}."
373
374

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
375
        # Used to indicate whether the model is a multimodal model
376
377
378
379
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
380
381
            and hasattr(self.model, "get_mm_mapping")
        )
382
        self.is_pooling_model = is_pooling_model(self.model)
383
        self.is_moe_model = is_moe_model(self.model)
384
385
386
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
387
        self._last_mapping: LoRAMapping | None = None
388
        self._create_lora_modules()
389
        self.model.lora_manager = self
390
391
392

    def __len__(self) -> int:
        return len(self._registered_adapters)
393
394
395
396
397
398
399
400
401

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

402
403
404
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
405

406
    def activate_adapter(
407
408
409
410
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
411
        if lora_id in self._active_adapters:
412
413
            return False
        first_free_slot = next(
414
415
416
417
418
419
420
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
421
422
423
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
424
425
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
426
427
428
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
429
430
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
431
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
432
433
            if module_lora:
                module_lora.optimize()
434
435
436
437
438
439
                module.set_lora(
                    index,
                    module_lora.lora_a,
                    module_lora.lora_b,
                    module_lora.embeddings_tensor,
                )
440
441
442
443
            else:
                module.reset_lora(index)
        return True

444
    def _deactivate_adapter(self, lora_id: int):
445
446
447
448
449
450
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

451
    def _add_adapter(self, lora: LoRAModel):
452
        self._create_merged_loras_inplace(lora)
453
        self._registered_adapters[lora.id] = lora
454

455
    def pin_adapter(self, lora_id: int) -> bool:
456
457
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
458
            "Pinning is not supported in LoRAModelManager. "
459
460
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
461

462
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
463
464
465
466
467
468
469
470
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
471

472
    def remove_all_adapters(self):
473
        """Remove all LoRAModels from the manager."""
474
        self._registered_adapters.clear()
475
        self.lora_index_to_id = [None] * self.lora_slots
476
        self._active_adapters.clear()
477
478

    def _create_lora_modules(self):
479
480
481
482
483
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
484
            return module_name.rpartition(".")[0]
485

486
        for module_name, module in self.model.named_modules(remove_duplicate=False):
487
488
            if isinstance(module, PPMissingLayer):
                continue
489
490
            if not self._match_target_modules(module_name):
                continue
491
492
493
494
495
496
497
498
499
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
500
501
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
502
            new_module = replace_submodule(
503
504
505
506
507
508
509
510
511
512
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
513

514
515
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
516
                logits_processor_module_name = "logits_processor"
517
518
519
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
520
521
                        f"{parent_module}.{logits_processor_module_name}"
                    )
522

523
                logits_processor_module = self.model.get_submodule(
524
525
                    logits_processor_module_name
                )
526

527
                new_module = replace_submodule(
528
529
530
531
532
533
534
535
536
537
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
538
539
540
541
542
543

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
544
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
545
                continue
546
547
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
548
549
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
550
551
552
553
554

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
555
    def create_dummy_lora(
556
557
558
        self,
        lora_id: int,
        rank: int,
559
        embedding_modules: dict[str, str] | None = None,
560
    ) -> LoRAModel:
561
        """Create zero-initialized LoRAModel for warmup."""
562
        model = LoRAModel(lora_id, rank, {})
563
        for module_name, module in self.model.named_modules():
564
565
566
567
568
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
569
570
571
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
572
                assert embedding_modules is not None
Terry's avatar
Terry committed
573
                if parts[-1] in embedding_modules:
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
                    input_dim = (
                        module.base_layer.org_vocab_size
                        + self.lora_config.lora_extra_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
                    embeddings_tensor_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[1]
                    )
590
591
592
593
594
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
595
                        module.lora_a_stacked[0].dtype,
596
                        "cpu",
597
                        embeddings_tensor_dim=embeddings_tensor_dim,
598
                    )
599
600
601
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
602
603
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
604
                        rank,
605
                        module.lora_a_stacked[0].dtype,
606
607
608
609
610
                        "cpu",
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
611
                subloras: list[LoRALayerWeights | None] = []
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
629
630
631
632
633
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
634

635
636
637
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
638
        language model. LoRA for other modules, such as the vision tower, will
639
640
641
642
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
643
            prefix_lst = module_mapping.connector + module_mapping.tower_model
644
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
645
646
        return False

647
648
649
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
650
651
652
653
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
654
655
656
657
658
659
660
661
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
662
            replacement_loras: list[LoRALayerWeights | None] = []
663
            replaced_module: set[str] = set()
664
665
            has_replacement = False
            for r in new_module_names:
666
                lora = self._get_lora_layer_weights(lora_model, r)
667
668
669
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
670
                    replaced_module.add(r)
671
672
673
674
675
676
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
677
            # HACK Temporary solution for the pool model.
678
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
679
680
681
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
682
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
683
684
                replacement_loras
            )
685
686
687
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
688

689
    def _get_lora_layer_weights(
690
        self, lora_model: LoRAModel, module_name: str
691
    ) -> LoRALayerWeights | None:
692
        org_module_name = module_name
693
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
694
695
696
697
698
699
700
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
701
702
                    "after removing the prefix 'model.'."
                )
703
704
        return lora_model.get_lora(org_module_name)

705
    def deactivate_adapter(self, adapter_id: int) -> bool:
706
707
708
709
710
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
711
712

    def add_adapter(self, adapter: LoRAModel) -> bool:
713
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
714
715
716
717
718
719
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
720

721
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
722
723
724
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
725
726

    def remove_adapter(self, adapter_id: int) -> bool:
727
728
729
730
731
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
732

733
734
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
735

736
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
737
        return self._registered_adapters.get(adapter_id)
738
739
740


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
741
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
742
        super().__init__(capacity, deactivate_lora_fn)
743
744
745
746
747


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

748
749
750
751
752
753
754
755
756
757
758
759
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
760
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
761
762
            self.capacity, self.deactivate_adapter
        )
763
        self._active_adapters: LoRALRUCache = LoRALRUCache(
764
765
            self.lora_slots, self._deactivate_adapter
        )
766

767
    def list_adapters(self) -> dict[int, LoRAModel]:
768
        """List all registered LoRAModels."""
769
        return dict(self._registered_adapters.cache)
770

771
    def add_adapter(self, lora: LoRAModel) -> bool:
772
        """Add a LoRAModel to the manager."""
773
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
774
775
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
776
777
778
            was_added = True
        else:
            # We always touch to update the LRU cache order
779
            self._registered_adapters.touch(lora.id)
780
781
782
            was_added = False
        return was_added

783
    def activate_adapter(
784
785
786
        self,
        lora_id: int,
    ) -> bool:
787
788
789
790
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
791
792
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
793
        # We always touch to update the LRU cache order
794
        self._active_adapters.touch(lora_id)
795
796
        return result

797
798
799
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
800
801
802
            return True
        return False

803
    def pin_adapter(self, lora_id: int) -> bool:
804
805
806
807
808
809
810
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
811
            self._registered_adapters.pin(lora_id)
812
        except ValueError as err:
813
814
815
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
816
817

    def _pin_lora_in_gpu_cache(self, lora_id: int):
818
        if lora_id not in self._active_adapters:
819
            # move lora to gpu if not already active
820
            self.activate_adapter(lora_id)
821

822
        self._active_adapters.pin(lora_id)
823

824
825

def create_lora_manager(
826
827
828
829
830
831
832
833
834
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
835
    """Create a LoRA adapter for a given model."""
836
    if not isinstance(model, SupportsLoRA):
837
838
839
840
841
842
843
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
844
        device=device,
845
846
        **kwargs,
    )
847
    return lora_manager