rtn.py 9.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# Copyright © 2025, Oracle and/or its affiliates.

import os
from typing import Any, Optional

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be 
overridden by setting the RTN_NUM_BITS envvar
"""
NUM_BITS = os.getenv('RTN_NUM_BITS', "8")
"""By default, use group size of 128 parameters, but it can be 
overridden by setting the RTN_GROUP_SIZE envvar
"""
GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128")


class RTNConfig(QuantizationConfig):
    """Config class for RTN.
    """

    def __init__(
            self,
            weight_bits: int = int(NUM_BITS),
            group_size: int = int(GROUP_SIZE),
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size

        if self.weight_bits != 4 and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 4-bit or 8-bit weight quantization is "
                f"supported for RTN, but got {self.weight_bits} bits.")

    def __repr__(self) -> str:
        return (f"RTNConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size})")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "rtn"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits, group_size)

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["RTNLinearMethod"]:
        if isinstance(layer, LinearBase):
            return RTNLinearMethod(self)
        return None


class RTNTensor:
    """A wrapper over Tensor that enables quantization on-the-fly by
    overloading the copy_ method.
    """

    def __init__(self, data: torch.Tensor, scale: torch.Tensor,
                 quant_config: RTNConfig) -> None:
        self.data = data
        self.scale = scale
        self.quant_config = quant_config

    def narrow(self, dim, start, length):
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return RTNTensor(
            self.data.narrow(dim, start // factor, length // factor),
            self.scale.narrow(dim, start, length), self.quant_config)

    @property
    def shape(self):
        shape = self.data.shape
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return torch.Size((shape[0] * factor, shape[1]))

    def copy_(self, loaded_weight: torch.Tensor) -> None:
        qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
                                             self.quant_config.weight_bits,
                                             self.quant_config.group_size)

        self.data.copy_(qweight)
        self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
    """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
    when its data is accessed. We need this wrapper for the data loading phase
    only, so we can intercept a weight copying function (torch.Tensor.copy_)
    and apply quantization on-the-fly.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):
        return super().__new__(cls, data=data, requires_grad=False)

    def __init__(self, data: torch.Tensor, scale: torch.Tensor,
                 quant_config: RTNConfig) -> None:
        self.scale = scale
        self.quant_config = quant_config

    @property
    def data(self):
        return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
    """Linear method for RTN.

    Args:
        quant_config: The RTN quantization config.
    """

    def __init__(self, quant_config: RTNConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        num_groups_per_col = (input_size_per_partition //
                              self.quant_config.group_size
                              if self.quant_config.group_size != -1 else 1)

        scale = Parameter(
            torch.empty(output_size_per_partition,
                        num_groups_per_col,
                        dtype=params_dtype),
            requires_grad=False,
        )
        factor = 1 if self.quant_config.weight_bits == 8 else 2

        weight = RTNParameter(data=torch.empty(output_size_per_partition //
                                               factor,
                                               input_size_per_partition,
                                               dtype=torch.int8),
                              scale=scale,
                              quant_config=self.quant_config)

        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })

        layer.register_parameter("scale", scale)
        layer.output_size_per_partition = output_size_per_partition

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """torch.compile does not know how to deal with a Parameter subclass
        (aka RTNParameter). As we don't really need RTNParameters for the
        forward pass, we replace them with equivalent instances of Parameters.
        """
        old_weight = layer.weight
        assert isinstance(old_weight, RTNParameter)
        data = old_weight.data.data

        delattr(layer, "weight")

        new_weight = Parameter(data=data, requires_grad=False)
        layer.register_parameter("weight", new_weight)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        qweight = layer.weight
        scale = layer.scale

        weight = rtn_dequantize(qweight, scale)
        out = F.linear(x, weight)
        del weight
        if bias is not None:
            out.add_(bias)

        return out


def rtn_quantize(tensor: torch.Tensor, num_bits: int,
                 group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Quantize a tensor using per-group static scaling factor.

    Args:
        tensor: The input tensor.
        num_bits: Target precision for the result (supported values are
                  8 or 4).
        group_size: Quantization granularity. 
                    If equal to -1, each row in the input tensor is treated
                    as one group.
    """

    q_range = 2**num_bits
    num_groups = (tensor.shape[0] * tensor.shape[1] //
                  group_size if group_size != -1 else tensor.shape[0])
    """Calculate a scaling factor per input group.
    """
    input_flat = tensor.reshape(num_groups, -1)
    input_min = torch.min(input_flat, dim=1, keepdim=True)[0]
    input_max = torch.max(input_flat, dim=1, keepdim=True)[0]
    input_max_abs = torch.max(input_min.abs(), input_max.abs())
    scale = (input_max_abs * 2.0 / (q_range - 1))
    """Scale each input group, truncate and round to the nearest integer.
    """
    scaled_input = input_flat / scale
    scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1)
    scaled_input = scaled_input.round()

    scale = scale.reshape(tensor.shape[0], -1).contiguous()
    inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8)
    inputs_q = inputs_q.contiguous()

    if num_bits == 4:
        """Pack two 4-bit values into each byte.
        """
        inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf)
        inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1])
        inputs_q = inputs_q.contiguous()

    return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """Dequantize a tensor using per-group static scaling factors.

    Args:
        tensor: The input tensor.
        scale: The tensor with per-group scale factors.
    """

    num_groups = scale.size(0) * scale.size(1)
    input_dim, output_dim = tensor.shape

    num_bits = 8 if input_dim == scale.size(0) else 4
    if num_bits == 4:
        input_dim *= 2

    data = torch.empty((input_dim, output_dim),
                       dtype=scale.dtype,
                       device=tensor.device)

    if num_bits == 8:
        data.copy_(tensor)
    else:
        """Unpack two 4-bit values from each byte.
        """
        tensor = tensor.reshape(input_dim, output_dim // 2)
        for i in range(2):
            data[:, i::2] = (tensor << 4 * (1 - i)) >> 4
    """Scale each input group with its scaling factor.
    """
    scale = scale.reshape(num_groups, -1)
    data = data.reshape(num_groups, -1)
    data = torch.mul(data, scale)

    input_deq = data.reshape((input_dim, output_dim)).contiguous()
    return input_deq