marlin.py 9.33 KB
Newer Older
1
2
3
4
5
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

6
from vllm import _custom_ops as ops
7
from vllm.logger import init_logger
8
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
9
10
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
11
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
12
13
14
15
from vllm.model_executor.parameter import (BasevLLMParameter,
                                           ChannelQuantScaleParameter,
                                           GroupQuantScaleParameter,
                                           PackedvLLMParameter)
16

17
18
logger = init_logger(__name__)

19
20
21
22
23
24
25
26
27
28

class MarlinConfig(QuantizationConfig):
    """Config class for Marlin.

    Reference: https://github.com/IST-DASLab/marlin/tree/master
    """

    def __init__(
        self,
        group_size: int,
29
        lm_head_quantized: bool,
30
31
32
    ) -> None:
        # Group size for the quantization.
        self.group_size = group_size
33
        self.lm_head_quantized = lm_head_quantized
34
35
        if self.group_size != 128 and self.group_size != -1:
            raise ValueError(
36
37
38
                "Currently, only group size 128 and -1 (channelwise) "
                "is supported for Marlin, but got group_size of "
                f"{self.group_size}")
39
40
41
42
43
44
45
46
47
48
49
50
51

        # 4 Bits packed into 32 bit datatype.
        self.pack_factor = 32 // 4

        # Tile size used by marlin kernels.
        self.tile_size = 16

        # Min out_features dim
        self.min_n_threads = 64

        # Min in_features dim
        self.min_k_threads = 128

52
53
        # Max parallel problems to solve at once (improves large
        # batch performance)
54
55
56
57
58
59
        self.max_parallel = 16

        # Permutation length used by the marlin kernels.
        self.perm_len = 1024

    def __repr__(self) -> str:
60
61
        return (f"MarlinConfig(group_size={self.group_size}, "
                f"lm_head_quantized={self.lm_head_quantized})")
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    @classmethod
    def get_name(cls) -> str:
        return "marlin"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.half]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return ["quantize_config.json"]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
        group_size = cls.get_from_keys(config, ["group_size"])
83
84
85
        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                                 default=False)
        return cls(group_size, lm_head_quantized)
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    @classmethod
    def override_quantization_method(cls, hf_quant_cfg,
                                     user_quant) -> Optional[str]:
        # compat: autogptq >=0.8.0 use checkpoint_format: str
        # compat: autogptq <=0.7.1 is_marlin_format: bool
        is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
                            or hf_quant_cfg.get("is_marlin_format", False))

        is_valid_user_quant = (user_quant is None or user_quant == "gptq"
                               or user_quant == "marlin")

        if is_marlin_format and is_valid_user_quant:
            msg = ("The model is serialized in {} format. Using {} kernel.".
                   format(cls.get_name(), cls.get_name()))
            logger.info(msg)
            return cls.get_name()

        return None

106
107
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["MarlinLinearMethod"]:
108
109
        if (isinstance(layer, LinearBase) or
            (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
110
111
            return MarlinLinearMethod(self)
        return None
112
113
114
115
116
117
118
119
120
121
122
123
124
125


class MarlinLinearMethod(LinearMethodBase):
    """Linear method for Marlin.

    Args:
        quant_config: The Marlin quantization config.
    """

    def __init__(self, quant_config: MarlinConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
126
        layer: torch.nn.Module,
127
        input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
128
        output_partition_sizes: List[int],
129
130
131
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
132
133
        **extra_weight_attrs,
    ):
134
        del output_size  # Unused.
135
        weight_loader = extra_weight_attrs["weight_loader"]
136
137
138
139
140
141

        if params_dtype != torch.float16:
            raise ValueError(
                f"The params dtype must be float16, but got {params_dtype}")

        # Validate output_size_per_partition
James Fleming's avatar
James Fleming committed
142
        output_size_per_partition = sum(output_partition_sizes)
143
144
        if output_size_per_partition % self.quant_config.min_n_threads != 0:
            raise ValueError(
145
146
147
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"min_n_threads = {self.quant_config.min_n_threads}.")
148
149
        if output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
150
151
152
                f"Weight output_size_per_partition = "
                f"{output_size_per_partition} is not divisible by "
                f"pack_factor = {self.quant_config.pack_factor}.")
153
154
155
156

        # Validate input_size_per_partition
        if input_size_per_partition % self.quant_config.min_k_threads != 0:
            raise ValueError(
157
158
159
160
161
162
163
164
                f"Weight input_size_per_partition = "
                f"{input_size_per_partition} is not divisible by "
                f"min_k_threads = {self.quant_config.min_k_threads}.")
        if (self.quant_config.group_size != -1 and
                input_size_per_partition % self.quant_config.group_size != 0):
            raise ValueError(f"Weight input_size_per_partition = "
                             f"{input_size_per_partition} is not divisible by "
                             f"group_size = {self.quant_config.group_size}.")
165
166
167
168
169
170
171
172
173

        # Check that we have at least 4 tiles horizontally in the shard
        num_tiles_per_perm = self.quant_config.perm_len // (
            self.quant_config.tile_size**2)
        if output_size_per_partition % num_tiles_per_perm != 0:
            raise ValueError(
                "Each permutation group must reside on the same gpu")

        # Quantized 4Bit weights packed into Int32.
174
175
        qweight = PackedvLLMParameter(
            data=torch.empty(
176
177
178
179
180
181
                input_size_per_partition // self.quant_config.tile_size,
                output_size_per_partition * self.quant_config.tile_size //
                self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
182
183
184
185
186
187
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            marlin_tile_size=self.quant_config.tile_size,
            weight_loader=weight_loader)
188
189

        # Determine if channelwise or not
190
191
192
        input_groups = (1 if self.quant_config.group_size == -1 else
                        input_size_per_partition //
                        self.quant_config.group_size)
193

194
195
        weight_scale_args = {
            "data":
196
197
198
199
200
201
            torch.empty(
                input_groups,
                output_size_per_partition,
                device="cuda",
                dtype=params_dtype,
            ),
202
203
204
205
206
207
208
209
210
211
            "weight_loader":
            weight_loader
        }
        if input_groups == 1:
            scales = ChannelQuantScaleParameter(output_dim=1,
                                                **weight_scale_args)
        else:
            scales = GroupQuantScaleParameter(output_dim=1,
                                              input_dim=0,
                                              **weight_scale_args)
212
213
214
215
216

        # Allocate workspace (Used for internal locking mechanism)
        max_workspace_size = (
            output_size_per_partition //
            self.quant_config.min_n_threads) * self.quant_config.max_parallel
217
218
219
220
221

        workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
                                                       device="cuda",
                                                       dtype=torch.int),
                                      weight_loader=weight_loader)
222

223
224
225
        layer.register_parameter("B", qweight)
        layer.register_parameter("s", scales)
        layer.register_parameter("workspace", workspace)
226
227
228
229
230
231

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # required by torch.compile
        layer.B = Parameter(layer.B.data, requires_grad=False)
        layer.s = Parameter(layer.s.data, requires_grad=False)
        layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
232

233
    def apply(
234
        self,
235
        layer: torch.nn.Module,
236
237
238
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
239
240
241
        qweight = layer.B
        scales = layer.s
        workspace = layer.workspace
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

        x_2d = x.view(-1, x.shape[-1])

        size_m = x_2d.shape[0]
        size_k = x_2d.shape[1]
        size_n = scales.shape[1]

        output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
                                    size_n, size_k)

        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

        if bias is not None:
            output.add_(bias)  # In-place add

        return output