utils.py 3.34 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Jee Jee Li's avatar
Jee Jee Li committed
4
from dataclasses import dataclass
5
from enum import Enum
Jee Jee Li's avatar
Jee Jee Li committed
6
7
8
9

import torch
import torch.nn as nn

10
11
12
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.utils.math_utils import next_power_of_2

Jee Jee Li's avatar
Jee Jee Li committed
13

14
15
16
17
18
19
class LoRAMappingType(Enum):
    LANGUAGE = 1
    TOWER = 2
    CONNECTOR = 3


Jee Jee Li's avatar
Jee Jee Li committed
20
@dataclass
21
22
23
class LoRAMapping:
    index_mapping: tuple[int, ...]
    prompt_mapping: tuple[int, ...]
Jee Jee Li's avatar
Jee Jee Li committed
24
    is_prefill: bool = False
25
    type: LoRAMappingType = LoRAMappingType.LANGUAGE
Jee Jee Li's avatar
Jee Jee Li committed
26

27
28
29
30
    def __post_init__(self):
        self.index_mapping = tuple(self.index_mapping)
        self.prompt_mapping = tuple(self.prompt_mapping)

Jee Jee Li's avatar
Jee Jee Li committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
    # unquantizedLinear
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
    # GPTQ/AWQ
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # HQQ marlin
    elif hasattr(base_layer, "W_q"):
        return base_layer.W_q.device
47
48
49
50
51
52
53
54
55
    # MoE layer
    elif hasattr(base_layer, "w2_weight"):
        return base_layer.w2_weight.device
    # MoE Compressed Tensor
    elif hasattr(base_layer, "w2_weight_packed"):
        return base_layer.w2_weight_packed.device
    # MoE GPTQ/AWQ/GGUF
    elif hasattr(base_layer, "w2_qweight"):
        return base_layer.w2_qweight.device
Jee Jee Li's avatar
Jee Jee Li committed
56
57
58
59
60
61
62
63
64
65
66
67
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")


def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
68
        condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
Jee Jee Li's avatar
Jee Jee Li committed
69
70
71
72
73
74
75
76
77
78
79
80
        return can_replace(*args, **kwargs) and condition

    return dec


def _fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
81
82
83
        return (
            can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
        )
Jee Jee Li's avatar
Jee Jee Li committed
84
85

    return dec
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


def try_get_optimal_moe_lora_config(
    op_type: str,
    w1_shape: tuple[int, ...],
    w2_shape: tuple[int, ...],
    rank: int,
    top_k: int,
    dtype: str | None,
    M: int,
    block_shape: list[int] | None = None,
) -> dict[str, int | None]:
    config = try_get_optimal_moe_config(
        w1_shape, w2_shape, top_k, dtype, M, block_shape
    ).copy()
    if op_type in [
        "fused_moe_lora_w13_shrink",
        "fused_moe_lora_w2_shrink",
    ]:
        config["BLOCK_SIZE_N"] = min(
            config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank)
        )
    elif op_type in [
        "fused_moe_lora_w13_expand",
        "fused_moe_lora_w2_expand",
    ]:
        config["BLOCK_SIZE_K"] = max(
            16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank))
        )
    return config