models.py 32.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import math
import os
6
from typing import Callable, Optional, TypeVar, Union
7

8
import regex as re
9
10
11
12
import safetensors.torch
import torch
from torch import nn

13
from vllm.config.lora import LoRAConfig
14
from vllm.logger import init_logger
15
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
16
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
17
from vllm.lora.peft_helper import PEFTHelper
18
from vllm.lora.punica_wrapper import get_punica_wrapper
19
20
21
22
23
24
25
26
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
    is_regex_target_modules,
    parse_fine_tuned_lora_name,
    replace_submodule,
)
27
from vllm.model_executor.layers.fused_moe import FusedMoE
28
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
29
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
30
from vllm.model_executor.models.interfaces import is_pooling_model
31
from vllm.model_executor.models.module_mapping import MultiModelKeys
32
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
33
from vllm.model_executor.utils import get_packed_modules_mapping
34
35
from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache
36

37
logger = init_logger(__name__)
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
T = TypeVar("T")


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

    def _on_remove(self, key: int, value: Optional[T]):
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


53
54
55
56
57
58
59
60
61
_GLOBAL_LORA_ID = 0


def get_lora_id():
    global _GLOBAL_LORA_ID
    _GLOBAL_LORA_ID += 1
    return _GLOBAL_LORA_ID


62
63
64
65
66
67
def is_moe_model(model: nn.Module) -> bool:
    """Checks if the model contains FusedMoE layers and warns the user."""
    if any(isinstance(module, FusedMoE) for module in model.modules()):
        logger.warning_once(
            "For MoE models, vLLM currently does not support fused MoE LoRA "
            "inference. Please ensure that the loaded LoRA model does not "
68
69
            "contain expert weights."
        )
70
71
72
73
        return True
    return False


74
class LoRAModel:
75
76
77
78
79
80
    """A LoRA fine-tuned model."""

    def __init__(
        self,
        lora_model_id: int,
        rank: int,
81
        loras: dict[str, LoRALayerWeights],
82
    ) -> None:
83
84
85
86
87
        """
        Args:
            lora_model_id: The integer id for the lora model.
            rank: lora rank.
            loras: module name -> weights for lora-replaced layers.
88

89
        """
90
        self.id = lora_model_id
91

92
93
94
        assert lora_model_id > 0, (
            f"a valid lora id should be greater than 0, got {self.id}"
        )
95
        self.rank = rank
96
        self.loras: dict[str, LoRALayerWeights] = loras
97

98
99
100
101
102
103
104
105
106
107
    def clone(self, lora_model_id: int) -> "LoRAModel":
        """Return a copy of the object with different ids.

        Will share the underlying tensors."""
        return self.__class__(
            lora_model_id,
            rank=self.rank,
            loras=self.loras.copy(),
        )

108
109
    @property
    def extra_vocab_size(self) -> int:
110
111
112
113
114
        return (
            max(lora.extra_vocab_size for lora in self.loras.values())
            if self.loras
            else 0
        )
115
116
117
118
119

    def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
        """Get LoRA for a given module by name"""
        return self.loras.get(module_name, None)

120
121
122
    def check_lora_name(self, lora_name: str) -> bool:
        return lora_name in self.loras

123
124
125
126
127
    # (yard1): TODO see if we can derive target_embedding_padding automatically
    @classmethod
    def from_lora_tensors(
        cls,
        lora_model_id: int,
128
        tensors: dict[str, torch.Tensor],
129
        peft_helper: PEFTHelper,
130
131
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
132
        embeddings: Optional[dict[str, torch.Tensor]] = None,
133
        target_embedding_padding: Optional[int] = None,
134
135
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
136
        weights_mapper: Optional[WeightsMapper] = None,
137
138
    ) -> "LoRAModel":
        """Create a LoRAModel from a dictionary of tensors."""
139
        pin_memory = str(device) == "cpu" and is_pin_memory_available()
140
        loras: dict[str, LoRALayerWeights] = {}
141
        for tensor_name, tensor in tensors.items():
142
            module_name, is_lora_a = parse_fine_tuned_lora_name(
143
144
                tensor_name, weights_mapper
            )
145
146
147
            if module_name not in loras:
                lora_embeddings_tensor = None
                if embeddings:
148
                    assert embedding_modules is not None
149
                    embeddings_module = next(
150
151
                        (k for k in embedding_modules if k in module_name), None
                    )
152
153
                    if embeddings_module:
                        lora_embeddings_tensor = embeddings[
154
155
                            embedding_modules[embeddings_module]
                        ].to(device=device, dtype=dtype)
156
                        if pin_memory:
157
                            lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
158
                loras[module_name] = LoRALayerWeights.from_config(
159
160
                    module_name, peft_helper, lora_embeddings_tensor
                )
161

162
            if is_lora_a:
163
                loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
164
                if pin_memory:
165
                    loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
166
            else:
167
                loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
168
                assert embedding_padding_modules is not None
169
170
171
172
                if (
                    any(name in module_name for name in embedding_padding_modules)
                    and target_embedding_padding is not None
                ):
173
                    lora_b = loras[module_name].lora_b
174
175
                    assert target_embedding_padding >= lora_b.shape[0]
                    addition = target_embedding_padding - lora_b.shape[0]
176
                    loras[module_name].lora_b = torch.nn.functional.pad(
177
178
                        lora_b, (0, 0, 0, addition)
                    )
179
                if pin_memory:
180
                    loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
181
182
183

        for lora in loras.values():
            lora.optimize()
184

185
        return cls(lora_model_id, peft_helper.r, loras)
186
187
188

    @classmethod
    def from_local_checkpoint(
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        cls,
        lora_dir: str,
        expected_lora_modules: list[str],
        peft_helper: PEFTHelper,
        *,
        lora_model_id: Optional[int] = None,
        device: str = "cuda",
        dtype: Optional[torch.dtype] = None,
        target_embedding_padding: Optional[int] = None,
        embedding_modules: Optional[dict[str, str]] = None,
        embedding_padding_modules: Optional[list[str]] = None,
        weights_mapper: Optional[WeightsMapper] = None,
        tensorizer_config_dict: Optional[dict] = None,
    ) -> "LoRAModel":
203
        """Create a LoRAModel from a local checkpoint.
204

205
206
207
208
        Args:
            lora_dir: The local path that has lora data.
            expected_lora_modules: Name of modules that are expected to be
                replaced by lora.
209
            peft_helper: Loaded lora configuration information.
210
            lora_model_id: LoRA model id. If not given, automatically set by
211
212
213
214
215
216
217
                a global counter.
            device: Device where the lora model is loaded.
            dtype: dtype of the lora model weights.

        Returns:
            Loaded LoRA Model.
        """
218
219
        lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
        lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
220
        lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
221
        new_embeddings_tensor_path = os.path.join(
222
223
224
            lora_dir, "new_embeddings.safetensors"
        )
        new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
225
226
227
228
229
        tensors: dict[str, torch.Tensor] = {}
        unexpected_modules: list[Union[list[str], str]] = []

        def check_unexpected_modules(modules: dict):
            for lora_module in modules.keys():  # noqa
230
                module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
231
232
233
234
235
236
237
238
                part_name = module_name.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module_name)
            if unexpected_modules:
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
239
240
                    f" Please verify that the loaded LoRA module is correct"
                )
241
242
243
244
245

        if tensorizer_config_dict:
            from tensorizer import TensorDeserializer

            tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
246
247
248
            lora_tensor_path = os.path.join(
                tensorizer_config.tensorizer_dir, "adapter_model.tensors"
            )
249
            tensorizer_args = tensorizer_config._construct_tensorizer_args()
250
251
252
            tensors = TensorDeserializer(
                lora_tensor_path,
                dtype=tensorizer_config.dtype,
253
254
                **tensorizer_args.deserialization_kwargs,
            )
255
            check_unexpected_modules(tensors)
256

257
        elif os.path.isfile(lora_tensor_path):
258
259
260
261
262
263
264
            # Find unexpected modules.
            # Use safetensor key as a source of truth to find expected modules.
            # in peft if you have target_modules A, B, C and C does not exist
            # in the model it won’t error and model will be trained with A, B
            # loraified. C won’t exist in the safetensor but it will exist in
            # the target_modules of the adapter_config.json.
            unexpected_modules = []
265
            with safetensors.safe_open(lora_tensor_path, framework="pt") as f:  # type: ignore
266
                # Load tensors if there are only expected modules.
267
                check_unexpected_modules(f)
268
269
                for module in f.keys():  # noqa
                    tensors[module] = f.get_tensor(module)
270
        elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
271
272
            # When a bin/pt file is provided, we rely on config to find
            # unexpected modules.
273
            unexpected_modules = []
274
            target_modules = peft_helper.target_modules
275
276
            if not isinstance(target_modules, list):
                target_modules = [target_modules]
277
278
279
280
281
282
283
284
285
286
            for module in target_modules:
                # Compatible with more modules,
                # such as:layers.11.self_attn.k_proj
                part_name = module.split(".")[-1]
                if part_name not in expected_lora_modules:
                    unexpected_modules.append(module)
            # loaded lora's target modules must be a subset of
            # expected_lora_modules. It is not reliable. See
            # https://github.com/vllm-project/vllm/pull/5909. But there's no
            # other better mechanism.
287
            if unexpected_modules and not is_regex_target_modules(
288
289
                peft_helper.target_modules, expected_lora_modules
            ):
290
291
292
293
                raise ValueError(
                    f"While loading {lora_dir}, expected"
                    f" target modules in {expected_lora_modules}"
                    f" but received {unexpected_modules}."
294
295
296
297
298
299
300
301
                    f" Please verify that the loaded LoRA module is correct"
                )
            lora_file_path = (
                lora_bin_file_path
                if os.path.isfile(lora_bin_file_path)
                else lora_pt_file_path
            )
            tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
302
303
304
305
306
        else:
            raise ValueError(f"{lora_dir} doesn't contain tensors")

        embeddings = None
        if os.path.isfile(new_embeddings_tensor_path):
307
            embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
308
        elif os.path.isfile(new_embeddings_bin_file_path):
309
310
311
            embeddings = torch.load(
                new_embeddings_bin_file_path, map_location=device, weights_only=True
            )
312
313

        return cls.from_lora_tensors(
314
            lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
315
            tensors=tensors,
316
            peft_helper=peft_helper,
317
318
319
320
            device=device,
            dtype=dtype,
            embeddings=embeddings,
            target_embedding_padding=target_embedding_padding,
Terry's avatar
Terry committed
321
            embedding_modules=embedding_modules,
322
            embedding_padding_modules=embedding_padding_modules,
323
324
            weights_mapper=weights_mapper,
        )
325
326


327
class LoRAModelManager:
328
329
330
331
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
332
        model: SupportsLoRA,
333
334
335
336
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
337
        device: torch.device,
338
339
340
341
342
343
344
345
346
347
348
349
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
350
351
352
353
354
        self.model: SupportsLoRA = model
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
355
        self.lora_config = lora_config
356
        self.device = device
357
358
359
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
360
        self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
361
        self.vocab_size = vocab_size
362
363
364
365
        self.punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
366
367
            max_loras=self.lora_config.max_loras,
        )
368

369
370
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, "No supported LoRA modules found in"
371
        f" {self.model.__class__.__name__}."
372
373

        self.packed_modules_mapping = get_packed_modules_mapping(self.model)
374
        # Used to indicate whether the model is a multimodal model
375
376
377
378
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
379
380
            and hasattr(self.model, "get_mm_mapping")
        )
381
        self.is_pooling_model = is_pooling_model(self.model)
382
        self.is_moe_model = is_moe_model(self.model)
383
384
385
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
386
        self._last_mapping: Optional[LoRAMapping] = None
387
        self._create_lora_modules()
388
        self.model.lora_manager = self
389
390
391

    def __len__(self) -> int:
        return len(self._registered_adapters)
392
393
394
395
396
397
398
399
400

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

401
402
403
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
404

405
    def activate_adapter(
406
407
408
409
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
410
        if lora_id in self._active_adapters:
411
412
            return False
        first_free_slot = next(
413
414
415
416
417
418
419
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
420
421
422
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
423
424
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
425
426
427
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
428
429
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
430
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
431
432
            if module_lora:
                module_lora.optimize()
433
434
435
436
437
438
                module.set_lora(
                    index,
                    module_lora.lora_a,
                    module_lora.lora_b,
                    module_lora.embeddings_tensor,
                )
439
440
441
442
            else:
                module.reset_lora(index)
        return True

443
    def _deactivate_adapter(self, lora_id: int):
444
445
446
447
448
449
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

450
    def _add_adapter(self, lora: LoRAModel):
451
        self._create_merged_loras_inplace(lora)
452
        self._registered_adapters[lora.id] = lora
453

454
    def pin_adapter(self, lora_id: int) -> bool:
455
456
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
457
            "Pinning is not supported in LoRAModelManager. "
458
459
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
460

461
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
462
463
464
465
466
467
468
469
        # update lora states
        self.punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
            self.lora_config.lora_extra_vocab_size,
        )
470

471
    def remove_all_adapters(self):
472
        """Remove all LoRAModels from the manager."""
473
        self._registered_adapters.clear()
474
        self.lora_index_to_id = [None] * self.lora_slots
475
        self._active_adapters.clear()
476
477

    def _create_lora_modules(self):
478
479
480
481
482
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
483
            return module_name.rpartition(".")[0]
484

485
        for module_name, module in self.model.named_modules(remove_duplicate=False):
486
487
            if isinstance(module, PPMissingLayer):
                continue
488
489
            if not self._match_target_modules(module_name):
                continue
490
491
492
493
494
495
496
497
498
            # A temporary approach for multimodal models to support LoRA
            # TODO: Remove this restriction
            if self._filter_unsupported_mm_module(module_name):
                logger.warning(
                    "Regarding multimodal models, vLLM currently only supports "
                    "adding LoRA to language model, %s will be ignored.",
                    module_name,
                )
                continue
499
500
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
501
            new_module = replace_submodule(
502
503
504
505
506
507
508
509
510
511
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
512

513
514
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
515
                logits_processor_module_name = "logits_processor"
516
517
518
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
519
520
                        f"{parent_module}.{logits_processor_module_name}"
                    )
521

522
                logits_processor_module = self.model.get_submodule(
523
524
                    logits_processor_module_name
                )
525

526
                new_module = replace_submodule(
527
528
529
530
531
532
533
534
535
536
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
537
538
539
540
541
542

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
543
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
544
                continue
545
546
            self.register_module(module_name, new_module)
            self._register_packed_modules(module_name)
547
548
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(self.punica_wrapper)
549
550
551
552
553

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA)
        self.modules[module_name] = module

Terry's avatar
Terry committed
554
    def create_dummy_lora(
555
556
557
558
559
        self,
        lora_id: int,
        rank: int,
        embedding_modules: Optional[dict[str, str]] = None,
    ) -> LoRAModel:
560
        """Create zero-initialized LoRAModel for warmup."""
561
        model = LoRAModel(lora_id, rank, {})
562
        for module_name, module in self.model.named_modules():
563
564
565
566
567
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._filter_unsupported_mm_module(module_name)
            ):
568
569
570
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
571
                assert embedding_modules is not None
Terry's avatar
Terry committed
572
                if parts[-1] in embedding_modules:
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
                    input_dim = (
                        module.base_layer.org_vocab_size
                        + self.lora_config.lora_extra_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
                    embeddings_tensor_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[1]
                    )
589
590
591
592
593
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
594
                        module.lora_a_stacked[0].dtype,
595
                        "cpu",
596
                        embeddings_tensor_dim=embeddings_tensor_dim,
597
                    )
598
599
600
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
601
602
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
603
                        rank,
604
                        module.lora_a_stacked[0].dtype,
605
606
607
608
609
                        "cpu",
                    )
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
610
                subloras: list[Optional[LoRALayerWeights]] = []
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
        return model

    def _match_target_modules(self, module_name: str):
        return any(
            re.match(
628
629
630
631
632
                r".*\.{target_module}$".format(target_module=target_module), module_name
            )
            or target_module == module_name
            for target_module in self.supported_lora_modules
        )
633

634
635
636
    def _filter_unsupported_mm_module(self, module_name: str) -> bool:
        """
        Regarding multimodal models, vLLM currently only supports adding LoRA to
637
        language model. LoRA for other modules, such as the vision tower, will
638
639
640
641
        be filtered out.
        """
        if self.supports_mm:
            module_mapping: MultiModelKeys = self.model.get_mm_mapping()
642
            prefix_lst = module_mapping.connector + module_mapping.tower_model
643
            return any([module_name.startswith(prefix) for prefix in prefix_lst])
644
645
        return False

646
647
648
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
649
650
651
652
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
653
654
655
656
657
658
659
660
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
661
662
            replacement_loras: list[Optional[LoRALayerWeights]] = []
            replaced_module: set[str] = set()
663
664
            has_replacement = False
            for r in new_module_names:
665
                lora = self._get_lora_layer_weights(lora_model, r)
666
667
668
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
669
                    replaced_module.add(r)
670
671
672
673
674
675
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
676
            # HACK Temporary solution for the pool model.
677
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
678
679
680
                replaced_module_name = module_name.replace("model.", "")
                if lora_model.check_lora_name(module_name):
                    module_name = replaced_module_name
681
            lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
682
683
                replacement_loras
            )
684
685
686
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
687

688
    def _get_lora_layer_weights(
689
690
        self, lora_model: LoRAModel, module_name: str
    ) -> Optional[LoRALayerWeights]:
691
        org_module_name = module_name
692
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
693
694
695
696
697
698
699
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.replace("model.", "")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
700
701
                    "after removing the prefix 'model.'."
                )
702
703
        return lora_model.get_lora(org_module_name)

704
    def deactivate_adapter(self, adapter_id: int) -> bool:
705
706
707
708
709
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
710
711

    def add_adapter(self, adapter: LoRAModel) -> bool:
712
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
713
714
715
716
717
718
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
719

720
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
721
722
723
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
724
725

    def remove_adapter(self, adapter_id: int) -> bool:
726
727
728
729
730
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
731

732
733
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
734

735
736
    def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]:
        return self._registered_adapters.get(adapter_id)
737
738
739


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
740
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
741
        super().__init__(capacity, deactivate_lora_fn)
742
743
744
745
746


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

747
748
749
750
751
752
753
754
755
756
757
758
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
    ):
        super().__init__(
            model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
        )
759
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
760
761
            self.capacity, self.deactivate_adapter
        )
762
        self._active_adapters: LoRALRUCache = LoRALRUCache(
763
764
            self.lora_slots, self._deactivate_adapter
        )
765

766
    def list_adapters(self) -> dict[int, LoRAModel]:
767
        """List all registered LoRAModels."""
768
        return dict(self._registered_adapters.cache)
769

770
    def add_adapter(self, lora: LoRAModel) -> bool:
771
        """Add a LoRAModel to the manager."""
772
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
773
774
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
775
776
777
            was_added = True
        else:
            # We always touch to update the LRU cache order
778
            self._registered_adapters.touch(lora.id)
779
780
781
            was_added = False
        return was_added

782
    def activate_adapter(
783
784
785
        self,
        lora_id: int,
    ) -> bool:
786
787
788
789
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
790
791
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
792
        # We always touch to update the LRU cache order
793
        self._active_adapters.touch(lora_id)
794
795
        return result

796
797
798
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
799
800
801
            return True
        return False

802
    def pin_adapter(self, lora_id: int) -> bool:
803
804
805
806
807
808
809
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
810
            self._registered_adapters.pin(lora_id)
811
        except ValueError as err:
812
813
814
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
815
816

    def _pin_lora_in_gpu_cache(self, lora_id: int):
817
        if lora_id not in self._active_adapters:
818
            # move lora to gpu if not already active
819
            self.activate_adapter(lora_id)
820

821
        self._active_adapters.pin(lora_id)
822

823
824

def create_lora_manager(
825
826
827
828
829
830
831
832
833
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
834
    """Create a LoRA adapter for a given model."""
835
    if not isinstance(model, SupportsLoRA):
836
837
838
839
840
841
842
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
843
        device=device,
844
845
        **kwargs,
    )
846
    return lora_manager