Commit e79c3e83 authored by wenjh's avatar wenjh
Browse files

Fix some bug of nmz fp8


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 1d95abb9
......@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, int8_simulation_fp8, int8_simulation_fp8_tensorwise
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import (
......@@ -587,7 +587,7 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if IS_HIP_EXTENSION and int8_simulation_fp8:
if IS_HIP_EXTENSION:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0:
......
......@@ -51,6 +51,18 @@ def _run_test(quantization):
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_int8_distributed(quantization):
if quantization == "fp8" and not fp8_available:
......@@ -73,15 +85,3 @@ def test_int8_distributed(quantization):
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
......@@ -584,13 +584,13 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION:
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
......@@ -714,13 +714,13 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION:
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
......
......@@ -25,7 +25,7 @@ from transformer_engine.common.recipe import (
)
from .constants import dist_group_type
from .utils import get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938
from .utils import (get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938)
from .jit import jit_fuser
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment