Commit e79c3e83 authored by wenjh's avatar wenjh
Browse files

Fix some bug of nmz fp8


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 1d95abb9
...@@ -16,7 +16,7 @@ import transformer_engine ...@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, int8_simulation_fp8, int8_simulation_fp8_tensorwise from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import ( from test_numerics import (
...@@ -587,7 +587,7 @@ def test_fake_quant_fp8( ...@@ -587,7 +587,7 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad), "dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input), "wgrad_fp8": not (wgrad_grad or wgrad_input),
} }
if IS_HIP_EXTENSION and int8_simulation_fp8: if IS_HIP_EXTENSION:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]: if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation. return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0: if WORLD_RANK == 0:
......
...@@ -51,6 +51,18 @@ def _run_test(quantization): ...@@ -51,6 +51,18 @@ def _run_test(quantization):
all_boolean = [True, False] all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) @pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_int8_distributed(quantization): def test_int8_distributed(quantization):
if quantization == "fp8" and not fp8_available: if quantization == "fp8" and not fp8_available:
...@@ -73,15 +85,3 @@ def test_int8_distributed(quantization): ...@@ -73,15 +85,3 @@ def test_int8_distributed(quantization):
else: else:
del os.environ["NVTE_INT8_SIM_FP8"] del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
...@@ -584,13 +584,13 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): ...@@ -584,13 +584,13 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
dtype, dtype,
use_bias=True, use_bias=True,
): ):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import importlib import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None) ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
os.environ["NVTE_INT8_SIM_FP8"] = "1" os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None # if we cannot get all four tensors, then still set the tensor dump to None
...@@ -714,13 +714,13 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase) ...@@ -714,13 +714,13 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
dtype, dtype,
use_bias=True, use_bias=True,
): ):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import importlib import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None) ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
os.environ["NVTE_INT8_SIM_FP8"] = "1" os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None # if we cannot get all four tensors, then still set the tensor dump to None
......
...@@ -25,7 +25,7 @@ from transformer_engine.common.recipe import ( ...@@ -25,7 +25,7 @@ from transformer_engine.common.recipe import (
) )
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938 from .utils import (get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938)
from .jit import jit_fuser from .jit import jit_fuser
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment