fp8.py 25.3 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Union
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
10
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  fused_moe)
11
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
12
from vllm.model_executor.layers.quantization.base_config import (
13
    QuantizationConfig, QuantizeMethodBase)
14
15
16
17
18
from vllm.model_executor.layers.quantization.gptq_marlin import (
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
    marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    pack_fp8_to_int32)
19
from vllm.model_executor.utils import set_weight_attrs
20
21
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
22

23
24
25
26
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

27

28
def cutlass_fp8_supported() -> bool:
29
    capability = current_platform.get_device_capability()
30
    capability = capability[0] * 10 + capability[1]
31
32

    return ops.cutlass_scaled_mm_supports_fp8(capability)
33
34


35
class Fp8Config(QuantizationConfig):
36
37
    """Config class for FP8."""

38
39
    def __init__(
        self,
40
        is_checkpoint_fp8_serialized: bool = False,
41
42
        activation_scheme: str = "dynamic",
    ) -> None:
43
44
45
46
47
48
49
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
50
51
        self.activation_scheme = activation_scheme

52
53
54
55
56
57
58
59
60
61
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
62
        return 80
63
64
65
66
67
68

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
69
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
70
71
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
72
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
73
74
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme)
75

76
    def get_quant_method(
77
78
79
            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

80
81
        if isinstance(layer, LinearBase):
            return Fp8LinearMethod(self)
82
83
84
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
85
            return Fp8KVCacheMethod(self)
86
        return None
87
88
89
90
91
92
93

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
94
95
96
97
98
99
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
100
101
102
103
104

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
105

106
107
108
109
    Args:
        quant_config: The quantization config.
    """

110
    def __init__(self, quant_config: Fp8Config):
111
        self.quant_config = quant_config
112
        self.cutlass_fp8_supported = cutlass_fp8_supported()
113

114
115
116
117
118
119
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        capability = current_platform.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        self.use_marlin = capability < 89

120
121
122
123
124
125
126
127
128
129
    def _create_scale_param(
        self,
        scale_name: str,
        layer: torch.nn.Module,
        output_partition_sizes: List[int],
        **extra_weight_attrs,
    ) -> None:
        scale = Parameter(torch.empty(len(output_partition_sizes),
                                      dtype=torch.float32),
                          requires_grad=False)
130
        scale[:] = torch.finfo(torch.float8_e4m3fn).min
131
        layer.register_parameter(scale_name, scale)
132
133
134
135
        set_weight_attrs(scale, {
            **extra_weight_attrs,
            "needs_scalar_to_array": True,
        })
136

137
138
139
140
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
141
        output_partition_sizes: List[int],
142
143
144
145
146
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
147
        del input_size, output_size
148
        output_size_per_partition = sum(output_partition_sizes)
149
150
151
152

        layer.process_after_load = True
        layer.logical_widths = output_partition_sizes

153
154
155
156
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

157
158
159
160
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
161
162
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
163
                                       dtype=weight_dtype),
164
165
                           requires_grad=False)
        layer.register_parameter("weight", weight)
166
167
168
169
170
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
171

172
173
174
175
176
177
178
179
180
181
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            self._create_scale_param(
                scale_name="weight_scale",
                layer=layer,
                output_partition_sizes=output_partition_sizes,
                **extra_weight_attrs)

182
            # INPUT ACTIVATION SCALE
183
184
            if self.quant_config.activation_scheme == "static":
                self._create_scale_param(
185
                    scale_name="input_scale",
186
187
188
189
                    layer=layer,
                    output_partition_sizes=output_partition_sizes,
                    **extra_weight_attrs)

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        # For GPUs without FP8 hardware support, we use Marlin for fast
        # fused dequantization
        if self.use_marlin:
            layer.marlin_state = GPTQMarlinState.REPACK

    def prepare_layer_for_marlin(self, layer: Module) -> None:
        print_warning_once(
            "Your GPU does not have native support for FP8 computation but "
            "FP8 quantization is being used. Weight-only FP8 compression will "
            "be used leveraging the Marlin kernel. This may degrade "
            "performance for compute-heavy workloads.")

        part_size_n = layer.output_size_per_partition
        part_size_k = layer.input_size_per_partition

        assert layer.marlin_state == GPTQMarlinState.REPACK
        layer.marlin_state = GPTQMarlinState.READY

        device = layer.weight.device

        # WEIGHTS
        # Repack weights to gptq format (packed int32 elements)
        packed_gptq_qweight = pack_fp8_to_int32(layer.weight)

        # Repack weights to marlin format
        marlin_qweight = ops.gptq_marlin_repack(
            b_q_weight=packed_gptq_qweight,
            perm=torch.empty(0, dtype=torch.int, device=device),
            size_k=part_size_k,
            size_n=part_size_n,
            num_bits=8,
        )
        layer.weight = Parameter(marlin_qweight, requires_grad=False)

        # WEIGHT SCALES
        # Currently Marlin doesn't support per-tensor scales, so we
        # expand it to channelwise
        scales = layer.weight_scale.repeat(1, part_size_n).to(
            layer.orig_dtype).to(device)
        # Permute scales
        marlin_scales = marlin_permute_scales(
            s=scales,
            size_k=part_size_k,
            size_n=part_size_n,
            group_size=-1,
            num_bits=8,
        )
        layer.weight_scale = Parameter(marlin_scales, requires_grad=False)

        # Allocate marlin workspace
        max_workspace_size = (
            part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
        workspace = torch.zeros(max_workspace_size,
                                dtype=torch.int,
                                device=device,
                                requires_grad=False)

        layer.workspace = workspace

249
    def process_weights_after_loading(self, layer: Module) -> None:
250
251
252
253
254
255
256
257
258
259
260
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.logical_widths = None
261
            layer.input_scale = None
262
263
            if self.use_marlin:
                self.prepare_layer_for_marlin(layer)
264
265
            return

266
267
268
269
270
271
        # If checkpoint is fp8, requantize the separately quantized logical
        # weights into a single fp8 weight with a single weight scale.
        else:
            # WEIGHT_SCALE / WEIGHT
            #   Loop over logical weights, requantizing with single scale.
            max_w_scale = layer.weight_scale.max()
272

273
274
275
276
277
278
279
280
281
282
283
            # QKV / MLP is fused in the on disk checkpoint if any of the
            # weight scales are still set to the default since we initialize
            # N weight scales for N shards but we only load 1 weight scale
            # from disk in this case. As a result, we skip dequant -> requant
            # since we already have quantized QKV together.
            # Sample Model with fused checkpoint:
            #   * nm-testing/Phi-3-mini-128k-instruct-FP8
            unfused_module_in_checkpoint = (
                layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)

            if unfused_module_in_checkpoint:
284
285
286
287
288
289
290
291
292
                start = 0
                for idx, logical_width in enumerate(layer.logical_widths):
                    end = start + logical_width
                    weight_dq = per_tensor_dequantize(
                        layer.weight[start:end, :], layer.weight_scale[idx])

                    layer.weight[start:end, :] = per_tensor_quantize(
                        weight_dq, layer.weight_scale.max())
                    start = end
293
294
295
296
297
298
299
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

            # WEIGHT
            #   Transpose weight for passing to torch._scaled_mm
            weight = layer.weight
            layer.weight = Parameter(weight.t(), requires_grad=False)

300
            # INPUT ACTIVATION SCALE
301
            #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
302
            #   Static:  set to max of the input_scales (since they are equal).
303
            if self.quant_config.activation_scheme == "dynamic":
304
                layer.input_scale = None
305
            elif self.quant_config.activation_scheme == "static":
306
307
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
308
309
310
            else:
                raise ValueError(
                    f"Unknown scheme {self.quant_config.activation_scheme}")
311

312
313
314
            if self.use_marlin:
                self.prepare_layer_for_marlin(layer)

315
316
317
318
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
319

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        if self.use_marlin:
            # For GPUs that lack FP8 hardware support, we can leverage the
            # Marlin kernel for fast weight-only FP8 quantization

            reshaped_x = x.reshape(-1, x.shape[-1])
            out_shape = x.shape[:-1] + (layer.output_size_per_partition, )

            output = ops.fp8_marlin_gemm(
                a=reshaped_x,
                b_q_weight=layer.weight,
                b_scales=layer.weight_scale,
                workspace=layer.workspace,
                num_bits=8,
                size_m=reshaped_x.shape[0],
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
            )
337

338
339
            if bias is not None:
                output.add_(bias)  # In-place add
340

341
            return output.reshape(out_shape)
342
343

        else:
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

            # ops.scaled_fp8_quant supports both dynamic and static quant.
            # If dynamic, layer.input_scale is None and x_scale computed from x
            # If static, layer.input_scale is scalar and x_scale is input_scale

            if bias is None and self.cutlass_fp8_supported:
                qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)

                # Fused GEMM_DQ
                output = ops.cutlass_scaled_mm(
                    qinput,
                    layer.weight,
                    out_dtype=x.dtype,
                    scale_a=x_scale,
                    scale_b=layer.weight_scale,
                )

            else:
                qinput, x_scale = ops.scaled_fp8_quant(x,
                                                       layer.input_scale,
                                                       batch_dim_padding=17)

                # Fused GEMM_DQ -- note we padded the input above because
                # torch._scaled_mm is more performant for matrices with
                # batch dimension > 16. Note that this could change
                # in the future.
                output, _ = torch._scaled_mm(
                    qinput,
                    layer.weight,
                    out_dtype=x.dtype,
                    scale_a=x_scale,
                    scale_b=layer.weight_scale,
                    bias=bias,
                )
378

379
        return torch.narrow(output, 0, 0, x.shape[0])
380
381


382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        layer.process_after_load = True

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                  2,
                                                  dtype=torch.float32),
                                       requires_grad=False)
        layer.register_parameter("w13_scale", w13_scale)

        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                 dtype=torch.float32),
                                      requires_grad=False)
        layer.register_parameter("w2_scale", w2_scale)

        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
            set_weight_attrs(w13_scale, extra_weight_attrs)
            set_weight_attrs(w2_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

            a13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                      dtype=torch.float32),
                                           requires_grad=False)
            layer.register_parameter("a13_scale", a13_scale)
            set_weight_attrs(a13_scale, extra_weight_attrs)

            a2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                     dtype=torch.float32),
                                          requires_grad=False)
            layer.register_parameter("a2_scale", a2_scale)
            set_weight_attrs(a2_scale, extra_weight_attrs)
        else:
            layer.a13_scale = None
            layer.a2_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            w13_weight = torch.empty_like(layer.w13_weight.data,
                                          dtype=torch.float8_e4m3fn)
            w2_weight = torch.empty_like(layer.w2_weight.data,
                                         dtype=torch.float8_e4m3fn)

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
            layer.w13_scale = torch.nn.Parameter(torch.ones(
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
                                                 requires_grad=False)
            for expert in range(layer.num_experts):
                w13_weight[expert, :, :], layer.w13_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
                w2_weight[expert, :, :], layer.w2_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
                if layer.a13_scale is None or layer.a2_scale is None:
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
                if (not all_close_1d(layer.a13_scale)
                        or not all_close_1d(layer.a2_scale)):
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
                layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
                                                     requires_grad=False)
                layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
                                                    requires_grad=False)

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
            assert layer.w13_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_scale.max(dim=1).values
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
                        start:start + shard_size, :] = per_tensor_quantize(
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

            layer.w13_scale = torch.nn.Parameter(max_w13_scales,
                                                 requires_grad=False)
            return

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              router_logits: torch.Tensor,
              top_k: int,
              renormalize: bool = True) -> torch.Tensor:

        return fused_moe(x,
                         layer.w13_weight,
                         layer.w2_weight,
                         router_logits,
                         top_k,
                         renormalize=renormalize,
                         inplace=True,
                         use_fp8=True,
                         w1_scale=layer.w13_scale,
                         w2_scale=layer.w2_scale,
                         a1_scale=layer.a13_scale,
                         a2_scale=layer.a2_scale)


563
564
565
566
567
568
569
570
class Fp8KVCacheMethod(QuantizeMethodBase):
    """Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module):
571
572
        """Create "weight" (aka kv_scale) for an attention layer.

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        Args:
            layer: The layer that is using the QuantizeMethodBase factory.
        """
        # Initialize the KV cache scale to 1.0 as the default value.
        # If the kv_scale appears in the checkpoint, it will be
        # overwritten when loading weights.
        layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)

    def apply(self, layer: torch.nn.Module) -> torch.Tensor:
        raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

    def process_weights_after_loading(self, layer: Module) -> None:
        # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
        # regardless whether the kv-scale is available in the checkpoint.
        if layer.kv_cache_dtype != "auto":
            kv_scale = layer.kv_scale.to("cpu").tolist()
            if not isinstance(kv_scale, float):
                raise ValueError("Only support per-tensor scaling factor "
                                 "for fp8 KV cache")
            layer._kv_scale = kv_scale
            if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
                print_warning_once(
                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
                    "cause accuracy issues. Please make sure kv-cache scaling "
                    "factor is available in the fp8 checkpoint.")
        del layer.kv_scale


601
def per_tensor_quantize(tensor: torch.Tensor,
602
                        inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
603
604
605
606
607
    finfo = torch.finfo(torch.float8_e4m3fn)
    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
    return qweight.to(torch.float8_e4m3fn)


608
609
610
def per_tensor_dequantize(
        tensor: torch.Tensor, inv_scale: Union[float,
                                               torch.Tensor]) -> torch.Tensor:
611
612
613
    fake_qweight = tensor.to(torch.float16)
    dq_weight = fake_qweight * inv_scale
    return dq_weight
614
615
616
617
618


def all_close_1d(x: torch.Tensor) -> bool:
    assert len(x.shape) == 1
    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))