fp8.py 22.6 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.logger import init_logger
10
11
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
12
13
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
14
from vllm.model_executor.layers.quantization.base_config import (
15
    QuantizationConfig, QuantizeMethodBase)
16
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
17
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
18
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
19
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
21
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
22
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
23
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
24
    requantize_with_max_scale)
25
26
from vllm.model_executor.parameter import (ModelWeightParameter,
                                           PerTensorScaleParameter)
27
from vllm.model_executor.utils import set_weight_attrs
28
from vllm.platforms import current_platform
29
from vllm.utils import print_warning_once
30

31
32
33
34
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

35

36
class Fp8Config(QuantizationConfig):
37
38
    """Config class for FP8."""

39
40
    def __init__(
        self,
41
        is_checkpoint_fp8_serialized: bool = False,
42
        activation_scheme: str = "dynamic",
43
        ignored_layers: Optional[List[str]] = None,
44
    ) -> None:
45
46
47
48
49
50
51
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
52
        self.activation_scheme = activation_scheme
53
        self.ignored_layers = ignored_layers or []
54

55
56
57
58
59
60
61
62
63
64
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
65
        return 80
66
67
68
69
70
71

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
72
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
73
74
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
75
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
76
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
77
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
78
79
                   activation_scheme=activation_scheme,
                   ignored_layers=ignored_layers)
80

81
82
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
83
84
        from vllm.attention.layer import Attention  # Avoid circular import

85
        if isinstance(layer, LinearBase):
86
87
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
88
            return Fp8LinearMethod(self)
89
90
91
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
92
            return Fp8KVCacheMethod(self)
93
        return None
94
95
96
97


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
98
99
100
101
102
103
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
104
105
106
107
108

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
109

110
111
112
113
    Args:
        quant_config: The quantization config.
    """

114
    def __init__(self, quant_config: Fp8Config):
115
        self.quant_config = quant_config
116
        self.cutlass_fp8_supported = cutlass_fp8_supported()
117

118
119
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
120
121
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
122
        # Disable marlin for rocm
123
        if current_platform.is_rocm():
124
            self.use_marlin = False
125

126
127
128
129
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
130
        output_partition_sizes: List[int],
131
132
133
134
135
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
136
        del input_size, output_size
137
        output_size_per_partition = sum(output_partition_sizes)
138
        weight_loader = extra_weight_attrs.get("weight_loader")
139
140
141

        layer.logical_widths = output_partition_sizes

142
143
144
145
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

146
147
148
149
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
150
151
152
153
154
155
156
157

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
158
159
        layer.register_parameter("weight", weight)

160
161
162
163
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
164
165
166
167
168
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
169
            layer.register_parameter("weight_scale", scale)
170

171
            # INPUT ACTIVATION SCALE
172
            if self.quant_config.activation_scheme == "static":
173
174
175
176
177
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
178
                layer.register_parameter("input_scale", scale)
179
180
            else:
                layer.register_parameter("input_scale", None)
181

182
    def process_weights_after_loading(self, layer: Module) -> None:
183
184
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
185
        # If checkpoint not serialized fp8, quantize the weights.
186
187
188
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
189

190
191
192
193
194
195
196
197
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

198
            # Update the layer with the new values.
199
200
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
201
            layer.input_scale = None
202

203
204
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
205
        else:
206
207
208
209
210
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
211
212
213
214
215
216
217
218
219
220
221
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
222
223
224
225
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
226
                if current_platform.is_rocm():
227
228
229
230
231
232
233
234
235
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

236
                weight_scale, weight = requantize_with_max_scale(
237
238
                    weight=weight,
                    weight_scale=weight_scale,
239
240
                    logical_widths=layer.logical_widths,
                )
241

242
            # Update layer with new values.
243
            layer.weight = Parameter(weight.t(), requires_grad=False)
244
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
245
            if self.quant_config.activation_scheme == "static":
246
247
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
248

249
250
251
252
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
253

254
255
256
257
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
258

259
        if self.use_marlin:
260
261
262
263
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
264
265
266
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
267
                bias=bias)
268

269
270
271
272
273
274
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
275
276
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
277
278


279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
322
323
324
325
326
        w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                         2,
                                                         dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
327

328
329
330
331
        w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                        dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
332
333
334
335
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
336
337
338
339
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
340
341
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
342
343
344
345
346
347
348
349

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

350
351
352
353
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
354
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
355
356
357
358
359

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
360
361
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

362
        else:
363
364
            layer.w13_input_scale = None
            layer.w2_input_scale = None
365
366
367
368
369

    def process_weights_after_loading(self, layer: Module) -> None:

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
370
371
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
372
                        if current_platform.is_rocm() else torch.float8_e4m3fn
373
            w13_weight = torch.empty_like(layer.w13_weight.data,
374
375
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
376
377
378

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
379
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
380
381
382
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
383
                                                        requires_grad=False)
384
            for expert in range(layer.num_experts):
385
                w13_weight[expert, :, :], layer.w13_weight_scale[
386
387
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
388
                w2_weight[expert, :, :], layer.w2_weight_scale[
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
404
405
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
406
407
408
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
409
410
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
411
412
413
414
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
415
416
417
418
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
419
            # If rocm, normalize the weights and scales to e4m3fnuz
420
            if current_platform.is_rocm():
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
445
446
447

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
448
            assert layer.w13_weight_scale is not None
449
            shard_size = layer.intermediate_size_per_partition
450
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
451
452
453
454
455
456
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
457
                        layer.w13_weight_scale[expert_id][shard_id])
458
                    layer.w13_weight[expert_id][
459
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
460
461
462
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

463
464
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
465
466
            return

467
468
469
470
471
472
473
474
475
476
477
478
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
    ) -> torch.Tensor:
479
480
481
482
483
484
485
486
487
488

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
489
490
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function)
491
492
493
494
495
496
497

        return fused_experts(x,
                             layer.w13_weight,
                             layer.w2_weight,
                             topk_weights=topk_weights,
                             topk_ids=topk_ids,
                             inplace=True,
498
                             use_fp8_w8a8=True,
499
500
501
502
                             w1_scale=layer.w13_weight_scale,
                             w2_scale=layer.w2_weight_scale,
                             a1_scale=layer.w13_input_scale,
                             a2_scale=layer.w2_input_scale)
503
504


505
506
507
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
508
509
510
    """

    def __init__(self, quant_config: Fp8Config):
511
        super().__init__(quant_config)