Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import torch
def scale_from_amax_tensor(
x_dtype: torch.dtype,
amax: torch.Tensor,
quant_dtype: torch.dtype,
*,
eps: float,
pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
......@@ -6,63 +6,16 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType_To_Torch
# Compute scale and scale_inv from amax
def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
# option1: set scale to fp32 max when scale is inf
scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale)
# option2: when scale is inf, set scale to 1
scale = torch.where(scale == torch.inf, 1.0, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
# TODO: If/when adding a URM option an option is to cap to 126
# rather than allowing the full range of FP32 (2 - 2^23) x 2^127
# addresses cases where adding a mantissa overflows into inf scales.
# Not necessary currently without additional scale smudging options.
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv
from references.quantize_scale_calc import scale_from_amax_tensor
# compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
return scale, scale_inv, amax
return scale_from_amax_tensor(
torch.float32, amax, quant_dtype, eps=eps, pow_2_scales=pow_2_scales
)
def _multi_dim_transpose(tensor):
......@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast(
qx_t = _multi_dim_transpose(qx)
sx_t = sx
return qx, sx, qx_t, sx_t
def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
This diff is collapsed.
......@@ -30,6 +30,9 @@ if IS_HIP_EXTENSION:
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
......@@ -58,6 +61,7 @@ fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
# Supported data types
......@@ -328,9 +332,13 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
......
This diff is collapsed.
This diff is collapsed.
......@@ -82,7 +82,8 @@ class TestFP8RecipeLinearBase:
@staticmethod
def _get_mean_abs_relative_error(a, b):
return torch.mean(torch.abs((a - b) / b))
error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
return torch.mean(error)
@staticmethod
def _load_golden_tensor_values(a, b):
......@@ -97,9 +98,12 @@ class TestFP8RecipeLinearBase:
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
if recipe.float8_current_scaling():
scaling_type = "ScalingType.PER_TENSOR"
elif recipe.float8_block_scaling():
scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else:
scaling_type = "Unknown"
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
......@@ -437,9 +441,13 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
if recipe.float8_current_scaling():
scaling_type = "ScalingType.PER_TENSOR"
elif recipe.float8_block_scaling():
scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else:
scaling_type = "Unknown"
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
......@@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment