rtn.py 19.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
# Copyright © 2025, Oracle and/or its affiliates.

import os
6
from typing import Any, Optional
7

8
import numpy as np
9
10
11
12
import torch
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
13
from vllm.model_executor.layers.fused_moe.config import (
14
    FusedMoEConfig,
15
16
    FusedMoEQuantConfig,
)
17
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
18
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
19
20
21
22
23
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    set_weight_attrs,
)
24
25
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
    QuantizationConfig,
    QuantizeMethodBase,
)
29
30
31
32
33
34
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    apply_rtn_marlin_linear,
    marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types
35
36
37
38
39

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be 
overridden by setting the RTN_NUM_BITS envvar
"""
40
NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
41
42
43
"""By default, use group size of 128 parameters, but it can be 
overridden by setting the RTN_GROUP_SIZE envvar
"""
44
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
45
46
47
"""Global Marlin workspace shared by all modules
"""
workspace = None
48
49
50


class RTNConfig(QuantizationConfig):
51
    """Config class for RTN."""
52
53

    def __init__(
54
55
56
        self,
        weight_bits: int = int(NUM_BITS),
        group_size: int = int(GROUP_SIZE),
57
58
59
60
61
62
63
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size

        if self.weight_bits != 4 and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 4-bit or 8-bit weight quantization is "
64
65
                f"supported for RTN, but got {self.weight_bits} bits."
            )
66

67
68
69
70
        self.quant_type = (
            scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
        )

71
    def __repr__(self) -> str:
72
73
74
        return (
            f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
        )
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "rtn"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits, group_size)

98
99
100
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
101
102
        if isinstance(layer, LinearBase):
            return RTNLinearMethod(self)
103
        elif isinstance(layer, FusedMoE):
104
            return RTNMoEMethod(self, layer.moe_config)
105
106
107
108
109
110
111
112
        return None


class RTNTensor:
    """A wrapper over Tensor that enables quantization on-the-fly by
    overloading the copy_ method.
    """

113
114
115
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
116
117
118
119
120
121
122
123
        self.data = data
        self.scale = scale
        self.quant_config = quant_config

    def narrow(self, dim, start, length):
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return RTNTensor(
            self.data.narrow(dim, start // factor, length // factor),
124
125
126
            self.scale.narrow(dim, start, length),
            self.quant_config,
        )
127

128
129
130
    def __getitem__(self, key):
        return RTNTensor(self.data[key], self.scale[key], self.quant_config)

131
132
133
134
    @property
    def shape(self):
        shape = self.data.shape
        factor = 1 if self.quant_config.weight_bits == 8 else 2
135
136
137
138
139
        batch_present = len(shape) == 3
        if batch_present:
            return torch.Size((shape[0], shape[1] * factor, shape[2]))
        else:
            return torch.Size((shape[0] * factor, shape[1]))
140
141

    def copy_(self, loaded_weight: torch.Tensor) -> None:
142
143
144
145
146
        qweight, weight_scale = rtn_quantize(
            loaded_weight.cuda(),
            self.quant_config.weight_bits,
            self.quant_config.group_size,
        )
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        self.data.copy_(qweight)
        self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
    """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
    when its data is accessed. We need this wrapper for the data loading phase
    only, so we can intercept a weight copying function (torch.Tensor.copy_)
    and apply quantization on-the-fly.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):
        return super().__new__(cls, data=data, requires_grad=False)

162
163
164
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        self.scale = scale
        self.quant_config = quant_config

    @property
    def data(self):
        return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
    """Linear method for RTN.

    Args:
        quant_config: The RTN quantization config.
    """

    def __init__(self, quant_config: RTNConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
194
195
196
197
198
        num_groups_per_col = (
            input_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
199
200

        scale = Parameter(
201
202
203
            torch.empty(
                output_size_per_partition, num_groups_per_col, dtype=params_dtype
            ),
204
205
206
207
            requires_grad=False,
        )
        factor = 1 if self.quant_config.weight_bits == 8 else 2

208
209
210
211
212
213
214
215
216
        weight = RTNParameter(
            data=torch.empty(
                output_size_per_partition // factor,
                input_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=scale,
            quant_config=self.quant_config,
        )
217
218

        layer.register_parameter("weight", weight)
219
220
221
222
223
224
225
226
        set_weight_attrs(
            weight,
            {
                **extra_weight_attrs,
                "input_dim": 1,
                "output_dim": 0,
            },
        )
227
228
229
230
231

        layer.register_parameter("scale", scale)
        layer.output_size_per_partition = output_size_per_partition

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
232
233
234
235
236
237
238
239
240
        """Repack weights and scales for Marlin kernels."""
        weight_bits = self.quant_config.weight_bits

        weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)

        replace_parameter(layer, "weight", weight)
        replace_parameter(layer, "scale", scale)

        init_workspace(layer.weight.device)
241

242
243
244
245
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
246
        bias: torch.Tensor | None = None,
247
    ) -> torch.Tensor:
248
249
250
251
252
253
254
255
256
257
        return apply_rtn_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.scale,
            workspace=workspace,
            quant_type=self.quant_config.quant_type,
            output_size_per_partition=layer.output_size_per_partition,
            input_size_per_partition=layer.input_size_per_partition,
            bias=bias,
        )
258
259


260
class RTNMoEMethod(FusedMoEMethodBase):
261
262
    def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
        super().__init__(moe)
263
264
        self.quant_config = quant_config

265
266
267
268
269
270
271
272
273
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
274
275
276
        factor = 1 if self.quant_config.weight_bits == 8 else 2

        # Fused gate_up_proj (column parallel)
277
278
279
280
281
        num_groups_per_col = (
            hidden_size // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
282
        w13_scale = Parameter(
283
284
285
286
287
288
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                num_groups_per_col,
                dtype=params_dtype,
            ),
289
290
291
292
            requires_grad=False,
        )
        layer.register_parameter("w13_scale", w13_scale)

293
294
295
296
297
298
299
300
301
302
        w13_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition // factor,
                hidden_size,
                dtype=torch.uint8,
            ),
            scale=w13_scale,
            quant_config=self.quant_config,
        )
303
304
305
306
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
307
308
309
310
311
312
313
314
315
316
317
        num_groups_per_col = (
            intermediate_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
        w2_scale = Parameter(
            torch.zeros(
                num_experts, hidden_size, num_groups_per_col, dtype=params_dtype
            ),
            requires_grad=False,
        )
318
319
        layer.register_parameter("w2_scale", w2_scale)

320
321
322
323
324
325
326
327
328
329
        w2_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                hidden_size // factor,
                intermediate_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=w2_scale,
            quant_config=self.quant_config,
        )
330
331
332
333
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
334
        """Repack weights and scales for Marlin kernels."""
335
        weight_bits = self.quant_config.weight_bits
336
337
338
339
340
341
342
343
344
345
346
347
348
349

        w13_weight, w13_scale = repack_weights(
            layer.w13_weight, layer.w13_scale, weight_bits
        )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w13_scale", w13_scale)

        w2_weight, w2_scale = repack_weights(
            layer.w2_weight, layer.w2_scale, weight_bits
        )
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, "w2_scale", w2_scale)

        init_workspace(layer.w13_weight.device)
350

351
    def get_fused_moe_quant_config(
352
        self, layer: torch.nn.Module
353
    ) -> FusedMoEQuantConfig | None:
354
        return None
355

356
357
    def apply(
        self,
358
        layer: FusedMoE,
359
360
        x: torch.Tensor,
        router_logits: torch.Tensor,
361
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
362
        topk_weights, topk_ids, _ = layer.select_experts(
363
364
            hidden_states=x,
            router_logits=router_logits,
365
        )
366

367
        return fused_marlin_moe(
368
369
370
            x,
            layer.w13_weight,
            layer.w2_weight,
371
372
373
374
375
376
377
378
            getattr(layer, "w13_bias", None),
            getattr(layer, "w2_bias", None),
            layer.w13_scale,
            layer.w2_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=self.quant_config.quant_type.id,
379
380
381
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
382
            workspace=workspace,
383
        )
384
385


386
387
388
def rtn_quantize(
    tensor: torch.Tensor, num_bits: int, group_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
389
390
391
392
393
394
    """Quantize a tensor using per-group static scaling factor.

    Args:
        tensor: The input tensor.
        num_bits: Target precision for the result (supported values are
                  8 or 4).
395
        group_size: Quantization granularity.
396
397
398
                    If equal to -1, each row in the input tensor is treated
                    as one group.
    """
399
400
401
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
402
403

    q_range = 2**num_bits
404
405
406
407
408
    num_groups = (
        tensor.shape[1] * tensor.shape[2] // group_size
        if group_size != -1
        else tensor.shape[1]
    )
409
410
    """Calculate a scaling factor per input group.
    """
411
412
413
    input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
    input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
    input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
414
    input_max_abs = torch.max(input_min.abs(), input_max.abs())
415
    scale = input_max_abs * 2.0 / (q_range - 1)
416
417
    """Scale each input group, round to the nearest integer, shift 
    the range and truncate.
418
419
420
    """
    scaled_input = input_flat / scale
    scaled_input = scaled_input.round()
421
422
    scaled_input += q_range // 2
    scaled_input = scaled_input.clamp(0, q_range - 1)
423

424
425
    scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
    inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
426
427
428
429
430
    inputs_q = inputs_q.contiguous()

    if num_bits == 4:
        """Pack two 4-bit values into each byte.
        """
431
432
433
434
        inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
        inputs_q = inputs_q.reshape(
            tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
        )
435
436
        inputs_q = inputs_q.contiguous()

437
438
439
440
    if not batch_present:
        inputs_q = inputs_q.squeeze(0)
        scale = scale.squeeze(0)

441
442
443
444
445
446
447
448
449
450
    return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """Dequantize a tensor using per-group static scaling factors.

    Args:
        tensor: The input tensor.
        scale: The tensor with per-group scale factors.
    """
451
452
453
454
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
        scale = scale.unsqueeze(0)
455

456
457
    num_groups = scale.size(1) * scale.size(2)
    batch, input_dim, output_dim = tensor.shape
458

459
460
    num_bits = 8 if input_dim == scale.size(1) else 4
    q_range = 2**num_bits
461
462
463
    if num_bits == 4:
        input_dim *= 2

464
465
466
    data = torch.empty(
        (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device
    )
467
468
469

    if num_bits == 8:
        data.copy_(tensor)
470
        data -= q_range // 2
471
472
473
    else:
        """Unpack two 4-bit values from each byte.
        """
474
        tensor = tensor.reshape(batch, input_dim, output_dim // 2)
475
        for i in range(2):
476
477
478
            data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to(
                torch.int8
            ) - q_range // 2
479
480
    """Scale each input group with its scaling factor.
    """
481
482
    scale = scale.reshape(batch, num_groups, -1)
    data = data.reshape(batch, num_groups, -1)
483
484
    data = torch.mul(data, scale)

485
486
487
488
    input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
    if not batch_present:
        input_deq = input_deq.squeeze(0)

489
    return input_deq
490
491


492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def _get_perms():
    perm = []
    for i in range(32):
        perm1 = []
        col = i // 4
        for block in [0, 1]:
            for row in [
                2 * (i % 4),
                2 * (i % 4) + 1,
                2 * (i % 4 + 4),
                2 * (i % 4 + 4) + 1,
            ]:
                perm1.append(16 * row + col + 8 * block)
        for j in range(4):
            perm.extend([p + 256 * j for p in perm1])

    perm_arr = np.array(perm)
    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
    perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
    perm_tensor = torch.from_numpy(perm_arr)
    scale_perm = []
    for i in range(8):
        scale_perm.extend([i + 8 * j for j in range(8)])
    scale_perm_single = []
    for i in range(4):
        scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
    return perm_tensor, scale_perm, scale_perm_single


_perm, _scale_perm, _scale_perm_single = _get_perms()


def pack_for_marlin(weight, scale, qbits):
    batch = weight.shape[0]

    n = weight.size(1)
    k = weight.size(2)
    groupsize = k // scale.size(2)

    tile = 16
    s = scale.permute(0, 2, 1)  # transpose
    w = weight.permute(0, 2, 1)  # transpose
    if groupsize != k:
        w = w.reshape((batch, -1, groupsize, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, groupsize, -1))
        s = s.reshape((batch, 1, -1))

    if groupsize != k:
        w = w.reshape((batch, groupsize, -1, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, k, n)).contiguous()
        s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
    else:
        s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
    s = s.reshape((batch, -1, n)).contiguous()
    w = w.reshape((batch, k // tile, tile, n // tile, tile))
    w = w.permute((0, 1, 3, 2, 4))
    w = w.reshape((batch, k // tile, n * tile))
    res = w
    res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
    if qbits == 4:
        q = torch.zeros(
            (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
        )
        for i in range(2):
            q |= res[:, :, i::2] << 4 * i
        q = q.reshape(batch, -1, n).contiguous()
    else:
        q = res.clone()
        q[:, :, 2::8] = res[:, :, 4::8]
        q[:, :, 3::8] = res[:, :, 5::8]
        q[:, :, 4::8] = res[:, :, 2::8]
        q[:, :, 5::8] = res[:, :, 3::8]
        q = q.reshape(batch, -1, n).to(torch.int8).contiguous()

    return q, s


def repack_8bit_into_32bit(input):
    output = torch.zeros(
        (input.shape[0], input.shape[1], input.shape[2] // 4),
        dtype=torch.int32,
        device=input.device,
    )
    for i in range(4):
        output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i

    return output


def repack_weights(qweight, scale, weight_bits):
    batch_present = len(qweight.shape) == 3
    if not batch_present:
        qweight = qweight.unsqueeze(0)
        scale = scale.unsqueeze(0)

    if weight_bits == 4:
        """Unpack two 4-bit values from each byte.
        """
        qweight_unpacked = torch.empty(
            (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
            dtype=torch.uint8,
            device=qweight.device,
        )
        for i in range(2):
            qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
                qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
            )
    else:
        qweight_unpacked = qweight

    qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
    """Marlin kernels expect tensors in int32 format in a certain shape
606
    """
607
608
609
610
611
612
613
614
615
    qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
    qweight_reshaped = qweight_repacked.reshape(
        qweight.shape[0], qweight.shape[2] // 16, -1
    )
    if not batch_present:
        qweight_reshaped = qweight_reshaped.squeeze(0)
        scale_packed = scale_packed.squeeze(0)

    return qweight_reshaped, scale_packed
616
617


618
619
620
621
def init_workspace(device):
    global workspace
    if workspace is None:
        workspace = marlin_make_workspace_new(device, 4)