marlin.py 8.33 KB
Newer Older
1
2
3
4
5
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

6
from vllm import _custom_ops as ops
7
from vllm.logger import init_logger
8
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
9
10
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
11
from vllm.model_executor.utils import set_weight_attrs
12

13
14
logger = init_logger(__name__)

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

class MarlinConfig(QuantizationConfig):
    """Config class for Marlin.

    Reference: https://github.com/IST-DASLab/marlin/tree/master
    """

    def __init__(
        self,
        group_size: int,
    ) -> None:
        # Group size for the quantization.
        self.group_size = group_size
        if self.group_size != 128 and self.group_size != -1:
            raise ValueError(
30
31
32
                "Currently, only group size 128 and -1 (channelwise) "
                "is supported for Marlin, but got group_size of "
                f"{self.group_size}")
33
34
35
36
37
38
39
40
41
42
43
44
45

        # 4 Bits packed into 32 bit datatype.
        self.pack_factor = 32 // 4

        # Tile size used by marlin kernels.
        self.tile_size = 16

        # Min out_features dim
        self.min_n_threads = 64

        # Min in_features dim
        self.min_k_threads = 128

46
47
        # Max parallel problems to solve at once (improves large
        # batch performance)
48
49
50
51
52
53
        self.max_parallel = 16

        # Permutation length used by the marlin kernels.
        self.perm_len = 1024

    def __repr__(self) -> str:
54
        return f"MarlinConfig(group_size={self.group_size})"
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    @classmethod
    def get_name(cls) -> str:
        return "marlin"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.half]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return ["quantize_config.json"]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(group_size)

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    @classmethod
    def override_quantization_method(cls, hf_quant_cfg,
                                     user_quant) -> Optional[str]:
        # compat: autogptq >=0.8.0 use checkpoint_format: str
        # compat: autogptq <=0.7.1 is_marlin_format: bool
        is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
                            or hf_quant_cfg.get("is_marlin_format", False))

        is_valid_user_quant = (user_quant is None or user_quant == "gptq"
                               or user_quant == "marlin")

        if is_marlin_format and is_valid_user_quant:
            msg = ("The model is serialized in {} format. Using {} kernel.".
                   format(cls.get_name(), cls.get_name()))
            logger.info(msg)
            return cls.get_name()

        return None

97
98
99
100
101
    def get_quant_method(
            self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
        if isinstance(layer, LinearBase):
            return MarlinLinearMethod(self)
        return None
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    def get_scaled_act_names(self) -> List[str]:
        return []


class MarlinLinearMethod(LinearMethodBase):
    """Linear method for Marlin.

    Args:
        quant_config: The Marlin quantization config.
    """

    def __init__(self, quant_config: MarlinConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
119
        layer: torch.nn.Module,
120
        input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
121
        output_partition_sizes: List[int],
122
123
124
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
125
126
        **extra_weight_attrs,
    ):
127
128
129
130
131
132
133
        del output_size  # Unused.

        if params_dtype != torch.float16:
            raise ValueError(
                f"The params dtype must be float16, but got {params_dtype}")

        # Validate output_size_per_partition
James Fleming's avatar
James Fleming committed
134
        output_size_per_partition = sum(output_partition_sizes)
135
136
        if output_size_per_partition % self.quant_config.min_n_threads != 0:
            raise ValueError(
137
138
139
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"min_n_threads = {self.quant_config.min_n_threads}.")
140
141
        if output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
142
143
144
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"pack_factor = {self.quant_config.pack_factor}.")
145
146
147
148

        # Validate input_size_per_partition
        if input_size_per_partition % self.quant_config.min_k_threads != 0:
            raise ValueError(
149
150
151
152
153
154
155
156
                f"Weight input_size_per_partition = "
                f"{input_size_per_partition} is not divisible by "
                f"min_k_threads = {self.quant_config.min_k_threads}.")
        if (self.quant_config.group_size != -1 and
                input_size_per_partition % self.quant_config.group_size != 0):
            raise ValueError(f"Weight input_size_per_partition = "
                             f"{input_size_per_partition} is not divisible by "
                             f"group_size = {self.quant_config.group_size}.")
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

        # Check that we have at least 4 tiles horizontally in the shard
        num_tiles_per_perm = self.quant_config.perm_len // (
            self.quant_config.tile_size**2)
        if output_size_per_partition % num_tiles_per_perm != 0:
            raise ValueError(
                "Each permutation group must reside on the same gpu")

        # Quantized 4Bit weights packed into Int32.
        qweight = Parameter(
            torch.empty(
                input_size_per_partition // self.quant_config.tile_size,
                output_size_per_partition * self.quant_config.tile_size //
                self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            qweight,
            {
                "input_dim": 0,
                "output_dim": 1,
                "packed_dim": 1,
                "pack_factor": self.quant_config.pack_factor,
                "marlin_tile_size": self.quant_config.tile_size,
            },
        )

        # Determine if channelwise or not
188
189
190
        input_groups = (1 if self.quant_config.group_size == -1 else
                        input_size_per_partition //
                        self.quant_config.group_size)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        scales = Parameter(
            torch.empty(
                input_groups,
                output_size_per_partition,
                device="cuda",
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            scales,
            {
                "input_dim": None if input_groups == 1 else 0,
                "output_dim": 1,
            },
        )

        # Allocate workspace (Used for internal locking mechanism)
        max_workspace_size = (
            output_size_per_partition //
            self.quant_config.min_n_threads) * self.quant_config.max_parallel
        workspace = Parameter(torch.zeros(max_workspace_size,
                                          device="cuda",
                                          dtype=torch.int),
                              requires_grad=False)

218
219
220
221
222
223
        layer.register_parameter("B", qweight)
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("s", scales)
        set_weight_attrs(scales, extra_weight_attrs)
        layer.register_parameter("workspace", workspace)
        set_weight_attrs(workspace, extra_weight_attrs)
224

225
    def apply(
226
        self,
227
        layer: torch.nn.Module,
228
229
230
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
231
232
233
        qweight = layer.B
        scales = layer.s
        workspace = layer.workspace
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

        x_2d = x.view(-1, x.shape[-1])

        size_m = x_2d.shape[0]
        size_k = x_2d.shape[1]
        size_n = scales.shape[1]

        output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
                                    size_n, size_k)

        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

        if bias is not None:
            output.add_(bias)  # In-place add

        return output