# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from typing import Tuple import torch def scale_from_amax_tensor( x_dtype: torch.dtype, amax: torch.Tensor, quant_dtype: torch.dtype, *, eps: float, pow_2_scales: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Derives quantization and dequantization from amax and options. Reference implementation for scale calculation. Returns: - scale: quantization scales - scale_inv: dequantization scales - amax: Amax tensor with updates made for extrema values. """ assert amax.dtype == torch.float, "amax must be a float tensor." if quant_dtype == torch.int8: fp8_max = torch.iinfo(quant_dtype).max else: fp8_max = torch.finfo(quant_dtype).max # Clamping amax to avoid division by small numbers amax = torch.max(amax, torch.tensor(eps)) # Compute scale factor scale = torch.div(fp8_max, amax) # Note frexp doesn't give back inf for exponent with an inf input # We take care of inf before pow_2_scales scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) if pow_2_scales: # Calculate rounded down exponent _, exp = torch.frexp(scale) # Positive numbers are always returned as mant, exp with # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because # of the shift. Subnormal and zero cases need not be considered because # the smallest possible result of fp8_max / amax is still normal. exp = exp - 1 # No subnormals and zero. assert (exp > -127).all() unity = torch.tensor([1.0], device=exp.device) torch.ldexp(unity, exp, out=scale) # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales # Return 0.0 for 0.0 scale for consistency with non-pow2 scale # calculation. scale = torch.where(amax == float("inf"), 0.0, scale) # Handle overflow cases for amax zero causing NaN scale = torch.where(amax == 0, 1.0, scale) # Compute scale_inv scale_inv = torch.reciprocal(scale) return scale, scale_inv, amax