fp8.py 18.3 KB
Newer Older
1
from typing import Any, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
10
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  fused_moe)
11
12
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
13
from vllm.model_executor.layers.quantization.base_config import (
14
    QuantizationConfig, QuantizeMethodBase)
15
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
16
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
17
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
18
19
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
20
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
21
22
23
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
    create_per_tensor_scale_param, cutlass_fp8_supported,
    per_tensor_dequantize, requantize_with_max_scale)
24
from vllm.model_executor.utils import set_weight_attrs
25
26
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
27

28
29
30
31
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

32

33
class Fp8Config(QuantizationConfig):
34
35
    """Config class for FP8."""

36
37
    def __init__(
        self,
38
        is_checkpoint_fp8_serialized: bool = False,
39
        activation_scheme: str = "dynamic",
40
        ignored_layers: Optional[List[str]] = None,
41
    ) -> None:
42
43
44
45
46
47
48
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
49
        self.activation_scheme = activation_scheme
50
        self.ignored_layers = ignored_layers or []
51

52
53
54
55
56
57
58
59
60
61
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
62
        return 80
63
64
65
66
67
68

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
69
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
70
71
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
72
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
73
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
74
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
75
76
                   activation_scheme=activation_scheme,
                   ignored_layers=ignored_layers)
77

78
79
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
80
81
        from vllm.attention.layer import Attention  # Avoid circular import

82
        if isinstance(layer, LinearBase):
83
84
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
85
            return Fp8LinearMethod(self)
86
87
88
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
89
            return Fp8KVCacheMethod(self)
90
        return None
91
92
93
94
95
96
97

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
98
99
100
101
102
103
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
104
105
106
107
108

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
109

110
111
112
113
    Args:
        quant_config: The quantization config.
    """

114
    def __init__(self, quant_config: Fp8Config):
115
        self.quant_config = quant_config
116
        self.cutlass_fp8_supported = cutlass_fp8_supported()
117

118
119
120
121
122
123
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        capability = current_platform.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        self.use_marlin = capability < 89

124
125
126
127
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
128
        output_partition_sizes: List[int],
129
130
131
132
133
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
134
        del input_size, output_size
135
        output_size_per_partition = sum(output_partition_sizes)
136
137
138

        layer.logical_widths = output_partition_sizes

139
140
141
142
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

143
144
145
146
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
147
148
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
149
                                       dtype=weight_dtype),
150
151
                           requires_grad=False)
        layer.register_parameter("weight", weight)
152
153
154
155
156
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
157

158
159
160
161
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
162
163
164
            scale = create_per_tensor_scale_param(output_partition_sizes,
                                                  **extra_weight_attrs)
            layer.register_parameter("weight_scale", scale)
165

166
            # INPUT ACTIVATION SCALE
167
            if self.quant_config.activation_scheme == "static":
168
169
170
                scale = create_per_tensor_scale_param(output_partition_sizes,
                                                      **extra_weight_attrs)
                layer.register_parameter("input_scale", scale)
171

172
    def process_weights_after_loading(self, layer: Module) -> None:
173
        # If checkpoint not serialized fp8, quantize the weights.
174
175
176
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
177
178

            # Update the layer with the new values.
179
180
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
181
            layer.input_scale = None
182

183
184
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
185
        else:
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
                weight_scale, weight = requantize_with_max_scale(
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    logical_widths=layer.logical_widths,
                )
202

203
            # Update layer with new values.
204
            layer.weight = Parameter(weight.t(), requires_grad=False)
205
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
206
            if self.quant_config.activation_scheme == "static":
207
208
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
209
            else:
210
                layer.input_scale = None
211

212
213
214
215
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
216

217
218
219
220
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
221

222
        if self.use_marlin:
223
224
225
226
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
227
228
229
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
230
                bias=bias)
231

232
233
234
235
236
237
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
238
239
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
240
241


242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                  2,
                                                  dtype=torch.float32),
                                       requires_grad=False)
        layer.register_parameter("w13_scale", w13_scale)

        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                 dtype=torch.float32),
                                      requires_grad=False)
        layer.register_parameter("w2_scale", w2_scale)

        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
            set_weight_attrs(w13_scale, extra_weight_attrs)
            set_weight_attrs(w2_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

            a13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                      dtype=torch.float32),
                                           requires_grad=False)
            layer.register_parameter("a13_scale", a13_scale)
            set_weight_attrs(a13_scale, extra_weight_attrs)

            a2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                     dtype=torch.float32),
                                          requires_grad=False)
            layer.register_parameter("a2_scale", a2_scale)
            set_weight_attrs(a2_scale, extra_weight_attrs)
        else:
            layer.a13_scale = None
            layer.a2_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            w13_weight = torch.empty_like(layer.w13_weight.data,
                                          dtype=torch.float8_e4m3fn)
            w2_weight = torch.empty_like(layer.w2_weight.data,
                                         dtype=torch.float8_e4m3fn)

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
            layer.w13_scale = torch.nn.Parameter(torch.ones(
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
                                                 requires_grad=False)
            for expert in range(layer.num_experts):
                w13_weight[expert, :, :], layer.w13_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
                w2_weight[expert, :, :], layer.w2_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
                if layer.a13_scale is None or layer.a2_scale is None:
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
                if (not all_close_1d(layer.a13_scale)
                        or not all_close_1d(layer.a2_scale)):
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
                layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
                                                     requires_grad=False)
                layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
                                                    requires_grad=False)

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
            assert layer.w13_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_scale.max(dim=1).values
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
389
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
390
391
392
393
394
395
396
397
398
399
400
401
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

            layer.w13_scale = torch.nn.Parameter(max_w13_scales,
                                                 requires_grad=False)
            return

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              router_logits: torch.Tensor,
              top_k: int,
402
403
404
405
              renormalize: bool = True,
              use_grouped_topk: bool = False,
              num_expert_group: Optional[int] = None,
              topk_group: Optional[int] = None) -> torch.Tensor:
406
407
408
409
410
411
412
413
414
415
416
417

        return fused_moe(x,
                         layer.w13_weight,
                         layer.w2_weight,
                         router_logits,
                         top_k,
                         renormalize=renormalize,
                         inplace=True,
                         use_fp8=True,
                         w1_scale=layer.w13_scale,
                         w2_scale=layer.w2_scale,
                         a1_scale=layer.a13_scale,
418
419
420
421
                         a2_scale=layer.a2_scale,
                         use_grouped_topk=use_grouped_topk,
                         num_expert_group=num_expert_group,
                         topk_group=topk_group)
422
423


424
425
426
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
427
428
429
    """

    def __init__(self, quant_config: Fp8Config):
430
        super().__init__(quant_config)