rtn.py 20.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
# Copyright © 2025, Oracle and/or its affiliates.

import os
6
7
from collections.abc import Callable
from typing import Any, Optional
8

9
import numpy as np
10
11
12
13
import torch
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
14
from vllm.model_executor.layers.fused_moe.config import (
15
    FusedMoEConfig,
16
17
    FusedMoEQuantConfig,
)
18
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
19
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
20
21
22
23
24
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    set_weight_attrs,
)
25
26
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
27
28
29
    QuantizationConfig,
    QuantizeMethodBase,
)
30
31
32
33
34
35
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    apply_rtn_marlin_linear,
    marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types
36
37
38
39
40

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be 
overridden by setting the RTN_NUM_BITS envvar
"""
41
NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
42
43
44
"""By default, use group size of 128 parameters, but it can be 
overridden by setting the RTN_GROUP_SIZE envvar
"""
45
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
46
47
48
"""Global Marlin workspace shared by all modules
"""
workspace = None
49
50
51


class RTNConfig(QuantizationConfig):
52
    """Config class for RTN."""
53
54

    def __init__(
55
56
57
        self,
        weight_bits: int = int(NUM_BITS),
        group_size: int = int(GROUP_SIZE),
58
59
60
61
62
63
64
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size

        if self.weight_bits != 4 and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 4-bit or 8-bit weight quantization is "
65
66
                f"supported for RTN, but got {self.weight_bits} bits."
            )
67

68
69
70
71
        self.quant_type = (
            scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
        )

72
    def __repr__(self) -> str:
73
74
75
        return (
            f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
        )
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "rtn"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits, group_size)

99
100
101
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
102
103
        if isinstance(layer, LinearBase):
            return RTNLinearMethod(self)
104
        elif isinstance(layer, FusedMoE):
105
            return RTNMoEMethod(self, layer.moe_config)
106
107
108
109
110
111
112
113
        return None


class RTNTensor:
    """A wrapper over Tensor that enables quantization on-the-fly by
    overloading the copy_ method.
    """

114
115
116
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
117
118
119
120
121
122
123
124
        self.data = data
        self.scale = scale
        self.quant_config = quant_config

    def narrow(self, dim, start, length):
        factor = 1 if self.quant_config.weight_bits == 8 else 2
        return RTNTensor(
            self.data.narrow(dim, start // factor, length // factor),
125
126
127
            self.scale.narrow(dim, start, length),
            self.quant_config,
        )
128

129
130
131
    def __getitem__(self, key):
        return RTNTensor(self.data[key], self.scale[key], self.quant_config)

132
133
134
135
    @property
    def shape(self):
        shape = self.data.shape
        factor = 1 if self.quant_config.weight_bits == 8 else 2
136
137
138
139
140
        batch_present = len(shape) == 3
        if batch_present:
            return torch.Size((shape[0], shape[1] * factor, shape[2]))
        else:
            return torch.Size((shape[0] * factor, shape[1]))
141
142

    def copy_(self, loaded_weight: torch.Tensor) -> None:
143
144
145
146
147
        qweight, weight_scale = rtn_quantize(
            loaded_weight.cuda(),
            self.quant_config.weight_bits,
            self.quant_config.group_size,
        )
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

        self.data.copy_(qweight)
        self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
    """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
    when its data is accessed. We need this wrapper for the data loading phase
    only, so we can intercept a weight copying function (torch.Tensor.copy_)
    and apply quantization on-the-fly.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):
        return super().__new__(cls, data=data, requires_grad=False)

163
164
165
    def __init__(
        self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
    ) -> None:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        self.scale = scale
        self.quant_config = quant_config

    @property
    def data(self):
        return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
    """Linear method for RTN.

    Args:
        quant_config: The RTN quantization config.
    """

    def __init__(self, quant_config: RTNConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
195
196
197
198
199
        num_groups_per_col = (
            input_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
200
201

        scale = Parameter(
202
203
204
            torch.empty(
                output_size_per_partition, num_groups_per_col, dtype=params_dtype
            ),
205
206
207
208
            requires_grad=False,
        )
        factor = 1 if self.quant_config.weight_bits == 8 else 2

209
210
211
212
213
214
215
216
217
        weight = RTNParameter(
            data=torch.empty(
                output_size_per_partition // factor,
                input_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=scale,
            quant_config=self.quant_config,
        )
218
219

        layer.register_parameter("weight", weight)
220
221
222
223
224
225
226
227
        set_weight_attrs(
            weight,
            {
                **extra_weight_attrs,
                "input_dim": 1,
                "output_dim": 0,
            },
        )
228
229
230
231
232

        layer.register_parameter("scale", scale)
        layer.output_size_per_partition = output_size_per_partition

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
233
234
235
236
237
238
239
240
241
        """Repack weights and scales for Marlin kernels."""
        weight_bits = self.quant_config.weight_bits

        weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)

        replace_parameter(layer, "weight", weight)
        replace_parameter(layer, "scale", scale)

        init_workspace(layer.weight.device)
242

243
244
245
246
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
247
        bias: torch.Tensor | None = None,
248
    ) -> torch.Tensor:
249
250
251
252
253
254
255
256
257
258
        return apply_rtn_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.scale,
            workspace=workspace,
            quant_type=self.quant_config.quant_type,
            output_size_per_partition=layer.output_size_per_partition,
            input_size_per_partition=layer.input_size_per_partition,
            bias=bias,
        )
259
260


261
class RTNMoEMethod(FusedMoEMethodBase):
262
263
    def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
        super().__init__(moe)
264
265
        self.quant_config = quant_config

266
267
268
269
270
271
272
273
274
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
275
276
277
        factor = 1 if self.quant_config.weight_bits == 8 else 2

        # Fused gate_up_proj (column parallel)
278
279
280
281
282
        num_groups_per_col = (
            hidden_size // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
283
        w13_scale = Parameter(
284
285
286
287
288
289
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                num_groups_per_col,
                dtype=params_dtype,
            ),
290
291
292
293
            requires_grad=False,
        )
        layer.register_parameter("w13_scale", w13_scale)

294
295
296
297
298
299
300
301
302
303
        w13_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition // factor,
                hidden_size,
                dtype=torch.uint8,
            ),
            scale=w13_scale,
            quant_config=self.quant_config,
        )
304
305
306
307
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
308
309
310
311
312
313
314
315
316
317
318
        num_groups_per_col = (
            intermediate_size_per_partition // self.quant_config.group_size
            if self.quant_config.group_size != -1
            else 1
        )
        w2_scale = Parameter(
            torch.zeros(
                num_experts, hidden_size, num_groups_per_col, dtype=params_dtype
            ),
            requires_grad=False,
        )
319
320
        layer.register_parameter("w2_scale", w2_scale)

321
322
323
324
325
326
327
328
329
330
        w2_weight = RTNParameter(
            data=torch.empty(
                num_experts,
                hidden_size // factor,
                intermediate_size_per_partition,
                dtype=torch.uint8,
            ),
            scale=w2_scale,
            quant_config=self.quant_config,
        )
331
332
333
334
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
335
        """Repack weights and scales for Marlin kernels."""
336
        weight_bits = self.quant_config.weight_bits
337
338
339
340
341
342
343
344
345
346
347
348
349
350

        w13_weight, w13_scale = repack_weights(
            layer.w13_weight, layer.w13_scale, weight_bits
        )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w13_scale", w13_scale)

        w2_weight, w2_scale = repack_weights(
            layer.w2_weight, layer.w2_scale, weight_bits
        )
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, "w2_scale", w2_scale)

        init_workspace(layer.w13_weight.device)
351

352
    def get_fused_moe_quant_config(
353
        self, layer: torch.nn.Module
354
    ) -> FusedMoEQuantConfig | None:
355
        return None
356

357
358
359
360
361
362
363
364
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
365
366
        topk_group: int | None = None,
        num_expert_group: int | None = None,
367
        global_num_experts: int = -1,
368
369
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
370
        scoring_func: str = "softmax",
371
        routed_scaling_factor: float = 1.0,
372
        e_score_correction_bias: torch.Tensor | None = None,
373
374
375
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
376
377
378
379
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
380
        if enable_eplb:
381
            raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
382

XuruiYang's avatar
XuruiYang committed
383
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
384
385
386
387
388
389
390
391
392
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
393
            routed_scaling_factor=routed_scaling_factor,
394
            e_score_correction_bias=e_score_correction_bias,
395
396
            indices_type=self.topk_indices_dtype,
        )
397

398
        return fused_marlin_moe(
399
400
401
            x,
            layer.w13_weight,
            layer.w2_weight,
402
403
404
405
406
407
408
409
            getattr(layer, "w13_bias", None),
            getattr(layer, "w2_bias", None),
            layer.w13_scale,
            layer.w2_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=self.quant_config.quant_type.id,
410
            apply_router_weight_on_input=apply_router_weight_on_input,
411
            global_num_experts=global_num_experts,
412
            expert_map=expert_map,
413
            workspace=workspace,
414
        )
415
416


417
418
419
def rtn_quantize(
    tensor: torch.Tensor, num_bits: int, group_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
420
421
422
423
424
425
    """Quantize a tensor using per-group static scaling factor.

    Args:
        tensor: The input tensor.
        num_bits: Target precision for the result (supported values are
                  8 or 4).
426
        group_size: Quantization granularity.
427
428
429
                    If equal to -1, each row in the input tensor is treated
                    as one group.
    """
430
431
432
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
433
434

    q_range = 2**num_bits
435
436
437
438
439
    num_groups = (
        tensor.shape[1] * tensor.shape[2] // group_size
        if group_size != -1
        else tensor.shape[1]
    )
440
441
    """Calculate a scaling factor per input group.
    """
442
443
444
    input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
    input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
    input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
445
    input_max_abs = torch.max(input_min.abs(), input_max.abs())
446
    scale = input_max_abs * 2.0 / (q_range - 1)
447
448
    """Scale each input group, round to the nearest integer, shift 
    the range and truncate.
449
450
451
    """
    scaled_input = input_flat / scale
    scaled_input = scaled_input.round()
452
453
    scaled_input += q_range // 2
    scaled_input = scaled_input.clamp(0, q_range - 1)
454

455
456
    scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
    inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
457
458
459
460
461
    inputs_q = inputs_q.contiguous()

    if num_bits == 4:
        """Pack two 4-bit values into each byte.
        """
462
463
464
465
        inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
        inputs_q = inputs_q.reshape(
            tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
        )
466
467
        inputs_q = inputs_q.contiguous()

468
469
470
471
    if not batch_present:
        inputs_q = inputs_q.squeeze(0)
        scale = scale.squeeze(0)

472
473
474
475
476
477
478
479
480
481
    return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """Dequantize a tensor using per-group static scaling factors.

    Args:
        tensor: The input tensor.
        scale: The tensor with per-group scale factors.
    """
482
483
484
485
    batch_present = len(tensor.shape) == 3
    if not batch_present:
        tensor = tensor.unsqueeze(0)
        scale = scale.unsqueeze(0)
486

487
488
    num_groups = scale.size(1) * scale.size(2)
    batch, input_dim, output_dim = tensor.shape
489

490
491
    num_bits = 8 if input_dim == scale.size(1) else 4
    q_range = 2**num_bits
492
493
494
    if num_bits == 4:
        input_dim *= 2

495
496
497
    data = torch.empty(
        (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device
    )
498
499
500

    if num_bits == 8:
        data.copy_(tensor)
501
        data -= q_range // 2
502
503
504
    else:
        """Unpack two 4-bit values from each byte.
        """
505
        tensor = tensor.reshape(batch, input_dim, output_dim // 2)
506
        for i in range(2):
507
508
509
            data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to(
                torch.int8
            ) - q_range // 2
510
511
    """Scale each input group with its scaling factor.
    """
512
513
    scale = scale.reshape(batch, num_groups, -1)
    data = data.reshape(batch, num_groups, -1)
514
515
    data = torch.mul(data, scale)

516
517
518
519
    input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
    if not batch_present:
        input_deq = input_deq.squeeze(0)

520
    return input_deq
521
522


523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
def _get_perms():
    perm = []
    for i in range(32):
        perm1 = []
        col = i // 4
        for block in [0, 1]:
            for row in [
                2 * (i % 4),
                2 * (i % 4) + 1,
                2 * (i % 4 + 4),
                2 * (i % 4 + 4) + 1,
            ]:
                perm1.append(16 * row + col + 8 * block)
        for j in range(4):
            perm.extend([p + 256 * j for p in perm1])

    perm_arr = np.array(perm)
    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
    perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
    perm_tensor = torch.from_numpy(perm_arr)
    scale_perm = []
    for i in range(8):
        scale_perm.extend([i + 8 * j for j in range(8)])
    scale_perm_single = []
    for i in range(4):
        scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
    return perm_tensor, scale_perm, scale_perm_single


_perm, _scale_perm, _scale_perm_single = _get_perms()


def pack_for_marlin(weight, scale, qbits):
    batch = weight.shape[0]

    n = weight.size(1)
    k = weight.size(2)
    groupsize = k // scale.size(2)

    tile = 16
    s = scale.permute(0, 2, 1)  # transpose
    w = weight.permute(0, 2, 1)  # transpose
    if groupsize != k:
        w = w.reshape((batch, -1, groupsize, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, groupsize, -1))
        s = s.reshape((batch, 1, -1))

    if groupsize != k:
        w = w.reshape((batch, groupsize, -1, n))
        w = w.permute(0, 2, 1, 3)
        w = w.reshape((batch, k, n)).contiguous()
        s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
    else:
        s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
    s = s.reshape((batch, -1, n)).contiguous()
    w = w.reshape((batch, k // tile, tile, n // tile, tile))
    w = w.permute((0, 1, 3, 2, 4))
    w = w.reshape((batch, k // tile, n * tile))
    res = w
    res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
    if qbits == 4:
        q = torch.zeros(
            (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
        )
        for i in range(2):
            q |= res[:, :, i::2] << 4 * i
        q = q.reshape(batch, -1, n).contiguous()
    else:
        q = res.clone()
        q[:, :, 2::8] = res[:, :, 4::8]
        q[:, :, 3::8] = res[:, :, 5::8]
        q[:, :, 4::8] = res[:, :, 2::8]
        q[:, :, 5::8] = res[:, :, 3::8]
        q = q.reshape(batch, -1, n).to(torch.int8).contiguous()

    return q, s


def repack_8bit_into_32bit(input):
    output = torch.zeros(
        (input.shape[0], input.shape[1], input.shape[2] // 4),
        dtype=torch.int32,
        device=input.device,
    )
    for i in range(4):
        output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i

    return output


def repack_weights(qweight, scale, weight_bits):
    batch_present = len(qweight.shape) == 3
    if not batch_present:
        qweight = qweight.unsqueeze(0)
        scale = scale.unsqueeze(0)

    if weight_bits == 4:
        """Unpack two 4-bit values from each byte.
        """
        qweight_unpacked = torch.empty(
            (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
            dtype=torch.uint8,
            device=qweight.device,
        )
        for i in range(2):
            qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
                qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
            )
    else:
        qweight_unpacked = qweight

    qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
    """Marlin kernels expect tensors in int32 format in a certain shape
637
    """
638
639
640
641
642
643
644
645
646
    qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
    qweight_reshaped = qweight_repacked.reshape(
        qweight.shape[0], qweight.shape[2] // 16, -1
    )
    if not batch_present:
        qweight_reshaped = qweight_reshaped.squeeze(0)
        scale_packed = scale_packed.squeeze(0)

    return qweight_reshaped, scale_packed
647
648


649
650
651
652
def init_workspace(device):
    global workspace
    if workspace is None:
        workspace = marlin_make_workspace_new(device, 4)