quantize_scale_calc.py 2.47 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from typing import Tuple
import torch
wenjh's avatar
wenjh committed
7
from torch.utils.cpp_extension import IS_HIP_EXTENSION
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

def scale_from_amax_tensor(
    x_dtype: torch.dtype,
    amax: torch.Tensor,
    quant_dtype: torch.dtype,
    *,
    eps: float,
    pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Derives quantization and dequantization from amax and options.

    Reference implementation for scale calculation.

    Returns:
    - scale: quantization scales
    - scale_inv: dequantization scales
    - amax: Amax tensor with updates made for extrema values.
    """
    assert amax.dtype == torch.float, "amax must be a float tensor."
yuguo's avatar
yuguo committed
27
28
29
30
    if quant_dtype == torch.int8:
        fp8_max = torch.iinfo(quant_dtype).max
    else:
        fp8_max = torch.finfo(quant_dtype).max
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    # Clamping amax to avoid division by small numbers
    amax = torch.max(amax, torch.tensor(eps))

    # Compute scale factor
    scale = torch.div(fp8_max, amax)
    # Note frexp doesn't give back inf for exponent with an inf input
    # We take care of inf before pow_2_scales
    scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale)
    if pow_2_scales:
        # Calculate rounded down exponent
        _, exp = torch.frexp(scale)
        # Positive numbers are always returned as mant, exp with
        # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
        # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
        # of the shift. Subnormal and zero cases need not be considered because
        # the smallest possible result of fp8_max / amax is still normal.
        exp = exp - 1
        # No subnormals and zero.
        assert (exp > -127).all()
        unity = torch.tensor([1.0], device=exp.device)
wenjh's avatar
wenjh committed
51
52
53
54
55
        if IS_HIP_EXTENSION:
            host_scale = torch.ldexp(unity.cpu(), exp.cpu())
            scale = host_scale.to(exp.device)
        else:
            torch.ldexp(unity, exp, out=scale)
56
57
58
59
60
61
62
63
64
65
66
67
        # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
        # Return 0.0 for 0.0 scale for consistency with non-pow2 scale
        # calculation.
        scale = torch.where(amax == float("inf"), 0.0, scale)

    # Handle overflow cases for amax zero causing NaN
    scale = torch.where(amax == 0, 1.0, scale)

    # Compute scale_inv
    scale_inv = torch.reciprocal(scale)

    return scale, scale_inv, amax