marlin.py 7.46 KB
Newer Older
1
2
3
4
5
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

6
from vllm import _custom_ops as ops
7
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
8
9
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
10
from vllm.model_executor.utils import set_weight_attrs
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class MarlinConfig(QuantizationConfig):
    """Config class for Marlin.

    Reference: https://github.com/IST-DASLab/marlin/tree/master
    """

    def __init__(
        self,
        group_size: int,
    ) -> None:
        # Group size for the quantization.
        self.group_size = group_size
        if self.group_size != 128 and self.group_size != -1:
            raise ValueError(
27
28
29
                "Currently, only group size 128 and -1 (channelwise) "
                "is supported for Marlin, but got group_size of "
                f"{self.group_size}")
30
31
32
33
34
35
36
37
38
39
40
41
42

        # 4 Bits packed into 32 bit datatype.
        self.pack_factor = 32 // 4

        # Tile size used by marlin kernels.
        self.tile_size = 16

        # Min out_features dim
        self.min_n_threads = 64

        # Min in_features dim
        self.min_k_threads = 128

43
44
        # Max parallel problems to solve at once (improves large
        # batch performance)
45
46
47
48
49
50
        self.max_parallel = 16

        # Permutation length used by the marlin kernels.
        self.perm_len = 1024

    def __repr__(self) -> str:
51
        return f"MarlinConfig(group_size={self.group_size})"
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    @classmethod
    def get_name(cls) -> str:
        return "marlin"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.half]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return ["quantize_config.json"]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(group_size)

75
76
77
78
79
    def get_quant_method(
            self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
        if isinstance(layer, LinearBase):
            return MarlinLinearMethod(self)
        return None
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    def get_scaled_act_names(self) -> List[str]:
        return []


class MarlinLinearMethod(LinearMethodBase):
    """Linear method for Marlin.

    Args:
        quant_config: The Marlin quantization config.
    """

    def __init__(self, quant_config: MarlinConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
97
        layer: torch.nn.Module,
98
        input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
99
        output_partition_sizes: List[int],
100
101
102
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
103
104
        **extra_weight_attrs,
    ):
105
106
107
108
109
110
111
        del output_size  # Unused.

        if params_dtype != torch.float16:
            raise ValueError(
                f"The params dtype must be float16, but got {params_dtype}")

        # Validate output_size_per_partition
James Fleming's avatar
James Fleming committed
112
        output_size_per_partition = sum(output_partition_sizes)
113
114
        if output_size_per_partition % self.quant_config.min_n_threads != 0:
            raise ValueError(
115
116
117
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"min_n_threads = {self.quant_config.min_n_threads}.")
118
119
        if output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
120
121
122
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"pack_factor = {self.quant_config.pack_factor}.")
123
124
125
126

        # Validate input_size_per_partition
        if input_size_per_partition % self.quant_config.min_k_threads != 0:
            raise ValueError(
127
128
129
130
131
132
133
134
                f"Weight input_size_per_partition = "
                f"{input_size_per_partition} is not divisible by "
                f"min_k_threads = {self.quant_config.min_k_threads}.")
        if (self.quant_config.group_size != -1 and
                input_size_per_partition % self.quant_config.group_size != 0):
            raise ValueError(f"Weight input_size_per_partition = "
                             f"{input_size_per_partition} is not divisible by "
                             f"group_size = {self.quant_config.group_size}.")
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        # Check that we have at least 4 tiles horizontally in the shard
        num_tiles_per_perm = self.quant_config.perm_len // (
            self.quant_config.tile_size**2)
        if output_size_per_partition % num_tiles_per_perm != 0:
            raise ValueError(
                "Each permutation group must reside on the same gpu")

        # Quantized 4Bit weights packed into Int32.
        qweight = Parameter(
            torch.empty(
                input_size_per_partition // self.quant_config.tile_size,
                output_size_per_partition * self.quant_config.tile_size //
                self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            qweight,
            {
                "input_dim": 0,
                "output_dim": 1,
                "packed_dim": 1,
                "pack_factor": self.quant_config.pack_factor,
                "marlin_tile_size": self.quant_config.tile_size,
            },
        )

        # Determine if channelwise or not
166
167
168
        input_groups = (1 if self.quant_config.group_size == -1 else
                        input_size_per_partition //
                        self.quant_config.group_size)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        scales = Parameter(
            torch.empty(
                input_groups,
                output_size_per_partition,
                device="cuda",
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            scales,
            {
                "input_dim": None if input_groups == 1 else 0,
                "output_dim": 1,
            },
        )

        # Allocate workspace (Used for internal locking mechanism)
        max_workspace_size = (
            output_size_per_partition //
            self.quant_config.min_n_threads) * self.quant_config.max_parallel
        workspace = Parameter(torch.zeros(max_workspace_size,
                                          device="cuda",
                                          dtype=torch.int),
                              requires_grad=False)

196
197
198
199
200
201
        layer.register_parameter("B", qweight)
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("s", scales)
        set_weight_attrs(scales, extra_weight_attrs)
        layer.register_parameter("workspace", workspace)
        set_weight_attrs(workspace, extra_weight_attrs)
202

203
    def apply(
204
        self,
205
        layer: torch.nn.Module,
206
207
208
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
209
210
211
        qweight = layer.B
        scales = layer.s
        workspace = layer.workspace
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

        x_2d = x.view(-1, x.shape[-1])

        size_m = x_2d.shape[0]
        size_k = x_2d.shape[1]
        size_n = scales.shape[1]

        output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
                                    size_n, size_k)

        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

        if bias is not None:
            output.add_(bias)  # In-place add

        return output