w4afp8.py 11.4 KB
Newer Older
1
2
from __future__ import annotations

3
import logging
4
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

10
11
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
12
from sglang.srt.layers.quantization.base_config import (
13
    FusedMoEMethodBase,
14
15
16
17
    QuantizationConfig,
    QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
18
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
19
20
21
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs

22
if TYPE_CHECKING:
23
24
25
    from sglang.srt.layers.moe import MoeRunnerConfig
    from sglang.srt.layers.moe.ep_moe.layer import EPMoE
    from sglang.srt.layers.moe.topk import StandardTopKOutput
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = logging.getLogger(__name__)


class W4AFp8Config(QuantizationConfig):
    """Config class for MIXED_PRECISION W4AFp8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = True,
        is_checkpoint_w4afp8_serialized: bool = True,
        linear_activation_scheme: str = "dynamic",
        moe_activation_scheme: str = "static",
        ignored_layers: Optional[List[str]] = None,
        weight_block_size: Optional[List[int]] = None,
        group_size: int = 128,
    ) -> None:
        super().__init__()
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
        if is_checkpoint_w4afp8_serialized:
            logger.warning("Detected w4afp8 checkpoint. Please note that")
        if moe_activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
        self.linear_activation_scheme = linear_activation_scheme
        self.moe_activation_scheme = moe_activation_scheme
        self.ignored_layers = ignored_layers or []
        self.weight_block_size = [128, 128]
        self.group_size = group_size

    @classmethod
    def get_name(cls) -> str:
        return "w4afp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
        return 90

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
75
    def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = "fp8" in quant_method
        is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
        linear_activation_scheme = "dynamic"
        moe_activation_scheme = "static"
        weight_block_size = [128, 128]
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
            linear_activation_scheme=linear_activation_scheme,
            moe_activation_scheme=moe_activation_scheme,
            weight_block_size=weight_block_size,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
92
93
    ) -> Optional[QuantizeMethodBase]:
        from sglang.srt.layers.linear import LinearBase
94
        from sglang.srt.layers.moe.ep_moe.layer import EPMoE
95
        from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
96
        from sglang.srt.managers.schedule_batch import global_server_args_dict
97
98
99
100
101

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
            return Fp8LinearMethod(self)
102
        elif isinstance(layer, FusedMoE):
103
104
105
106
107
108
109
            return W4AFp8MoEMethod(self)
        return None

    def get_scaled_act_names(self) -> List[str]:
        return []


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
    """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
    s_shape = scales.shape
    # Reshape to separate groups of 4
    alignment = 4 if s_shape[2] % 4 == 0 else 1
    scales_interleaved = scales.reshape(
        s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
    )
    # Permute dimensions to interleave
    scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
    # Reshape back to original dimensions but with interleaved values
    scales_interleaved = scales_interleaved.reshape(
        s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
    )
    return scales_interleaved.contiguous()

126

127
class W4AFp8MoEMethod(FusedMoEMethodBase):
128
129
130
131
132
    def __init__(self, quant_config: W4AFp8Config):
        self.quant_config = quant_config

    def create_weights(
        self,
133
134
        layer: EPMoE,
        num_experts: int,
135
136
137
138
139
        hidden_size: int,
        intermediate_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
140
141
        from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

142
143
144
145
146
        assert "weight_loader" in extra_weight_attrs

        # Fused gate_up_proj (column parallel)
        w13_weight = torch.nn.Parameter(
            torch.empty(
147
                num_experts,
148
149
150
151
152
153
154
155
156
157
158
159
                intermediate_size * 2,
                hidden_size // 2,
                dtype=torch.int8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
        w2_weight = torch.nn.Parameter(
            torch.empty(
160
                num_experts,
161
162
163
164
165
166
167
168
169
                hidden_size,
                intermediate_size // 2,
                dtype=torch.int8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

170
171
172
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
        )
173
174
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
175
                num_experts,
176
177
178
179
180
181
182
183
184
185
186
                2 * intermediate_size,
                hidden_size // self.quant_config.group_size,
                dtype=torch.float32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
187
                num_experts,
188
189
190
191
192
193
194
195
196
197
198
                hidden_size,
                intermediate_size // self.quant_config.group_size,
                dtype=torch.float32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # Input scales
        w13_input_scale = torch.nn.Parameter(
199
            torch.ones((num_experts, 2), dtype=torch.bfloat16),
200
201
202
203
204
205
            requires_grad=False,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(
206
            torch.ones(num_experts, dtype=torch.bfloat16),
207
208
209
210
211
212
213
214
215
            requires_grad=False,
        )
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)

        # Pre-populate the strides
        device = layer.w13_weight.device

        self.a_strides1 = torch.full(
216
            (num_experts, 3),
217
218
219
220
221
            hidden_size,
            device=device,
            dtype=torch.int64,
        )
        self.c_strides1 = torch.full(
222
            (num_experts, 3),
223
224
225
226
227
            2 * intermediate_size,
            device=device,
            dtype=torch.int64,
        )
        self.a_strides2 = torch.full(
228
            (num_experts, 3),
229
230
231
232
233
            intermediate_size,
            device=device,
            dtype=torch.int64,
        )
        self.c_strides2 = torch.full(
234
            (num_experts, 3),
235
236
237
238
239
240
241
242
243
244
            hidden_size,
            device=device,
            dtype=torch.int64,
        )
        self.b_strides1 = self.a_strides1
        self.s_strides13 = self.c_strides1
        self.b_strides2 = self.a_strides2
        self.s_strides2 = self.c_strides2

        self.expert_offsets = torch.empty(
245
            (num_experts + 1), dtype=torch.int32, device=device
246
247
        )
        self.problem_sizes1 = torch.empty(
248
            (num_experts, 3), dtype=torch.int32, device=device
249
250
        )
        self.problem_sizes2 = torch.empty(
251
            (num_experts, 3), dtype=torch.int32, device=device
252
253
254
255
256
257
258
259
260
261
        )

        return

    def process_weights_after_loading(self, layer: Module) -> None:
        dtype = torch.bfloat16
        device = layer.w2_weight.device

        # Interleave w13_weight_scale (gate_up_proj)
        w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
262
        w13_weight_scale = interleave_scales(w13_weight_scale)
263
264
265
266
        layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)

        # Interleave w2_weight_scale (down_proj)
        w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
267
        w2_weight_scale = interleave_scales(w2_weight_scale)
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)

        # Process input scales
        w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
        new_w13_input_scale = torch.tensor(
            [w13_input_scale_max],
            dtype=dtype,
            device=device,
        )
        layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)

        w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
        new_w2_input_scale = torch.tensor(
            [w2_input_scale_max], dtype=dtype, device=device
        )
        layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
284
285
286
287

    def apply(
        self,
        layer: EPMoE,
288
        x: torch.Tensor,
289
290
        topk_output: StandardTopKOutput,
        moe_runner_config: MoeRunnerConfig,
291
292
293
294
295
    ) -> torch.Tensor:

        # TODO(ch-wan): move it out of this class
        from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe

296
        topk_weights, topk_ids, _ = topk_output
297
        local_topk_ids = topk_ids
298
299
300
301
302
303
        if get_moe_expert_parallel_world_size() > 1:
            local_topk_ids = torch.where(
                topk_ids == -1,
                layer.num_experts,
                topk_ids,
            )
304
305

        output = cutlass_w4a8_moe(
306
307
308
            layer.start_expert_id,
            layer.end_expert_id,
            layer.num_experts,
309
            x,
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            layer.w13_weight,
            layer.w2_weight,
            layer.w13_weight_scale_inv,
            layer.w2_weight_scale_inv,
            topk_weights,
            topk_ids,
            local_topk_ids,
            self.a_strides1,
            self.b_strides1,
            self.c_strides1,
            self.a_strides2,
            self.b_strides2,
            self.c_strides2,
            self.s_strides13,
            self.s_strides2,
            self.expert_offsets,
            self.problem_sizes1,
            self.problem_sizes2,
            layer.w13_input_scale,
            layer.w2_input_scale,
        )
331
332
        if moe_runner_config.routed_scaling_factor is not None:
            output *= moe_runner_config.routed_scaling_factor
333
        return output