rtn.py 19.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
# Copyright © 2025, Oracle and/or its affiliates.

import os
6
from typing import Any, Optional
7

8
import numpy as np
9
10
11
12
import torch
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
13
from vllm.model_executor.layers.fused_moe import FusedMoERouter
14
from vllm.model_executor.layers.fused_moe.config import (
15
    FusedMoEConfig,
16
17
    FusedMoEQuantConfig,
)
18
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
19
20
21
22
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE,
    FusedMoEMethodBase,
)
23
24
25
26
27
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    set_weight_attrs,
)
28
29
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
30
31
32
    QuantizationConfig,
    QuantizeMethodBase,
)
33
34
35
36
37
38
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    apply_rtn_marlin_linear,
    marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types
39
40
41
42
43

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be 
overridden by setting the RTN_NUM_BITS envvar
"""
44
NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
45
46
47
"""By default, use group size of 128 parameters, but it can be 
overridden by setting the RTN_GROUP_SIZE envvar
"""
48
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
49
50
51
"""Global Marlin workspace shared by all modules
"""
workspace = None
52
53
54


class RTNConfig(QuantizationConfig):
55
    """Config class for RTN."""
56
57

    def __init__(
58
59
60
        self,
        weight_bits: int = int(NUM_BITS),
        group_size: int = int(GROUP_SIZE),
61
62
63
64
65
66
67
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size

        if self.weight_bits != 4 and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 4-bit or 8-bit weight quantization is "
68
69
                f"supported for RTN, but got {self.weight_bits} bits."
            )
70

71
72
73
74
        self.quant_type = (
            scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
        )

75
    def __repr__(self) -> str:
76
77
78
        return (
            f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
        )
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "rtn"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits, group_size)

102
103
104
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
105
106
        if isinstance(layer, LinearBase):
            return RTNLinearMethod(self)
107
        elif isinstance(layer, FusedMoE):
108
            return RTNMoEMethod(self, layer.moe_config)
109
110
111
112
113
114
115
116
        return None


class RTNTensor:
    """A wrapper over Tensor that enables quantization on-the-fly by
    overloading the copy_ method.
    """

117
118
119
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
120
121
122
123
124
125
126
127
        self.data = data
        self.scale = scale
        self.quant_config = quant_config

    def narrow(self, dim, start, length):
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return RTNTensor(
            self.data.narrow(dim, start // factor, length // factor),
128
129
130
            self.scale.narrow(dim, start, length),
            self.quant_config,
        )
131

132
133
134
    def __getitem__(self, key):
        return RTNTensor(self.data[key], self.scale[key], self.quant_config)

135
136
137
138
    @property
    def shape(self):
        shape = self.data.shape
        factor = 1 if self.quant_config.weight_bits == 8 else 2
139
140
141
142
143
        batch_present = len(shape) == 3
        if batch_present:
            return torch.Size((shape[0], shape[1] * factor, shape[2]))
        else:
            return torch.Size((shape[0] * factor, shape[1]))
144
145

    def copy_(self, loaded_weight: torch.Tensor) -> None:
146
147
148
149
150
        qweight, weight_scale = rtn_quantize(
            loaded_weight.cuda(),
            self.quant_config.weight_bits,
            self.quant_config.group_size,
        )
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        self.data.copy_(qweight)
        self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
    """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
    when its data is accessed. We need this wrapper for the data loading phase
    only, so we can intercept a weight copying function (torch.Tensor.copy_)
    and apply quantization on-the-fly.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):
        return super().__new__(cls, data=data, requires_grad=False)

166
167
168
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.scale = scale
        self.quant_config = quant_config

    @property
    def data(self):
        return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
    """Linear method for RTN.

    Args:
        quant_config: The RTN quantization config.
    """

    def __init__(self, quant_config: RTNConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
198
199
200
201
202
        num_groups_per_col = (
            input_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
203
204

        scale = Parameter(
205
206
207
            torch.empty(
                output_size_per_partition, num_groups_per_col, dtype=params_dtype
            ),
208
209
210
211
            requires_grad=False,
        )
        factor = 1 if self.quant_config.weight_bits == 8 else 2

212
213
214
215
216
217
218
219
220
        weight = RTNParameter(
            data=torch.empty(
                output_size_per_partition // factor,
                input_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=scale,
            quant_config=self.quant_config,
        )
221
222

        layer.register_parameter("weight", weight)
223
224
225
226
227
228
229
230
        set_weight_attrs(
            weight,
            {
                **extra_weight_attrs,
                "input_dim": 1,
                "output_dim": 0,
            },
        )
231
232
233
234
235

        layer.register_parameter("scale", scale)
        layer.output_size_per_partition = output_size_per_partition

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
236
237
238
239
240
241
242
243
244
        """Repack weights and scales for Marlin kernels."""
        weight_bits = self.quant_config.weight_bits

        weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)

        replace_parameter(layer, "weight", weight)
        replace_parameter(layer, "scale", scale)

        init_workspace(layer.weight.device)
245

246
247
248
249
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
250
        bias: torch.Tensor | None = None,
251
    ) -> torch.Tensor:
252
253
254
255
256
257
258
259
260
261
        return apply_rtn_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.scale,
            workspace=workspace,
            quant_type=self.quant_config.quant_type,
            output_size_per_partition=layer.output_size_per_partition,
            input_size_per_partition=layer.input_size_per_partition,
            bias=bias,
        )
262
263


264
class RTNMoEMethod(FusedMoEMethodBase):
265
266
    def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
        super().__init__(moe)
267
268
        self.quant_config = quant_config

269
270
271
272
273
274
275
276
277
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
278
279
280
        factor = 1 if self.quant_config.weight_bits == 8 else 2

        # Fused gate_up_proj (column parallel)
281
282
283
284
285
        num_groups_per_col = (
            hidden_size // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
286
        w13_scale = Parameter(
287
288
289
290
291
292
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                num_groups_per_col,
                dtype=params_dtype,
            ),
293
294
295
296
            requires_grad=False,
        )
        layer.register_parameter("w13_scale", w13_scale)

297
298
299
300
301
302
303
304
305
306
        w13_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition // factor,
                hidden_size,
                dtype=torch.uint8,
            ),
            scale=w13_scale,
            quant_config=self.quant_config,
        )
307
308
309
310
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
311
312
313
314
315
316
317
318
319
320
321
        num_groups_per_col = (
            intermediate_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
        w2_scale = Parameter(
            torch.zeros(
                num_experts, hidden_size, num_groups_per_col, dtype=params_dtype
            ),
            requires_grad=False,
        )
322
323
        layer.register_parameter("w2_scale", w2_scale)

324
325
326
327
328
329
330
331
332
333
        w2_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                hidden_size // factor,
                intermediate_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=w2_scale,
            quant_config=self.quant_config,
        )
334
335
336
337
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
338
        """Repack weights and scales for Marlin kernels."""
339
        weight_bits = self.quant_config.weight_bits
340
341
342
343
344
345
346
347
348
349
350
351
352
353

        w13_weight, w13_scale = repack_weights(
            layer.w13_weight, layer.w13_scale, weight_bits
        )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w13_scale", w13_scale)

        w2_weight, w2_scale = repack_weights(
            layer.w2_weight, layer.w2_scale, weight_bits
        )
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, "w2_scale", w2_scale)

        init_workspace(layer.w13_weight.device)
354

355
    def get_fused_moe_quant_config(
356
        self, layer: torch.nn.Module
357
    ) -> FusedMoEQuantConfig | None:
358
        return None
359

360
361
    def apply(
        self,
362
        layer: FusedMoE,
363
        router: FusedMoERouter,
364
365
        x: torch.Tensor,
        router_logits: torch.Tensor,
366
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
367
        topk_weights, topk_ids = router.select_experts(
368
369
            hidden_states=x,
            router_logits=router_logits,
370
        )
371

372
        return fused_marlin_moe(
373
374
375
            x,
            layer.w13_weight,
            layer.w2_weight,
376
377
378
379
380
381
382
383
            getattr(layer, "w13_bias", None),
            getattr(layer, "w2_bias", None),
            layer.w13_scale,
            layer.w2_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=self.quant_config.quant_type.id,
384
385
386
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
387
            workspace=workspace,
388
        )
389
390


391
392
393
def rtn_quantize(
    tensor: torch.Tensor, num_bits: int, group_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
394
395
396
397
398
399
    """Quantize a tensor using per-group static scaling factor.

    Args:
        tensor: The input tensor.
        num_bits: Target precision for the result (supported values are
                  8 or 4).
400
        group_size: Quantization granularity.
401
402
403
                    If equal to -1, each row in the input tensor is treated
                    as one group.
    """
404
405
406
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
407
408

    q_range = 2**num_bits
409
410
411
412
413
    num_groups = (
        tensor.shape[1] * tensor.shape[2] // group_size
        if group_size != -1
        else tensor.shape[1]
    )
414
415
    """Calculate a scaling factor per input group.
    """
416
417
418
    input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
    input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
    input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
419
    input_max_abs = torch.max(input_min.abs(), input_max.abs())
420
    scale = input_max_abs * 2.0 / (q_range - 1)
421
422
    """Scale each input group, round to the nearest integer, shift 
    the range and truncate.
423
424
425
    """
    scaled_input = input_flat / scale
    scaled_input = scaled_input.round()
426
427
    scaled_input += q_range // 2
    scaled_input = scaled_input.clamp(0, q_range - 1)
428

429
430
    scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
    inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
431
432
433
434
435
    inputs_q = inputs_q.contiguous()

    if num_bits == 4:
        """Pack two 4-bit values into each byte.
        """
436
437
438
439
        inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
        inputs_q = inputs_q.reshape(
            tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
        )
440
441
        inputs_q = inputs_q.contiguous()

442
443
444
445
    if not batch_present:
        inputs_q = inputs_q.squeeze(0)
        scale = scale.squeeze(0)

446
447
448
449
450
451
452
453
454
455
    return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """Dequantize a tensor using per-group static scaling factors.

    Args:
        tensor: The input tensor.
        scale: The tensor with per-group scale factors.
    """
456
457
458
459
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
        scale = scale.unsqueeze(0)
460

461
462
    num_groups = scale.size(1) * scale.size(2)
    batch, input_dim, output_dim = tensor.shape
463

464
465
    num_bits = 8 if input_dim == scale.size(1) else 4
    q_range = 2**num_bits
466
467
468
    if num_bits == 4:
        input_dim *= 2

469
470
471
    data = torch.empty(
        (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device
    )
472
473
474

    if num_bits == 8:
        data.copy_(tensor)
475
        data -= q_range // 2
476
477
478
    else:
        """Unpack two 4-bit values from each byte.
        """
479
        tensor = tensor.reshape(batch, input_dim, output_dim // 2)
480
        for i in range(2):
481
482
483
            data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to(
                torch.int8
            ) - q_range // 2
484
485
    """Scale each input group with its scaling factor.
    """
486
487
    scale = scale.reshape(batch, num_groups, -1)
    data = data.reshape(batch, num_groups, -1)
488
489
    data = torch.mul(data, scale)

490
491
492
493
    input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
    if not batch_present:
        input_deq = input_deq.squeeze(0)

494
    return input_deq
495
496


497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
def _get_perms():
    perm = []
    for i in range(32):
        perm1 = []
        col = i // 4
        for block in [0, 1]:
            for row in [
                2 * (i % 4),
                2 * (i % 4) + 1,
                2 * (i % 4 + 4),
                2 * (i % 4 + 4) + 1,
            ]:
                perm1.append(16 * row + col + 8 * block)
        for j in range(4):
            perm.extend([p + 256 * j for p in perm1])

    perm_arr = np.array(perm)
    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
    perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
    perm_tensor = torch.from_numpy(perm_arr)
    scale_perm = []
    for i in range(8):
        scale_perm.extend([i + 8 * j for j in range(8)])
    scale_perm_single = []
    for i in range(4):
        scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
    return perm_tensor, scale_perm, scale_perm_single


_perm, _scale_perm, _scale_perm_single = _get_perms()


def pack_for_marlin(weight, scale, qbits):
    batch = weight.shape[0]

    n = weight.size(1)
    k = weight.size(2)
    groupsize = k // scale.size(2)

    tile = 16
    s = scale.permute(0, 2, 1)  # transpose
    w = weight.permute(0, 2, 1)  # transpose
    if groupsize != k:
        w = w.reshape((batch, -1, groupsize, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, groupsize, -1))
        s = s.reshape((batch, 1, -1))

    if groupsize != k:
        w = w.reshape((batch, groupsize, -1, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, k, n)).contiguous()
        s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
    else:
        s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
    s = s.reshape((batch, -1, n)).contiguous()
    w = w.reshape((batch, k // tile, tile, n // tile, tile))
    w = w.permute((0, 1, 3, 2, 4))
    w = w.reshape((batch, k // tile, n * tile))
    res = w
    res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
    if qbits == 4:
        q = torch.zeros(
            (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
        )
        for i in range(2):
            q |= res[:, :, i::2] << 4 * i
        q = q.reshape(batch, -1, n).contiguous()
    else:
        q = res.clone()
        q[:, :, 2::8] = res[:, :, 4::8]
        q[:, :, 3::8] = res[:, :, 5::8]
        q[:, :, 4::8] = res[:, :, 2::8]
        q[:, :, 5::8] = res[:, :, 3::8]
        q = q.reshape(batch, -1, n).to(torch.int8).contiguous()

    return q, s


def repack_8bit_into_32bit(input):
    output = torch.zeros(
        (input.shape[0], input.shape[1], input.shape[2] // 4),
        dtype=torch.int32,
        device=input.device,
    )
    for i in range(4):
        output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i

    return output


def repack_weights(qweight, scale, weight_bits):
    batch_present = len(qweight.shape) == 3
    if not batch_present:
        qweight = qweight.unsqueeze(0)
        scale = scale.unsqueeze(0)

    if weight_bits == 4:
        """Unpack two 4-bit values from each byte.
        """
        qweight_unpacked = torch.empty(
            (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
            dtype=torch.uint8,
            device=qweight.device,
        )
        for i in range(2):
            qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
                qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
            )
    else:
        qweight_unpacked = qweight

    qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
    """Marlin kernels expect tensors in int32 format in a certain shape
611
    """
612
613
614
615
616
617
618
619
620
    qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
    qweight_reshaped = qweight_repacked.reshape(
        qweight.shape[0], qweight.shape[2] // 16, -1
    )
    if not batch_present:
        qweight_reshaped = qweight_reshaped.squeeze(0)
        scale_packed = scale_packed.squeeze(0)

    return qweight_reshaped, scale_packed
621
622


623
624
625
626
def init_workspace(device):
    global workspace
    if workspace is None:
        workspace = marlin_make_workspace_new(device, 4)