utils.py 3.25 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Jee Jee Li's avatar
Jee Jee Li committed
4
from dataclasses import dataclass
5
from enum import Enum
Jee Jee Li's avatar
Jee Jee Li committed
6
7
8
9

import torch
import torch.nn as nn

10
11
12
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.utils.math_utils import next_power_of_2

Jee Jee Li's avatar
Jee Jee Li committed
13

14
15
16
17
18
19
class LoRAMappingType(Enum):
    LANGUAGE = 1
    TOWER = 2
    CONNECTOR = 3


Jee Jee Li's avatar
Jee Jee Li committed
20
@dataclass
21
22
23
class LoRAMapping:
    index_mapping: tuple[int, ...]
    prompt_mapping: tuple[int, ...]
Jee Jee Li's avatar
Jee Jee Li committed
24
    is_prefill: bool = False
25
    type: LoRAMappingType = LoRAMappingType.LANGUAGE
Jee Jee Li's avatar
Jee Jee Li committed
26

27
28
29
30
    def __post_init__(self):
        self.index_mapping = tuple(self.index_mapping)
        self.prompt_mapping = tuple(self.prompt_mapping)

Jee Jee Li's avatar
Jee Jee Li committed
31
32
33
34
35
36
37
38
39
40
41
42
43

def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
    # unquantizedLinear
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
    # GPTQ/AWQ
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
44
45
46
47
48
49
50
51
52
    # MoE layer
    elif hasattr(base_layer, "w2_weight"):
        return base_layer.w2_weight.device
    # MoE Compressed Tensor
    elif hasattr(base_layer, "w2_weight_packed"):
        return base_layer.w2_weight_packed.device
    # MoE GPTQ/AWQ/GGUF
    elif hasattr(base_layer, "w2_qweight"):
        return base_layer.w2_qweight.device
Jee Jee Li's avatar
Jee Jee Li committed
53
54
55
56
57
58
59
60
61
62
63
64
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")


def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
65
        condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
Jee Jee Li's avatar
Jee Jee Li committed
66
67
68
69
70
71
72
73
74
75
76
77
        return can_replace(*args, **kwargs) and condition

    return dec


def _fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
78
79
80
        return (
            can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
        )
Jee Jee Li's avatar
Jee Jee Li committed
81
82

    return dec
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112


def try_get_optimal_moe_lora_config(
    op_type: str,
    w1_shape: tuple[int, ...],
    w2_shape: tuple[int, ...],
    rank: int,
    top_k: int,
    dtype: str | None,
    M: int,
    block_shape: list[int] | None = None,
) -> dict[str, int | None]:
    config = try_get_optimal_moe_config(
        w1_shape, w2_shape, top_k, dtype, M, block_shape
    ).copy()
    if op_type in [
        "fused_moe_lora_w13_shrink",
        "fused_moe_lora_w2_shrink",
    ]:
        config["BLOCK_SIZE_N"] = min(
            config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank)
        )
    elif op_type in [
        "fused_moe_lora_w13_expand",
        "fused_moe_lora_w2_expand",
    ]:
        config["BLOCK_SIZE_K"] = max(
            16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank))
        )
    return config