marlin.py 8.78 KB
Newer Older
1
2
3
4
5
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

6
from vllm import _custom_ops as ops
7
from vllm.logger import init_logger
8
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
9
10
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
11
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
12
from vllm.model_executor.utils import set_weight_attrs
13

14
15
logger = init_logger(__name__)

16
17
18
19
20
21
22
23
24
25

class MarlinConfig(QuantizationConfig):
    """Config class for Marlin.

    Reference: https://github.com/IST-DASLab/marlin/tree/master
    """

    def __init__(
        self,
        group_size: int,
26
        lm_head_quantized: bool,
27
28
29
    ) -> None:
        # Group size for the quantization.
        self.group_size = group_size
30
        self.lm_head_quantized = lm_head_quantized
31
32
        if self.group_size != 128 and self.group_size != -1:
            raise ValueError(
33
34
35
                "Currently, only group size 128 and -1 (channelwise) "
                "is supported for Marlin, but got group_size of "
                f"{self.group_size}")
36
37
38
39
40
41
42
43
44
45
46
47
48

        # 4 Bits packed into 32 bit datatype.
        self.pack_factor = 32 // 4

        # Tile size used by marlin kernels.
        self.tile_size = 16

        # Min out_features dim
        self.min_n_threads = 64

        # Min in_features dim
        self.min_k_threads = 128

49
50
        # Max parallel problems to solve at once (improves large
        # batch performance)
51
52
53
54
55
56
        self.max_parallel = 16

        # Permutation length used by the marlin kernels.
        self.perm_len = 1024

    def __repr__(self) -> str:
57
58
        return (f"MarlinConfig(group_size={self.group_size}, "
                f"lm_head_quantized={self.lm_head_quantized})")
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    @classmethod
    def get_name(cls) -> str:
        return "marlin"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.half]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return ["quantize_config.json"]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
        group_size = cls.get_from_keys(config, ["group_size"])
80
81
82
        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                                 default=False)
        return cls(group_size, lm_head_quantized)
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    @classmethod
    def override_quantization_method(cls, hf_quant_cfg,
                                     user_quant) -> Optional[str]:
        # compat: autogptq >=0.8.0 use checkpoint_format: str
        # compat: autogptq <=0.7.1 is_marlin_format: bool
        is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
                            or hf_quant_cfg.get("is_marlin_format", False))

        is_valid_user_quant = (user_quant is None or user_quant == "gptq"
                               or user_quant == "marlin")

        if is_marlin_format and is_valid_user_quant:
            msg = ("The model is serialized in {} format. Using {} kernel.".
                   format(cls.get_name(), cls.get_name()))
            logger.info(msg)
            return cls.get_name()

        return None

103
104
    def get_quant_method(
            self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
105
106
        if (isinstance(layer, LinearBase) or
            (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
107
108
            return MarlinLinearMethod(self)
        return None
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    def get_scaled_act_names(self) -> List[str]:
        return []


class MarlinLinearMethod(LinearMethodBase):
    """Linear method for Marlin.

    Args:
        quant_config: The Marlin quantization config.
    """

    def __init__(self, quant_config: MarlinConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
126
        layer: torch.nn.Module,
127
        input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
128
        output_partition_sizes: List[int],
129
130
131
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
132
133
        **extra_weight_attrs,
    ):
134
135
136
137
138
139
140
        del output_size  # Unused.

        if params_dtype != torch.float16:
            raise ValueError(
                f"The params dtype must be float16, but got {params_dtype}")

        # Validate output_size_per_partition
James Fleming's avatar
James Fleming committed
141
        output_size_per_partition = sum(output_partition_sizes)
142
143
        if output_size_per_partition % self.quant_config.min_n_threads != 0:
            raise ValueError(
144
145
146
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"min_n_threads = {self.quant_config.min_n_threads}.")
147
148
        if output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
149
150
151
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"pack_factor = {self.quant_config.pack_factor}.")
152
153
154
155

        # Validate input_size_per_partition
        if input_size_per_partition % self.quant_config.min_k_threads != 0:
            raise ValueError(
156
157
158
159
160
161
162
163
                f"Weight input_size_per_partition = "
                f"{input_size_per_partition} is not divisible by "
                f"min_k_threads = {self.quant_config.min_k_threads}.")
        if (self.quant_config.group_size != -1 and
                input_size_per_partition % self.quant_config.group_size != 0):
            raise ValueError(f"Weight input_size_per_partition = "
                             f"{input_size_per_partition} is not divisible by "
                             f"group_size = {self.quant_config.group_size}.")
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

        # Check that we have at least 4 tiles horizontally in the shard
        num_tiles_per_perm = self.quant_config.perm_len // (
            self.quant_config.tile_size**2)
        if output_size_per_partition % num_tiles_per_perm != 0:
            raise ValueError(
                "Each permutation group must reside on the same gpu")

        # Quantized 4Bit weights packed into Int32.
        qweight = Parameter(
            torch.empty(
                input_size_per_partition // self.quant_config.tile_size,
                output_size_per_partition * self.quant_config.tile_size //
                self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            qweight,
            {
                "input_dim": 0,
                "output_dim": 1,
                "packed_dim": 1,
                "pack_factor": self.quant_config.pack_factor,
                "marlin_tile_size": self.quant_config.tile_size,
            },
        )

        # Determine if channelwise or not
195
196
197
        input_groups = (1 if self.quant_config.group_size == -1 else
                        input_size_per_partition //
                        self.quant_config.group_size)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

        scales = Parameter(
            torch.empty(
                input_groups,
                output_size_per_partition,
                device="cuda",
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            scales,
            {
                "input_dim": None if input_groups == 1 else 0,
                "output_dim": 1,
            },
        )

        # Allocate workspace (Used for internal locking mechanism)
        max_workspace_size = (
            output_size_per_partition //
            self.quant_config.min_n_threads) * self.quant_config.max_parallel
        workspace = Parameter(torch.zeros(max_workspace_size,
                                          device="cuda",
                                          dtype=torch.int),
                              requires_grad=False)

225
226
227
228
229
230
        layer.register_parameter("B", qweight)
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("s", scales)
        set_weight_attrs(scales, extra_weight_attrs)
        layer.register_parameter("workspace", workspace)
        set_weight_attrs(workspace, extra_weight_attrs)
231

232
    def apply(
233
        self,
234
        layer: torch.nn.Module,
235
236
237
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
238
239
240
        qweight = layer.B
        scales = layer.s
        workspace = layer.workspace
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

        x_2d = x.view(-1, x.shape[-1])

        size_m = x_2d.shape[0]
        size_k = x_2d.shape[1]
        size_n = scales.shape[1]

        output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
                                    size_n, size_k)

        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

        if bias is not None:
            output.add_(bias)  # In-place add

        return output