Unverified Commit 7acb5e2b authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

Handle the scaling factor when amax is too tiny that leads to an infinite scale (#786)



* Handle the scaling factor when amax is too tiny that leads to an infinite scale
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>

* revert formatting changes
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>

* fix comments
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>

* Apply review suggestion
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJinze Xue <155670984+jinzex@users.noreply.github.com>

* Apply review suggestion
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJinze Xue <155670984+jinzex@users.noreply.github.com>

* Apply review suggestion
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJinze Xue <155670984+jinzex@users.noreply.github.com>

* apply review suggestion
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>

* add test_recipe.py to qa/L0_pytorch_unittest/test.sh; fix unittest for is_first_microbatch=False
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>

* revert changes to update_weight_scale_inv
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarJinze Xue <jinzex@nvidia.com>
Signed-off-by: default avatarJinze Xue <155670984+jinzex@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarJinze Xue <jinzex@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent a8178684
......@@ -8,6 +8,7 @@ set -e
pip install pytest==7.2 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
......
......@@ -9,9 +9,10 @@ import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
amax_and_scale_update,
_amax_and_scale_update,
get_default_fp8_recipe,
)
......@@ -162,3 +163,98 @@ class TestFP8Recipe:
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
@pytest.mark.parametrize(
"fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"]
)
def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype):
if fp8_dtype == tex.DType.kFloat8E4M3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
fp8_max = fp8_format.value.max_fwd
elif fp8_dtype == tex.DType.kFloat8E5M2:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
fp8_max = fp8_format.value.max_bwd
else:
raise ValueError(f"{fp8_dtype=} is not supported")
scaling_factor_compute_algo = None
if fused_update:
scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe:
te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin)
)
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
)
# Setup fp8_meta dictionary
def setup_fp8_meta():
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))
return module.fp8_meta
fp8_meta = setup_fp8_meta()
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
# Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose
fp8_meta[forward_key] = tex.FP8TensorMeta()
fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda")
fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")
# test different scenarios
if amax_case == "zero":
fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda")
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "tiny":
# calculate the minimum amax value that results in a FP32 maximum scale
fp32_max = torch.tensor(torch.finfo(torch.float32).max)
tiny_amax = fp8_max / fp32_max
# make the amax less than the minimum amax so that the scale will be infinite
amax_value = tiny_amax / 2
fp8_meta[forward_key].amax_history = torch.tensor(
[[amax_value]], dtype=torch.float32, device="cuda"
)
# expected scale is FP32_max
expected_scale = fp32_max.view(1).cuda()
elif amax_case == "normal":
# plus a small epsilon to avoid zero amax
amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5
fp8_meta[forward_key].amax_history = amax_value.view(1, 1)
expected_scale = fp8_max / amax_value
elif amax_case == "inf":
fp8_meta[forward_key].amax_history = torch.tensor(
[[torch.inf]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "nan":
fp8_meta[forward_key].amax_history = torch.tensor(
[[torch.nan]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
if fused_update:
tex.fused_amax_and_scale_update_after_reduction(
fp8_meta[forward_key].amax_history.clone().view(-1),
[fp8_meta[forward_key].amax_history],
[fp8_meta[forward_key].scale],
[fp8_meta[forward_key].scale_inv],
recipe.amax_compute_algo,
fp8_dtype,
recipe.margin,
)
else:
_amax_and_scale_update(
fp8_meta[forward_key].amax_history,
fp8_meta[forward_key].scale,
fp8_meta[forward_key].scale_inv,
fp8_max,
recipe,
)
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale))
......@@ -8,6 +8,7 @@
#include <cmath>
#include <string>
#include <limits>
#include "../common.h"
#include "../util/logging.h"
......@@ -151,6 +152,13 @@ kernel(const float* amax_history_ptr,
} else {
scale = scale_ptr[bid];
}
// When the amax is too tiny that the scale becoming infinite in FP32,
// we set the scale to the max value of FP32. In this case, the tensor’s
// amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
updated_scale_ptr[bid] = scale;
// Update scale inverse
......@@ -239,12 +247,30 @@ kernel_bulk(
// Update scale and scale inverse
if (tid == 0) {
// Computing the scaling factor requires consideration of the following scenarios:
// 1. amax == 0:
// No action is possible, set scale to the previous scale (or 1).
// 2. 0 < amax < tiny_amax
// The amax is too tiny that the scale becomes infinite in FP32.
// Set scale = FP32_max
// 3. tiny_amax <= amax < FP32_max:
// Set scale = FP8_max (or scaled_max) / amax
// 4. When amax == inf or amax == nan:
// No action is possible, set scale to the previous scale (or 1).
float scale;
if (isfinite(amax) && amax > 0) {
scale = scaled_max / amax;
} else {
scale = p.param[bid].scale[count];
}
// When the amax is too tiny that the scale becoming infinite in FP32,
// we set the scale to the max value of FP32. In this case, the tensor’s
// amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
}
p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale;
}
......
......@@ -598,11 +598,24 @@ def _default_sf_compute(
scale: torch.Tensor,
fp8_max: float,
margin: int,
_fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter
) -> torch.Tensor:
"""Default function to convert amax to scaling factor."""
"""Default function to convert amax to scaling factor.
Computing the scaling factor requires consideration of the following scenarios:
1. amax == 0:
No action is possible, set scale to the previous scale (or 1).
2. 0 < amax < tiny_amax
The amax is too tiny that the scale becomes infinite in FP32.
Set scale = FP32_max
3. tiny_amax <= amax < FP32_max:
Set scale = FP8_max (or scaled_max) / amax
4. When amax == inf or amax == nan:
No action is possible, set scale to the previous scale (or 1).
"""
sf = (fp8_max / amax) / (2 ** margin)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
scale.copy_(sf)
return scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment