gate_linear.py 1.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch.nn as nn
from transformers import PretrainedConfig

from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear

from .replicated_linear import ReplicatedLinearWithLoRA


class GateLinearWithLoRA(ReplicatedLinearWithLoRA):
    def __init__(self, base_layer: GateLinear) -> None:
        super().__init__(
            base_layer,
        )

    # GateLinearWithLoRA should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is maybe_get_oot_by_class(GateLinear)