rtn.py 16.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
# Copyright © 2025, Oracle and/or its affiliates.

import os
6
7
from collections.abc import Callable
from typing import Any, Optional
8
9
10
11
12
13

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
14
15
16
17
18
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
19
from vllm.model_executor.layers.fused_moe.config import (
20
21
22
23
24
25
26
27
28
    FusedMoEQuantConfig,
    int4_w4a16_moe_quant_config,
    int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    set_weight_attrs,
)
29
30
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
31
32
33
    QuantizationConfig,
    QuantizeMethodBase,
)
34
35
36
37
38

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be 
overridden by setting the RTN_NUM_BITS envvar
"""
39
NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
40
41
42
"""By default, use group size of 128 parameters, but it can be 
overridden by setting the RTN_GROUP_SIZE envvar
"""
43
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
44
45
46


class RTNConfig(QuantizationConfig):
47
    """Config class for RTN."""
48
49

    def __init__(
50
51
52
        self,
        weight_bits: int = int(NUM_BITS),
        group_size: int = int(GROUP_SIZE),
53
54
55
56
57
58
59
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size

        if self.weight_bits != 4 and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 4-bit or 8-bit weight quantization is "
60
61
                f"supported for RTN, but got {self.weight_bits} bits."
            )
62
63

    def __repr__(self) -> str:
64
65
66
        return (
            f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
        )
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "rtn"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits, group_size)

90
91
92
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
93
94
        if isinstance(layer, LinearBase):
            return RTNLinearMethod(self)
95
        elif isinstance(layer, FusedMoE):
96
            return RTNMoEMethod(self, layer.moe_config)
97
98
99
100
101
102
103
104
        return None


class RTNTensor:
    """A wrapper over Tensor that enables quantization on-the-fly by
    overloading the copy_ method.
    """

105
106
107
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
108
109
110
111
112
113
114
115
        self.data = data
        self.scale = scale
        self.quant_config = quant_config

    def narrow(self, dim, start, length):
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return RTNTensor(
            self.data.narrow(dim, start // factor, length // factor),
116
117
118
            self.scale.narrow(dim, start, length),
            self.quant_config,
        )
119

120
121
122
    def __getitem__(self, key):
        return RTNTensor(self.data[key], self.scale[key], self.quant_config)

123
124
125
126
    @property
    def shape(self):
        shape = self.data.shape
        factor = 1 if self.quant_config.weight_bits == 8 else 2
127
128
129
130
131
        batch_present = len(shape) == 3
        if batch_present:
            return torch.Size((shape[0], shape[1] * factor, shape[2]))
        else:
            return torch.Size((shape[0] * factor, shape[1]))
132
133

    def copy_(self, loaded_weight: torch.Tensor) -> None:
134
135
136
137
138
        qweight, weight_scale = rtn_quantize(
            loaded_weight.cuda(),
            self.quant_config.weight_bits,
            self.quant_config.group_size,
        )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        self.data.copy_(qweight)
        self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
    """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
    when its data is accessed. We need this wrapper for the data loading phase
    only, so we can intercept a weight copying function (torch.Tensor.copy_)
    and apply quantization on-the-fly.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):
        return super().__new__(cls, data=data, requires_grad=False)

154
155
156
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        self.scale = scale
        self.quant_config = quant_config

    @property
    def data(self):
        return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
    """Linear method for RTN.

    Args:
        quant_config: The RTN quantization config.
    """

    def __init__(self, quant_config: RTNConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
186
187
188
189
190
        num_groups_per_col = (
            input_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
191
192

        scale = Parameter(
193
194
195
            torch.empty(
                output_size_per_partition, num_groups_per_col, dtype=params_dtype
            ),
196
197
198
199
            requires_grad=False,
        )
        factor = 1 if self.quant_config.weight_bits == 8 else 2

200
201
202
203
204
205
206
207
208
        weight = RTNParameter(
            data=torch.empty(
                output_size_per_partition // factor,
                input_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=scale,
            quant_config=self.quant_config,
        )
209
210

        layer.register_parameter("weight", weight)
211
212
213
214
215
216
217
218
        set_weight_attrs(
            weight,
            {
                **extra_weight_attrs,
                "input_dim": 1,
                "output_dim": 0,
            },
        )
219
220
221
222
223

        layer.register_parameter("scale", scale)
        layer.output_size_per_partition = output_size_per_partition

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
224
        fix_weights(layer, "weight")
225

226
227
228
229
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
230
        bias: torch.Tensor | None = None,
231
    ) -> torch.Tensor:
232
233
234
235
236
237
238
239
240
241
242
243
        qweight = layer.weight
        scale = layer.scale

        weight = rtn_dequantize(qweight, scale)
        out = F.linear(x, weight)
        del weight
        if bias is not None:
            out.add_(bias)

        return out


244
class RTNMoEMethod(FusedMoEMethodBase):
245
246
    def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
        super().__init__(moe)
247
248
        self.quant_config = quant_config

249
250
251
252
253
254
255
256
257
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
258
259
260
        factor = 1 if self.quant_config.weight_bits == 8 else 2

        # Fused gate_up_proj (column parallel)
261
262
263
264
265
        num_groups_per_col = (
            hidden_size // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
266
        w13_scale = Parameter(
267
268
269
270
271
272
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                num_groups_per_col,
                dtype=params_dtype,
            ),
273
274
275
276
            requires_grad=False,
        )
        layer.register_parameter("w13_scale", w13_scale)

277
278
279
280
281
282
283
284
285
286
        w13_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition // factor,
                hidden_size,
                dtype=torch.uint8,
            ),
            scale=w13_scale,
            quant_config=self.quant_config,
        )
287
288
289
290
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
291
292
293
294
295
296
297
298
299
300
301
        num_groups_per_col = (
            intermediate_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
        w2_scale = Parameter(
            torch.zeros(
                num_experts, hidden_size, num_groups_per_col, dtype=params_dtype
            ),
            requires_grad=False,
        )
302
303
        layer.register_parameter("w2_scale", w2_scale)

304
305
306
307
308
309
310
311
312
313
        w2_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                hidden_size // factor,
                intermediate_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=w2_scale,
            quant_config=self.quant_config,
        )
314
315
316
317
318
319
320
321
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight_bits = self.quant_config.weight_bits
        fix_weights(layer, "w13_weight", weight_bits == 4)
        fix_weights(layer, "w2_weight", weight_bits == 4)

322
    def get_fused_moe_quant_config(
323
        self, layer: torch.nn.Module
324
    ) -> FusedMoEQuantConfig | None:
325
326
327
        weight_bits = self.quant_config.weight_bits
        group_size = self.quant_config.group_size
        assert weight_bits == 4 or weight_bits == 8
328
329
330
331
332
        config_builder = (
            int4_w4a16_moe_quant_config
            if weight_bits == 4
            else int8_w8a16_moe_quant_config
        )
333
334
335
336
337
338
339
340
        return config_builder(
            w1_scale=layer.w13_scale,
            w2_scale=layer.w2_scale,
            w1_zp=None,
            w2_zp=None,
            block_shape=[0, group_size],
        )

341
342
343
344
345
346
347
348
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
349
350
        topk_group: int | None = None,
        num_expert_group: int | None = None,
351
        global_num_experts: int = -1,
352
353
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
354
        scoring_func: str = "softmax",
355
        routed_scaling_factor: float = 1.0,
356
        e_score_correction_bias: torch.Tensor | None = None,
357
358
359
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
360
361
362
363
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
364
365
        assert self.fused_experts is None

366
        if enable_eplb:
367
            raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
368
369
370

        from vllm.model_executor.layers.fused_moe import fused_experts

XuruiYang's avatar
XuruiYang committed
371
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
372
373
374
375
376
377
378
379
380
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
381
            routed_scaling_factor=routed_scaling_factor,
382
            e_score_correction_bias=e_score_correction_bias,
383
384
            indices_type=self.topk_indices_dtype,
        )
385

386
        return fused_experts(
387
388
389
390
391
392
393
394
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
395
            global_num_experts=global_num_experts,
396
            expert_map=expert_map,
397
398
            quant_config=self.moe_quant_config,
        )
399
400


401
402
403
def rtn_quantize(
    tensor: torch.Tensor, num_bits: int, group_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
404
405
406
407
408
409
    """Quantize a tensor using per-group static scaling factor.

    Args:
        tensor: The input tensor.
        num_bits: Target precision for the result (supported values are
                  8 or 4).
410
        group_size: Quantization granularity.
411
412
413
                    If equal to -1, each row in the input tensor is treated
                    as one group.
    """
414
415
416
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
417
418

    q_range = 2**num_bits
419
420
421
422
423
    num_groups = (
        tensor.shape[1] * tensor.shape[2] // group_size
        if group_size != -1
        else tensor.shape[1]
    )
424
425
    """Calculate a scaling factor per input group.
    """
426
427
428
    input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
    input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
    input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
429
    input_max_abs = torch.max(input_min.abs(), input_max.abs())
430
    scale = input_max_abs * 2.0 / (q_range - 1)
431
432
    """Scale each input group, round to the nearest integer, shift 
    the range and truncate.
433
434
435
    """
    scaled_input = input_flat / scale
    scaled_input = scaled_input.round()
436
437
    scaled_input += q_range // 2
    scaled_input = scaled_input.clamp(0, q_range - 1)
438

439
440
    scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
    inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
441
442
443
444
445
    inputs_q = inputs_q.contiguous()

    if num_bits == 4:
        """Pack two 4-bit values into each byte.
        """
446
447
448
449
        inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
        inputs_q = inputs_q.reshape(
            tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
        )
450
451
        inputs_q = inputs_q.contiguous()

452
453
454
455
    if not batch_present:
        inputs_q = inputs_q.squeeze(0)
        scale = scale.squeeze(0)

456
457
458
459
460
461
462
463
464
465
    return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """Dequantize a tensor using per-group static scaling factors.

    Args:
        tensor: The input tensor.
        scale: The tensor with per-group scale factors.
    """
466
467
468
469
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
        scale = scale.unsqueeze(0)
470

471
472
    num_groups = scale.size(1) * scale.size(2)
    batch, input_dim, output_dim = tensor.shape
473

474
475
    num_bits = 8 if input_dim == scale.size(1) else 4
    q_range = 2**num_bits
476
477
478
    if num_bits == 4:
        input_dim *= 2

479
480
481
    data = torch.empty(
        (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device
    )
482
483
484

    if num_bits == 8:
        data.copy_(tensor)
485
        data -= q_range // 2
486
487
488
    else:
        """Unpack two 4-bit values from each byte.
        """
489
        tensor = tensor.reshape(batch, input_dim, output_dim // 2)
490
        for i in range(2):
491
492
493
            data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to(
                torch.int8
            ) - q_range // 2
494
495
    """Scale each input group with its scaling factor.
    """
496
497
    scale = scale.reshape(batch, num_groups, -1)
    data = data.reshape(batch, num_groups, -1)
498
499
    data = torch.mul(data, scale)

500
501
502
503
    input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
    if not batch_present:
        input_deq = input_deq.squeeze(0)

504
    return input_deq
505
506


507
def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False):
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    """torch.compile does not know how to deal with a Parameter subclass
    (aka RTNParameter). As we don't really need RTNParameters for the
    forward pass, we replace them with equivalent instances of Parameters.
    """
    old_weight = getattr(layer, param_name)
    assert isinstance(old_weight, RTNParameter)
    data = old_weight.data.data

    delattr(layer, param_name)

    if reshape:
        data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
    new_weight = Parameter(data=data, requires_grad=False)
    layer.register_parameter(param_name, new_weight)