rtn.py 20.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
# Copyright © 2025, Oracle and/or its affiliates.

import os
6
7
from collections.abc import Callable
from typing import Any, Optional
8

9
import numpy as np
10
11
12
13
import torch
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
14
from vllm.model_executor.layers.fused_moe.config import (
15
    FusedMoEConfig,
16
17
    FusedMoEQuantConfig,
)
18
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
19
20
21
22
23
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    set_weight_attrs,
)
24
25
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
    QuantizationConfig,
    QuantizeMethodBase,
)
29
30
31
32
33
34
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    apply_rtn_marlin_linear,
    marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types
35
36
37
38
39

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be 
overridden by setting the RTN_NUM_BITS envvar
"""
40
NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
41
42
43
"""By default, use group size of 128 parameters, but it can be 
overridden by setting the RTN_GROUP_SIZE envvar
"""
44
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
45
46
47
"""Global Marlin workspace shared by all modules
"""
workspace = None
48
49
50


class RTNConfig(QuantizationConfig):
51
    """Config class for RTN."""
52
53

    def __init__(
54
55
56
        self,
        weight_bits: int = int(NUM_BITS),
        group_size: int = int(GROUP_SIZE),
57
58
59
60
61
62
63
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size

        if self.weight_bits != 4 and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 4-bit or 8-bit weight quantization is "
64
65
                f"supported for RTN, but got {self.weight_bits} bits."
            )
66

67
68
69
70
        self.quant_type = (
            scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
        )

71
    def __repr__(self) -> str:
72
73
74
        return (
            f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
        )
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "rtn"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits, group_size)

98
99
100
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
101
102
        if isinstance(layer, LinearBase):
            return RTNLinearMethod(self)
103
        elif isinstance(layer, FusedMoE):
104
            return RTNMoEMethod(self, layer.moe_config)
105
106
107
108
109
110
111
112
        return None


class RTNTensor:
    """A wrapper over Tensor that enables quantization on-the-fly by
    overloading the copy_ method.
    """

113
114
115
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
116
117
118
119
120
121
122
123
        self.data = data
        self.scale = scale
        self.quant_config = quant_config

    def narrow(self, dim, start, length):
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return RTNTensor(
            self.data.narrow(dim, start // factor, length // factor),
124
125
126
            self.scale.narrow(dim, start, length),
            self.quant_config,
        )
127

128
129
130
    def __getitem__(self, key):
        return RTNTensor(self.data[key], self.scale[key], self.quant_config)

131
132
133
134
    @property
    def shape(self):
        shape = self.data.shape
        factor = 1 if self.quant_config.weight_bits == 8 else 2
135
136
137
138
139
        batch_present = len(shape) == 3
        if batch_present:
            return torch.Size((shape[0], shape[1] * factor, shape[2]))
        else:
            return torch.Size((shape[0] * factor, shape[1]))
140
141

    def copy_(self, loaded_weight: torch.Tensor) -> None:
142
143
144
145
146
        qweight, weight_scale = rtn_quantize(
            loaded_weight.cuda(),
            self.quant_config.weight_bits,
            self.quant_config.group_size,
        )
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        self.data.copy_(qweight)
        self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
    """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
    when its data is accessed. We need this wrapper for the data loading phase
    only, so we can intercept a weight copying function (torch.Tensor.copy_)
    and apply quantization on-the-fly.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):
        return super().__new__(cls, data=data, requires_grad=False)

162
163
164
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        self.scale = scale
        self.quant_config = quant_config

    @property
    def data(self):
        return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
    """Linear method for RTN.

    Args:
        quant_config: The RTN quantization config.
    """

    def __init__(self, quant_config: RTNConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
194
195
196
197
198
        num_groups_per_col = (
            input_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
199
200

        scale = Parameter(
201
202
203
            torch.empty(
                output_size_per_partition, num_groups_per_col, dtype=params_dtype
            ),
204
205
206
207
            requires_grad=False,
        )
        factor = 1 if self.quant_config.weight_bits == 8 else 2

208
209
210
211
212
213
214
215
216
        weight = RTNParameter(
            data=torch.empty(
                output_size_per_partition // factor,
                input_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=scale,
            quant_config=self.quant_config,
        )
217
218

        layer.register_parameter("weight", weight)
219
220
221
222
223
224
225
226
        set_weight_attrs(
            weight,
            {
                **extra_weight_attrs,
                "input_dim": 1,
                "output_dim": 0,
            },
        )
227
228
229
230
231

        layer.register_parameter("scale", scale)
        layer.output_size_per_partition = output_size_per_partition

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
232
233
234
235
236
237
238
239
240
        """Repack weights and scales for Marlin kernels."""
        weight_bits = self.quant_config.weight_bits

        weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)

        replace_parameter(layer, "weight", weight)
        replace_parameter(layer, "scale", scale)

        init_workspace(layer.weight.device)
241

242
243
244
245
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
246
        bias: torch.Tensor | None = None,
247
    ) -> torch.Tensor:
248
249
250
251
252
253
254
255
256
257
        return apply_rtn_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.scale,
            workspace=workspace,
            quant_type=self.quant_config.quant_type,
            output_size_per_partition=layer.output_size_per_partition,
            input_size_per_partition=layer.input_size_per_partition,
            bias=bias,
        )
258
259


260
class RTNMoEMethod(FusedMoEMethodBase):
261
262
    def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
        super().__init__(moe)
263
264
        self.quant_config = quant_config

265
266
267
268
269
270
271
272
273
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
274
275
276
        factor = 1 if self.quant_config.weight_bits == 8 else 2

        # Fused gate_up_proj (column parallel)
277
278
279
280
281
        num_groups_per_col = (
            hidden_size // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
282
        w13_scale = Parameter(
283
284
285
286
287
288
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                num_groups_per_col,
                dtype=params_dtype,
            ),
289
290
291
292
            requires_grad=False,
        )
        layer.register_parameter("w13_scale", w13_scale)

293
294
295
296
297
298
299
300
301
302
        w13_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition // factor,
                hidden_size,
                dtype=torch.uint8,
            ),
            scale=w13_scale,
            quant_config=self.quant_config,
        )
303
304
305
306
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
307
308
309
310
311
312
313
314
315
316
317
        num_groups_per_col = (
            intermediate_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
        w2_scale = Parameter(
            torch.zeros(
                num_experts, hidden_size, num_groups_per_col, dtype=params_dtype
            ),
            requires_grad=False,
        )
318
319
        layer.register_parameter("w2_scale", w2_scale)

320
321
322
323
324
325
326
327
328
329
        w2_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                hidden_size // factor,
                intermediate_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=w2_scale,
            quant_config=self.quant_config,
        )
330
331
332
333
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
334
        """Repack weights and scales for Marlin kernels."""
335
        weight_bits = self.quant_config.weight_bits
336
337
338
339
340
341
342
343
344
345
346
347
348
349

        w13_weight, w13_scale = repack_weights(
            layer.w13_weight, layer.w13_scale, weight_bits
        )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w13_scale", w13_scale)

        w2_weight, w2_scale = repack_weights(
            layer.w2_weight, layer.w2_scale, weight_bits
        )
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, "w2_scale", w2_scale)

        init_workspace(layer.w13_weight.device)
350

351
    def get_fused_moe_quant_config(
352
        self, layer: torch.nn.Module
353
    ) -> FusedMoEQuantConfig | None:
354
        return None
355

356
357
358
359
360
361
362
363
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
364
365
        topk_group: int | None = None,
        num_expert_group: int | None = None,
366
        global_num_experts: int = -1,
367
368
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
369
        scoring_func: str = "softmax",
370
        routed_scaling_factor: float = 1.0,
371
        e_score_correction_bias: torch.Tensor | None = None,
372
373
374
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
375
376
377
378
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
379
380
        assert self.fused_experts is None

381
        if enable_eplb:
382
            raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
383

XuruiYang's avatar
XuruiYang committed
384
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
385
386
387
388
389
390
391
392
393
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
394
            routed_scaling_factor=routed_scaling_factor,
395
            e_score_correction_bias=e_score_correction_bias,
396
397
            indices_type=self.topk_indices_dtype,
        )
398

399
        return torch.ops.vllm.fused_marlin_moe(
400
401
402
            x,
            layer.w13_weight,
            layer.w2_weight,
403
404
405
406
407
408
409
410
            getattr(layer, "w13_bias", None),
            getattr(layer, "w2_bias", None),
            layer.w13_scale,
            layer.w2_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=self.quant_config.quant_type.id,
411
            apply_router_weight_on_input=apply_router_weight_on_input,
412
            global_num_experts=global_num_experts,
413
            expert_map=expert_map,
414
            workspace=workspace,
415
        )
416
417


418
419
420
def rtn_quantize(
    tensor: torch.Tensor, num_bits: int, group_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
421
422
423
424
425
426
    """Quantize a tensor using per-group static scaling factor.

    Args:
        tensor: The input tensor.
        num_bits: Target precision for the result (supported values are
                  8 or 4).
427
        group_size: Quantization granularity.
428
429
430
                    If equal to -1, each row in the input tensor is treated
                    as one group.
    """
431
432
433
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
434
435

    q_range = 2**num_bits
436
437
438
439
440
    num_groups = (
        tensor.shape[1] * tensor.shape[2] // group_size
        if group_size != -1
        else tensor.shape[1]
    )
441
442
    """Calculate a scaling factor per input group.
    """
443
444
445
    input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
    input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
    input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
446
    input_max_abs = torch.max(input_min.abs(), input_max.abs())
447
    scale = input_max_abs * 2.0 / (q_range - 1)
448
449
    """Scale each input group, round to the nearest integer, shift 
    the range and truncate.
450
451
452
    """
    scaled_input = input_flat / scale
    scaled_input = scaled_input.round()
453
454
    scaled_input += q_range // 2
    scaled_input = scaled_input.clamp(0, q_range - 1)
455

456
457
    scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
    inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
458
459
460
461
462
    inputs_q = inputs_q.contiguous()

    if num_bits == 4:
        """Pack two 4-bit values into each byte.
        """
463
464
465
466
        inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
        inputs_q = inputs_q.reshape(
            tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
        )
467
468
        inputs_q = inputs_q.contiguous()

469
470
471
472
    if not batch_present:
        inputs_q = inputs_q.squeeze(0)
        scale = scale.squeeze(0)

473
474
475
476
477
478
479
480
481
482
    return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """Dequantize a tensor using per-group static scaling factors.

    Args:
        tensor: The input tensor.
        scale: The tensor with per-group scale factors.
    """
483
484
485
486
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
        scale = scale.unsqueeze(0)
487

488
489
    num_groups = scale.size(1) * scale.size(2)
    batch, input_dim, output_dim = tensor.shape
490

491
492
    num_bits = 8 if input_dim == scale.size(1) else 4
    q_range = 2**num_bits
493
494
495
    if num_bits == 4:
        input_dim *= 2

496
497
498
    data = torch.empty(
        (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device
    )
499
500
501

    if num_bits == 8:
        data.copy_(tensor)
502
        data -= q_range // 2
503
504
505
    else:
        """Unpack two 4-bit values from each byte.
        """
506
        tensor = tensor.reshape(batch, input_dim, output_dim // 2)
507
        for i in range(2):
508
509
510
            data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to(
                torch.int8
            ) - q_range // 2
511
512
    """Scale each input group with its scaling factor.
    """
513
514
    scale = scale.reshape(batch, num_groups, -1)
    data = data.reshape(batch, num_groups, -1)
515
516
    data = torch.mul(data, scale)

517
518
519
520
    input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
    if not batch_present:
        input_deq = input_deq.squeeze(0)

521
    return input_deq
522
523


524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def _get_perms():
    perm = []
    for i in range(32):
        perm1 = []
        col = i // 4
        for block in [0, 1]:
            for row in [
                2 * (i % 4),
                2 * (i % 4) + 1,
                2 * (i % 4 + 4),
                2 * (i % 4 + 4) + 1,
            ]:
                perm1.append(16 * row + col + 8 * block)
        for j in range(4):
            perm.extend([p + 256 * j for p in perm1])

    perm_arr = np.array(perm)
    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
    perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
    perm_tensor = torch.from_numpy(perm_arr)
    scale_perm = []
    for i in range(8):
        scale_perm.extend([i + 8 * j for j in range(8)])
    scale_perm_single = []
    for i in range(4):
        scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
    return perm_tensor, scale_perm, scale_perm_single


_perm, _scale_perm, _scale_perm_single = _get_perms()


def pack_for_marlin(weight, scale, qbits):
    batch = weight.shape[0]

    n = weight.size(1)
    k = weight.size(2)
    groupsize = k // scale.size(2)

    tile = 16
    s = scale.permute(0, 2, 1)  # transpose
    w = weight.permute(0, 2, 1)  # transpose
    if groupsize != k:
        w = w.reshape((batch, -1, groupsize, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, groupsize, -1))
        s = s.reshape((batch, 1, -1))

    if groupsize != k:
        w = w.reshape((batch, groupsize, -1, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, k, n)).contiguous()
        s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
    else:
        s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
    s = s.reshape((batch, -1, n)).contiguous()
    w = w.reshape((batch, k // tile, tile, n // tile, tile))
    w = w.permute((0, 1, 3, 2, 4))
    w = w.reshape((batch, k // tile, n * tile))
    res = w
    res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
    if qbits == 4:
        q = torch.zeros(
            (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
        )
        for i in range(2):
            q |= res[:, :, i::2] << 4 * i
        q = q.reshape(batch, -1, n).contiguous()
    else:
        q = res.clone()
        q[:, :, 2::8] = res[:, :, 4::8]
        q[:, :, 3::8] = res[:, :, 5::8]
        q[:, :, 4::8] = res[:, :, 2::8]
        q[:, :, 5::8] = res[:, :, 3::8]
        q = q.reshape(batch, -1, n).to(torch.int8).contiguous()

    return q, s


def repack_8bit_into_32bit(input):
    output = torch.zeros(
        (input.shape[0], input.shape[1], input.shape[2] // 4),
        dtype=torch.int32,
        device=input.device,
    )
    for i in range(4):
        output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i

    return output


def repack_weights(qweight, scale, weight_bits):
    batch_present = len(qweight.shape) == 3
    if not batch_present:
        qweight = qweight.unsqueeze(0)
        scale = scale.unsqueeze(0)

    if weight_bits == 4:
        """Unpack two 4-bit values from each byte.
        """
        qweight_unpacked = torch.empty(
            (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
            dtype=torch.uint8,
            device=qweight.device,
        )
        for i in range(2):
            qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
                qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
            )
    else:
        qweight_unpacked = qweight

    qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
    """Marlin kernels expect tensors in int32 format in a certain shape
638
    """
639
640
641
642
643
644
645
646
647
    qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
    qweight_reshaped = qweight_repacked.reshape(
        qweight.shape[0], qweight.shape[2] // 16, -1
    )
    if not batch_present:
        qweight_reshaped = qweight_reshaped.squeeze(0)
        scale_packed = scale_packed.squeeze(0)

    return qweight_reshaped, scale_packed
648
649


650
651
652
653
def init_workspace(device):
    global workspace
    if workspace is None:
        workspace = marlin_make_workspace_new(device, 4)