w4afp8.py 10.8 KB
Newer Older
1
2
from __future__ import annotations

3
import logging
4
from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from sglang.srt.layers.quantization.base_config import (
11
    FusedMoEMethodBase,
12
13
14
15
    QuantizationConfig,
    QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
16
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
17
18
19
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs

20
21
22
if TYPE_CHECKING:
    from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = logging.getLogger(__name__)


class W4AFp8Config(QuantizationConfig):
    """Config class for MIXED_PRECISION W4AFp8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = True,
        is_checkpoint_w4afp8_serialized: bool = True,
        linear_activation_scheme: str = "dynamic",
        moe_activation_scheme: str = "static",
        ignored_layers: Optional[List[str]] = None,
        weight_block_size: Optional[List[int]] = None,
        group_size: int = 128,
    ) -> None:
        super().__init__()
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
        if is_checkpoint_w4afp8_serialized:
            logger.warning("Detected w4afp8 checkpoint. Please note that")
        if moe_activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
        self.linear_activation_scheme = linear_activation_scheme
        self.moe_activation_scheme = moe_activation_scheme
        self.ignored_layers = ignored_layers or []
        self.weight_block_size = [128, 128]
        self.group_size = group_size

    @classmethod
    def get_name(cls) -> str:
        return "w4afp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
        return 90

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
71
    def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = "fp8" in quant_method
        is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
        linear_activation_scheme = "dynamic"
        moe_activation_scheme = "static"
        weight_block_size = [128, 128]
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
            linear_activation_scheme=linear_activation_scheme,
            moe_activation_scheme=moe_activation_scheme,
            weight_block_size=weight_block_size,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
88
89
    ) -> Optional[QuantizeMethodBase]:
        from sglang.srt.layers.linear import LinearBase
90
        from sglang.srt.layers.moe.ep_moe.layer import EPMoE
91
92
93
94
95
96
        from sglang.srt.layers.moe.fused_moe_triton import FusedMoE

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
            return Fp8LinearMethod(self)
97
        elif isinstance(layer, EPMoE):
98
99
100
101
102
103
104
            return W4AFp8MoEMethod(self)
        return None

    def get_scaled_act_names(self) -> List[str]:
        return []


105
class W4AFp8MoEMethod(FusedMoEMethodBase):
106
107
108
109
110
111

    def __init__(self, quant_config: W4AFp8Config):
        self.quant_config = quant_config

    def create_weights(
        self,
112
113
        layer: EPMoE,
        num_experts: int,
114
115
116
117
118
119
120
121
122
123
        hidden_size: int,
        intermediate_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        assert "weight_loader" in extra_weight_attrs

        # Fused gate_up_proj (column parallel)
        w13_weight = torch.nn.Parameter(
            torch.empty(
124
                num_experts,
125
126
127
128
129
130
131
132
133
134
135
136
                intermediate_size * 2,
                hidden_size // 2,
                dtype=torch.int8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
        w2_weight = torch.nn.Parameter(
            torch.empty(
137
                num_experts,
138
139
140
141
142
143
144
145
146
147
148
                hidden_size,
                intermediate_size // 2,
                dtype=torch.int8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
149
                num_experts,
150
151
152
153
154
155
156
157
158
159
160
                2 * intermediate_size,
                hidden_size // self.quant_config.group_size,
                dtype=torch.float32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
161
                num_experts,
162
163
164
165
166
167
168
169
170
171
172
                hidden_size,
                intermediate_size // self.quant_config.group_size,
                dtype=torch.float32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # Input scales
        w13_input_scale = torch.nn.Parameter(
173
            torch.ones((num_experts, 2), dtype=torch.bfloat16),
174
175
176
177
178
179
            requires_grad=False,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(
180
            torch.ones(num_experts, dtype=torch.bfloat16),
181
182
183
184
185
186
187
188
189
            requires_grad=False,
        )
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)

        # Pre-populate the strides
        device = layer.w13_weight.device

        self.a_strides1 = torch.full(
190
            (num_experts, 3),
191
192
193
194
195
            hidden_size,
            device=device,
            dtype=torch.int64,
        )
        self.c_strides1 = torch.full(
196
            (num_experts, 3),
197
198
199
200
201
            2 * intermediate_size,
            device=device,
            dtype=torch.int64,
        )
        self.a_strides2 = torch.full(
202
            (num_experts, 3),
203
204
205
206
207
            intermediate_size,
            device=device,
            dtype=torch.int64,
        )
        self.c_strides2 = torch.full(
208
            (num_experts, 3),
209
210
211
212
213
214
215
216
217
218
            hidden_size,
            device=device,
            dtype=torch.int64,
        )
        self.b_strides1 = self.a_strides1
        self.s_strides13 = self.c_strides1
        self.b_strides2 = self.a_strides2
        self.s_strides2 = self.c_strides2

        self.expert_offsets = torch.empty(
219
            (num_experts + 1), dtype=torch.int32, device=device
220
221
        )
        self.problem_sizes1 = torch.empty(
222
            (num_experts, 3), dtype=torch.int32, device=device
223
224
        )
        self.problem_sizes2 = torch.empty(
225
            (num_experts, 3), dtype=torch.int32, device=device
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        )

        return

    def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
        """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
        s_shape = scales.shape
        # Reshape to separate groups of 4
        scales_interleaved = scales.reshape(
            s_shape[0], s_shape[1], (s_shape[2] // 4), 4
        )
        # Permute dimensions to interleave
        scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
        # Reshape back to original dimensions but with interleaved values
        scales_interleaved = scales_interleaved.reshape(
            s_shape[0], s_shape[2] // 4, s_shape[1] * 4
        )
        return scales_interleaved.contiguous()

    def process_weights_after_loading(self, layer: Module) -> None:
        dtype = torch.bfloat16
        device = layer.w2_weight.device

        # Interleave w13_weight_scale (gate_up_proj)
        w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
        w13_weight_scale = self._interleave_scales(w13_weight_scale)
        layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)

        # Interleave w2_weight_scale (down_proj)
        w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
        w2_weight_scale = self._interleave_scales(w2_weight_scale)
        layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)

        # Process input scales
        w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
        new_w13_input_scale = torch.tensor(
            [w13_input_scale_max],
            dtype=dtype,
            device=device,
        )
        layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)

        w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
        new_w2_input_scale = torch.tensor(
            [w2_input_scale_max], dtype=dtype, device=device
        )
        layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
273
274
275
276
277
278

    def apply(
        self,
        layer: EPMoE,
        hidden_states: torch.Tensor,
        topk_output: TopKOutput,
279
        **kwargs,
280
281
282
283
284
    ) -> torch.Tensor:

        # TODO(ch-wan): move it out of this class
        from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe

285
        topk_weights, topk_ids, _ = topk_output
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        local_topk_ids = topk_ids
        if layer.expert_map is not None:
            "Translate info from expert_map to topk_ids"
            local_topk_ids = torch.where(
                layer.expert_map[topk_ids] != layer.num_experts,
                layer.expert_map[topk_ids],
                layer.num_experts,
            )

        return cutlass_w4a8_moe(
            layer.start_expert_id,
            layer.end_expert_id,
            layer.num_experts,
            hidden_states,
            layer.w13_weight,
            layer.w2_weight,
            layer.w13_weight_scale_inv,
            layer.w2_weight_scale_inv,
            topk_weights,
            topk_ids,
            local_topk_ids,
            self.a_strides1,
            self.b_strides1,
            self.c_strides1,
            self.a_strides2,
            self.b_strides2,
            self.c_strides2,
            self.s_strides13,
            self.s_strides2,
            self.expert_offsets,
            self.problem_sizes1,
            self.problem_sizes2,
            layer.w13_input_scale,
            layer.w2_input_scale,
        )