model_manager.py 36 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
6
from collections.abc import Callable
from typing import TypeVar
7
8
9
10

import torch
from torch import nn

11
from vllm.config import VllmConfig
12
from vllm.config.lora import LoRAConfig
13
from vllm.logger import init_logger
14
15
16
17
18
19
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    FusedMoE3DWithLoRA,
    LoRAMapping,
    LoRAMappingType,
)
20
from vllm.lora.lora_model import LoRAModel
21
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
22
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
23
24
25
26
from vllm.lora.utils import (
    from_layer,
    from_layer_logits_processor,
    get_supported_lora_modules,
27
    is_in_target_modules,
28
    is_moe_model,
29
    is_supported_lora_module,
30
    process_packed_modules_mapping,
31
32
    replace_submodule,
)
33
from vllm.model_executor.layers.fused_moe import FusedMoE
34
35
36
37
38
from vllm.model_executor.models import (
    SupportsLoRA,
    is_pooling_model,
    supports_multimodal,
)
39
from vllm.model_executor.models.module_mapping import MultiModelKeys
40
from vllm.model_executor.models.utils import PPMissingLayer
41
from vllm.multimodal import MULTIMODAL_REGISTRY
42
from vllm.multimodal.encoder_budget import MultiModalBudget
43
from vllm.utils.cache import LRUCache
44
from vllm.utils.platform_utils import is_pin_memory_available
45

46
logger = init_logger(__name__)
47

48
T = TypeVar("T")
49
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
50
51
52
53
54
55
56


class AdapterLRUCache(LRUCache[int, T]):
    def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
        super().__init__(capacity)
        self.deactivate_fn = deactivate_fn

57
    def _on_remove(self, key: int, value: T | None):
58
59
60
61
62
63
        logger.debug("Removing adapter int id: %d", key)
        self.deactivate_fn(key)
        return super()._on_remove(key, value)


class LoRAModelManager:
64
65
66
67
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
68
        model: SupportsLoRA,
69
70
71
72
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
73
        device: torch.device,
74
        vllm_config: VllmConfig | None = None,
75
76
77
78
79
80
81
82
83
84
85
86
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
87
        self.model: SupportsLoRA = model
88
89
90
91
92
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, (
            f"No supported LoRA modules found in {self.model.__class__.__name__}."
        )

93
94
95
96
        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
97
        self.lora_config = lora_config
98
        self.device = device
99
100
101
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
102
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
103
        self.vocab_size = vocab_size
104
        self.packed_modules_mapping = process_packed_modules_mapping(self.model)
105

106
        self.is_pooling_model = is_pooling_model(self.model)
107
108
109
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
110
        self._last_mapping: LoRAMapping | None = None
111
112
113
        is_moe = is_moe_model(self.model)
        self._is_3d_moe_model = is_moe and self.model.is_3d_moe_weight
        self._is_non_gated_moe = is_moe and self.model.is_non_gated_moe
114
        self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
115
        self._create_lora_modules()
116

117
        self.model.lora_manager = self
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def _init_punica_wrapper(
        self, max_num_batched_tokens: int, vllm_config: VllmConfig
    ) -> None:
        # Used to indicate whether the model is a multimodal model
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping")
        )
        self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
        if self.supports_mm:
            self._maybe_init_mm(vllm_config, max_num_batched_tokens)
        else:
            llm_punica_wrapper = get_punica_wrapper(
                max_num_batched_tokens,
                max_batches=self.max_num_seqs,
                device=self.device,
137
                lora_config=self.lora_config,
138
139
140
141
142
143
            )

            self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
                llm_punica_wrapper
            )

144
145
146
147
148
149
150
151
    def _maybe_init_mm(
        self,
        vllm_config: VllmConfig,
        max_num_batched_tokens: int,
    ) -> None:
        mm_registry = MULTIMODAL_REGISTRY

        self.supports_tower_connector_lora = False
152
153
154
155
156
157
158
159
160
161
        self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()

        # Only one language model can be included in the model.
        assert len(self.mm_mapping.language_model) == 1

        # Language model punica wrapper
        llm_punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
162
            lora_config=self.lora_config,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        )
        lm_prefix = self.mm_mapping.language_model[0]
        self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper

        if self.lora_config.enable_tower_connector_lora:
            self.supports_tower_connector_lora = self.supports_mm and hasattr(
                self.model, "get_num_mm_encoder_tokens"
            )
        if not self.supports_tower_connector_lora:
            return

        logger.warning(
            "LoRA for the tower and connector of multimodal models is "
            "experimental and may contain bugs. Please report any related issues on "
            "GitHub if you encounter them."
        )

180
        mm_budget = MultiModalBudget(vllm_config, mm_registry)
181
        limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values())
182
183
184
185
186
187
188
189
190
        num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
            mm_budget.get_encoder_budget()
        )

        # Tower wrappers
        tower_punica_wrapper = get_punica_wrapper(
            num_encoder_tokens,
            max_batches=self.max_num_seqs * limit_per_prompt,
            device=self.device,
191
            lora_config=self.lora_config,
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        )
        for prefix in self.mm_mapping.tower_model:
            self.punica_wrapper_mapping[prefix] = tower_punica_wrapper

        # Use wrapper for connector if present.
        if self.mm_mapping.connector:
            if hasattr(self.model, "get_num_mm_connector_tokens"):
                connector_tokens = self.model.get_num_mm_connector_tokens(
                    num_encoder_tokens
                )
                connector_punica_wrapper = get_punica_wrapper(
                    connector_tokens,
                    max_batches=self.max_num_seqs * limit_per_prompt,
                    device=self.device,
206
                    lora_config=self.lora_config,
207
208
209
210
211
212
213
214
215
216
                )
                for prefix in self.mm_mapping.connector:
                    self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
            else:
                logger.warning_once(
                    "Connector LoRA support disabled: model does not implement "
                    "get_num_mm_connector_tokens(). This method is required to "
                    "determine the connector's token budget for LoRA operations."
                )

217
218
    def __len__(self) -> int:
        return len(self._registered_adapters)
219
220
221
222
223
224
225
226
227

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

228
229
230
    @property
    def adapter_slots(self) -> int:
        return self.lora_slots
231

232
    def activate_adapter(
233
234
235
236
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
237
        if lora_id in self._active_adapters:
238
239
            return False
        first_free_slot = next(
240
241
242
243
244
245
246
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
247
248
249
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
250
251
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
252
253
254
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
255
256
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
257
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
258
259
260
            if not module_lora:
                module.reset_lora(index)
                continue
261

262
263
264
265
266
            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
267

268
269
        return True

270
    def _deactivate_adapter(self, lora_id: int):
271
272
273
274
275
276
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

277
    def _add_adapter(self, lora: LoRAModel):
278
        self._create_merged_loras_inplace(lora)
279
        self._registered_adapters[lora.id] = lora
280

281
    def pin_adapter(self, lora_id: int) -> bool:
282
283
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
284
            "Pinning is not supported in LoRAModelManager. "
285
286
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore
287

288
    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        # Default to the main language model wrapper
        if not (self.supports_mm and self.supports_tower_connector_lora):
            target_prefix = (
                self.mm_mapping.language_model[0]
                if self.supports_mm
                else DEFAULT_LANGUAGE_WRAPPER_KEY
            )
        elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
            target_prefix = self.mm_mapping.tower_model[0]
        elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
            target_prefix = self.mm_mapping.connector[0]
        else:
            target_prefix = self.mm_mapping.language_model[0]

        punica_wrapper = self._get_punica_wrapper(target_prefix)
        assert punica_wrapper is not None

        punica_wrapper.update_metadata(
307
308
309
310
311
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )
312

313
    def remove_all_adapters(self):
314
        """Remove all LoRAModels from the manager."""
315
        self._registered_adapters.clear()
316
        self.lora_index_to_id = [None] * self.lora_slots
317
        self._active_adapters.clear()
318
319

    def _create_lora_modules(self):
320
321
322
323
324
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
325
            return module_name.rpartition(".")[0]
326

327
        for module_name, module in self.model.named_modules(remove_duplicate=False):
328
329
            if isinstance(module, PPMissingLayer):
                continue
330

331
332
            if not self._match_target_modules(module_name):
                continue
333
334
335

            punica_wrapper = self._get_punica_wrapper(module_name)
            if punica_wrapper is None:
336
                logger.warning(
337
338
339
                    "Regarding %s, vLLM currently only supports adding LoRA to"
                    " language model, %s will be ignored.",
                    self.model.__class__.__name__,
340
341
342
                    module_name,
                )
                continue
343

344
345
346
347
348
349
350
351
352
353
354
355
356
357
            # TODO: Remove this restriction
            # peft error when generating LoRA adapter with "gate" module:
            # "Target module NemotronHTopkRouter() is not supported."
            # Working LoRA adapter was created using peft with:
            # LoraConfig(target_modules="all-linear", ...)
            if self._is_non_gated_moe and module_name.endswith("mixer.gate"):
                logger.debug_once(
                    "LoRA is not supported for non-gated MoE gate module."
                    " %s will be ignored.",
                    module_name,
                    scope="local",
                )
                continue

358
359
            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
360
361
362
363
364
365
366
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
367
            new_module = replace_submodule(
368
369
370
371
372
373
374
375
376
377
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
378

379
380
            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
381
                logits_processor_module_name = "logits_processor"
382
383
384
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
385
386
                        f"{parent_module}.{logits_processor_module_name}"
                    )
387

388
                logits_processor_module = self.model.get_submodule(
389
390
                    logits_processor_module_name
                )
391

392
                new_module = replace_submodule(
393
394
395
396
397
398
399
400
401
402
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )
403
404
405
406
407
408

            # In some models, especially multimodal ones, layers with the same
            # name may have different types, such as nn.Linear and
            # ReplicatedLinear. The nn.Linear layers cannot be replaced with
            # LoRA layers, leading to assertion error. The following check
            # aims to prevent this error
409
            if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
410
                continue
411
            self.register_module(module_name, new_module)
412

413
            self._register_packed_modules(module_name)
414
            # All lora layers share the same punica_wrapper based on reference.
415
            new_module.set_mapping(punica_wrapper)
416
417

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
418
        assert isinstance(module, BaseLayerWithLoRA), (
419
420
            f"Module {module_name} must be a BaseLayerWithLoRA instance, "
            f"got {type(module)}"
421
        )
422
423
        self.modules[module_name] = module

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    @staticmethod
    def _pad_lora_pairs_to_triplets(
        loras: list[LoRALayerWeights | None],
    ) -> list[LoRALayerWeights | None]:
        """Pad LoRA weight pairs to triplets for non-gated MoE.

        For non-gated MoE, each expert has 2 entries (w1, w2) that need to be
        padded to triplets (w1, w2, None) to match pack_moe expectations.
        """
        assert len(loras) % 2 == 0, "Expected pairs of LoRA weights for non-gated MoE."
        padded: list[LoRALayerWeights | None] = []
        for i in range(0, len(loras), 2):
            padded.extend(loras[i : i + 2])
            padded.append(None)
        return padded

Terry's avatar
Terry committed
440
    def create_dummy_lora(
441
442
443
        self,
        lora_id: int,
        rank: int,
444
        embedding_modules: dict[str, str] | None = None,
445
    ) -> LoRAModel:
446
        """Create zero-initialized LoRAModel for warmup."""
447
        model = LoRAModel(lora_id, rank, {})
448
        for module_name, module in self.model.named_modules():
449
450
451
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
452
                or self._get_punica_wrapper(module_name) is None
453
            ):
454
455
456
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
457
                assert embedding_modules is not None
Terry's avatar
Terry committed
458
                if parts[-1] in embedding_modules:
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
                    # Special-case lm_head: wrapped by LogitsProcessorWithLoRA.
                    # LoRA input dim is hidden_size, output dim is vocab size.
                    # LogitsProcessorWithLoRA handles extra vocab size directly.
                    if parts[-1] == "lm_head":
                        input_dim = module.lora_a_stacked[0].shape[-1]
                        output_dim = module.lora_b_stacked[0].shape[-2]
                    else:
                        input_dim = (
                            module.base_layer.org_vocab_size
                            if hasattr(module.base_layer, "org_vocab_size")
                            else module.base_layer.weight.shape[1]
                        )
                        output_dim = (
                            module.base_layer.embedding_dim
                            if hasattr(module.base_layer, "embedding_dim")
                            else module.base_layer.weight.shape[0]
                        )
476
477
478
479
480
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
481
                        module.lora_a_stacked[0].dtype,
482
                        "cpu",
483
                    )
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
508
509
510
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
511
512
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
513
                        rank,
514
                        module.lora_a_stacked[0].dtype,
515
516
                        "cpu",
                    )
517
                    model.loras[module_name] = lora
518
519
520
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
521
                subloras: list[LoRALayerWeights | None] = []
522
523
524
525
526
527
528
529
530
531
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
532
                if module.__class__.__name__ == "FusedMoEWithLoRA":
533
534
535
536
537
538
539
                    # For non-gated MoE, pad subloras to 3 elements per expert
                    # to match pack_moe expectations (w1, w2, None for w3)
                    if self._is_non_gated_moe and len(subloras) > 0:
                        subloras = self._pad_lora_pairs_to_triplets(subloras)
                    lora = PackedLoRALayerWeights.pack_moe(
                        subloras, module_name, is_non_gated_moe=self._is_non_gated_moe
                    )
540
541
                else:
                    lora = PackedLoRALayerWeights.pack(subloras)
542
                model.loras[module_name] = lora
543
544
        return model

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    def _match_target_modules(self, module_name: str) -> bool:
        """Check if a module should have LoRA applied.

        This method first checks if the module is in vLLM's supported LoRA
        modules, then applies deployment-time restrictions based on
        LoRAConfig.target_modules.

        Args:
            module_name: Full dot-separated module name (e.g.,
                "model.layers.0.self_attn.o_proj")

        Returns:
            True if LoRA should be applied to this module, False otherwise.
        """
        if not is_supported_lora_module(module_name, self.supported_lora_modules):
            return False
        return is_in_target_modules(module_name, self.lora_config.target_modules)
562

563
    def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
564
        """
565
        Determine whether this module supports LoRA and which wrapper to use.
566
        """
567
568
569
570
571
572
573
574
575
576
577
578
        # For language model (early return)
        if not self.supports_mm:
            return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]

        # For multimodal model
        # NOTE Sort by prefix length (descending) to match the longest prefix first
        # e.g., 'visual.merger' should match 'visual.merger' instead of 'visual.'
        for prefix in sorted(self.punica_wrapper_mapping.keys(), key=len, reverse=True):
            if module_name.startswith(prefix):
                return self.punica_wrapper_mapping[prefix]

        return None
579

580
581
582
    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
583
584
585
586
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
587
588
589
590
591
592
593
594
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
595
            replacement_loras: list[LoRALayerWeights | None] = []
596
            replaced_module: set[str] = set()
597
598
            has_replacement = False
            for r in new_module_names:
599
                lora = self._get_lora_layer_weights(lora_model, r)
600
601
602
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
603
                    replaced_module.add(r)
604
605
606
607
608
609
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
610
            # HACK Temporary solution for the pool model.
611
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
612
613
                replaced_module_name = module_name.removeprefix("model.")
                if lora_model.check_lora_name(replaced_module_name):
614
                    module_name = replaced_module_name
615
            if module_name.endswith(".experts"):
616
617
618
619
                if self._is_non_gated_moe and len(replacement_loras) > 0:
                    replacement_loras = self._pad_lora_pairs_to_triplets(
                        replacement_loras
                    )
620
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
621
622
623
                    replacement_loras,
                    module_name,
                    is_non_gated_moe=self._is_non_gated_moe,
624
625
626
627
628
                )
            else:
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                    replacement_loras
                )
629
630
631
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)
632

633
634
635
        for lora in lora_model.loras.values():
            lora.optimize()

636
637
638
639
        for module_name, module in self.modules.items():
            if isinstance(module, FusedMoE3DWithLoRA):
                self._stack_moe_lora_weights(lora_model, module, module_name)

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
        first_lora: LoRALayerWeights = next(iter(lora_model.loras.values()))
        assert first_lora.lora_a is not None
        if isinstance(first_lora.lora_a, list):
            lora_device = next(iter(first_lora.lora_a))
        else:
            lora_device = first_lora.lora_a.device
        # Execute pin_memory after LoRA weight merging, mainly because:
        # 1. Some MoE models have a large number of LoRA weights. If we
        # perform # pin_memory immediately after loading weights, the
        # overhead is significant.
        # 2. The weight packing above (e.g., pack_moe) may invalidate the
        # pin_memory allocation, so we execute it after packing.

        pin_memory = str(lora_device) == "cpu" and is_pin_memory_available()
        if pin_memory:
            for lora in lora_model.loras.values():
                if isinstance(lora.lora_a, list):
                    for index in range(len(lora.lora_a)):
                        if lora.lora_a[index] is None:
                            continue
                        lora.lora_a[index] = lora.lora_a[index].pin_memory()
                        lora.lora_b[index] = lora.lora_b[index].pin_memory()
                else:
                    lora.lora_a = lora.lora_a.pin_memory()
                    lora.lora_b = lora.lora_b.pin_memory()

666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    def _stack_moe_lora_weights(
        self, lora_model: LoRAModel, module: FusedMoE3DWithLoRA, module_name: str
    ):
        module_lora = self._get_lora_layer_weights(lora_model, module_name)

        # Note (gnovack) - If MOE lora weights are not split into
        # num_experts chunks, we split them here
        if module_lora and torch.is_tensor(module_lora.lora_a):
            # Handle PEFT file format where experts.base_layer is the
            # gate_up_proj and experts is the down_proj
            gate_up_proj_lora = self._get_lora_layer_weights(
                lora_model, module_name + ".base_layer"
            )
            down_proj_lora = module_lora
            # FIXME Edge case where LoRA is not added to gate_up_proj
            # or down_proj
            assert gate_up_proj_lora is not None
            assert down_proj_lora is not None
            if self._is_3d_moe_model:
                num_experts = module.w13_lora_a_stacked[0].shape[1]

                # (num_experts,rank,input_size)
                gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape(
                    num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
                )
                down_proj_lora.lora_a = down_proj_lora.lora_a.reshape(
                    num_experts, -1, down_proj_lora.lora_a.shape[-1]
                )

695
                # (output_size,rank,num_experts)
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
                gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape(
                    gate_up_proj_lora.lora_b.shape[0], -1, num_experts
                )
                down_proj_lora.lora_b = down_proj_lora.lora_b.reshape(
                    down_proj_lora.lora_b.shape[0], -1, num_experts
                )

                # (num_experts,output_size,rank)
                gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute(
                    2, 0, 1
                ).contiguous()
                down_proj_lora.lora_b = down_proj_lora.lora_b.permute(
                    2, 0, 1
                ).contiguous()

                module_lora.lora_a = [
                    gate_up_proj_lora.lora_a,
                    down_proj_lora.lora_a,
                ]
                module_lora.lora_b = [
                    gate_up_proj_lora.lora_b,
                    down_proj_lora.lora_b,
                ]
            else:
                # Some 3D MoE models haven't added the `is_3d_moe_weight`
                # attribute yet, so fallback here
                num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                    num_experts, dim=-1
                )
                up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                    num_experts, dim=-1
                )

                down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                lora_a = []
                lora_b = []
                for i in range(num_experts):
                    lora_a.append(gate_proj_a[i])
                    lora_a.append(down_proj_a[i])
                    lora_a.append(up_proj_a[i])

                    lora_b.append(gate_proj_b[i])
                    lora_b.append(down_proj_b[i])
                    lora_b.append(up_proj_b[i])

                module_lora.lora_a = lora_a
                module_lora.lora_b = lora_b

751
    def _get_lora_layer_weights(
752
        self, lora_model: LoRAModel, module_name: str
753
    ) -> LoRALayerWeights | None:
754
        org_module_name = module_name
755
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
756
757
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
758
            module_name = module_name.removeprefix("model.")
759
760
761
762
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
763
764
                    "after removing the prefix 'model.'."
                )
765
766
        return lora_model.get_lora(org_module_name)

767
    def deactivate_adapter(self, adapter_id: int) -> bool:
768
769
770
771
772
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True
773
774

    def add_adapter(self, adapter: LoRAModel) -> bool:
775
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
776
777
778
779
780
781
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True
782

783
    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
784
785
786
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping
787
788

    def remove_adapter(self, adapter_id: int) -> bool:
789
790
791
792
793
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True
794

795
796
    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)
797

798
    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
799
        return self._registered_adapters.get(adapter_id)
800
801
802


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
803
    def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
804
        super().__init__(capacity, deactivate_lora_fn)
805
806
807
808
809


class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

810
811
812
813
814
815
816
817
    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
818
        vllm_config: VllmConfig | None = None,
819
820
    ):
        super().__init__(
821
822
823
824
825
826
827
            model,
            max_num_seqs,
            max_num_batched_tokens,
            vocab_size,
            lora_config,
            device,
            vllm_config,
828
        )
829
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
830
831
            self.capacity, self.deactivate_adapter
        )
832
        self._active_adapters: LoRALRUCache = LoRALRUCache(
833
834
            self.lora_slots, self._deactivate_adapter
        )
835

836
    def list_adapters(self) -> dict[int, LoRAModel]:
837
        """List all registered LoRAModels."""
838
        return dict(self._registered_adapters.cache)
839

840
    def add_adapter(self, lora: LoRAModel) -> bool:
841
        """Add a LoRAModel to the manager."""
842
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
843
844
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
845
846
847
            was_added = True
        else:
            # We always touch to update the LRU cache order
848
            self._registered_adapters.touch(lora.id)
849
850
851
            was_added = False
        return was_added

852
    def activate_adapter(
853
854
855
        self,
        lora_id: int,
    ) -> bool:
856
857
858
859
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
860
861
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
862
        # We always touch to update the LRU cache order
863
        self._active_adapters.touch(lora_id)
864
865
        return result

866
867
868
    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
869
870
871
            return True
        return False

872
    def pin_adapter(self, lora_id: int) -> bool:
873
874
875
876
877
878
879
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
880
            self._registered_adapters.pin(lora_id)
881
        except ValueError as err:
882
883
884
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err
885
886

    def _pin_lora_in_gpu_cache(self, lora_id: int):
887
        if lora_id not in self._active_adapters:
888
            # move lora to gpu if not already active
889
            self.activate_adapter(lora_id)
890

891
        self._active_adapters.pin(lora_id)
892

893
894

def create_lora_manager(
895
896
897
898
899
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
900
    vllm_config: VllmConfig,
901
902
903
904
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
905
    """Create a LoRA adapter for a given model."""
906
    if not isinstance(model, SupportsLoRA):
907
908
909
910
911
912
913
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
914
        vllm_config=vllm_config,
915
        device=device,
916
917
        **kwargs,
    )
918
    return lora_manager