Commit 1d95abb9 authored by wenjh's avatar wenjh
Browse files

Enable fp8 on nmz


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 5c63251d
...@@ -36,10 +36,12 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_P ...@@ -36,10 +36,12 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_P
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
# channelwise int8 test # channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact_int8.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py_int8"
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact_int8_tensorwise.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py_int8_tensorwise"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact_int8.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py_int8"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
......
...@@ -36,8 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ ...@@ -36,8 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8_int8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py_int8"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests # debug tests
...@@ -47,7 +47,8 @@ NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_ ...@@ -47,7 +47,8 @@ NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
NVTE_INT8_SIM_FP8=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" NVTE_INT8_SIM_FP8=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py_int8"
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug # standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
......
...@@ -16,7 +16,7 @@ import transformer_engine ...@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, int8_simulation_fp8, int8_simulation_fp8_tensorwise
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import ( from test_numerics import (
...@@ -587,7 +587,7 @@ def test_fake_quant_fp8( ...@@ -587,7 +587,7 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad), "dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input), "wgrad_fp8": not (wgrad_grad or wgrad_input),
} }
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION and int8_simulation_fp8:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]: if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation. return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0: if WORLD_RANK == 0:
......
...@@ -51,9 +51,8 @@ def _run_test(quantization): ...@@ -51,9 +51,8 @@ def _run_test(quantization):
all_boolean = [True, False] all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) @pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization): def test_int8_distributed(quantization):
if quantization == "fp8" and not fp8_available: if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available: if quantization == "fp8_cs" and not fp8_available:
...@@ -74,3 +73,15 @@ def test_distributed(quantization): ...@@ -74,3 +73,15 @@ def test_distributed(quantization):
else: else:
del os.environ["NVTE_INT8_SIM_FP8"] del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
...@@ -46,7 +46,7 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -46,7 +46,7 @@ def cublas_gemm_fp8_blockwise_case(
atol: float = 0.0, atol: float = 0.0,
rtol: float = 0.0 rtol: float = 0.0
): ):
if IS_HIP_EXTENSION and int8_simulation_fp8: if IS_HIP_EXTENSION:
if use_bias or use_gelu: if use_bias or use_gelu:
pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.") pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.")
if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)): if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)):
...@@ -167,7 +167,7 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -167,7 +167,7 @@ def cublas_gemm_fp8_blockwise_case(
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if IS_HIP_EXTENSION and int8_simulation_fp8: if IS_HIP_EXTENSION and int8_simulation_fp8:
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled): if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None) y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled): elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
...@@ -248,7 +248,7 @@ def cublas_gemm_test_constraint_enforced( ...@@ -248,7 +248,7 @@ def cublas_gemm_test_constraint_enforced(
expected_err_cls=RuntimeError expected_err_cls=RuntimeError
): ):
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
pytest.skip("ROCm does not support cuBLAS GEMM. No need to test constraint enforcement.") pytest.skip("ROCm does not support cuBLAS blockwise FP8 gemm. No need to test constraint enforcement.")
if not fp8_blockwise_gemm_supported(): if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.") pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed # Setup device and random seed
......
...@@ -495,6 +495,9 @@ def test_quantization_block_tiling_extrema_versus_reference( ...@@ -495,6 +495,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
rtol=0.0, rtol=0.0,
) )
def fp8_blockwise_scaling_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
# FP8 per tesnor current scaling # FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
...@@ -530,6 +533,59 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): ...@@ -530,6 +533,59 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
dtype, dtype,
use_bias=True, use_bias=True,
): ):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_int_sim_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import importlib import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None) ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
...@@ -601,6 +657,65 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase) ...@@ -601,6 +657,65 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
dtype, dtype,
use_bias=True, use_bias=True,
): ):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1.6,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_int8_sim_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import importlib import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None) ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
...@@ -629,7 +744,7 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase) ...@@ -629,7 +744,7 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
use_bias, use_bias,
seed=torch.initial_seed(), seed=torch.initial_seed(),
dtype=dtype, dtype=dtype,
y_error=0.5 if not IS_HIP_EXTENSION else 0.9, y_error=0.5 if not IS_HIP_EXTENSION else 0.9,
ln_out_error=0.5, ln_out_error=0.5,
dgrad_error=1.6 if not IS_HIP_EXTENSION else 1.0, dgrad_error=1.6 if not IS_HIP_EXTENSION else 1.0,
wgrad_error=1, wgrad_error=1,
...@@ -642,4 +757,4 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase) ...@@ -642,4 +757,4 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
os.environ["NVTE_INT8_SIM_FP8"] = ori_int8_sim_fp8 os.environ["NVTE_INT8_SIM_FP8"] = ori_int8_sim_fp8
else: else:
del os.environ["NVTE_INT8_SIM_FP8"] del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
\ No newline at end of file
...@@ -13,6 +13,7 @@ from collections import deque ...@@ -13,6 +13,7 @@ from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
Recipe, Recipe,
...@@ -24,53 +25,51 @@ from transformer_engine.common.recipe import ( ...@@ -24,53 +25,51 @@ from transformer_engine.common.recipe import (
) )
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938
from .jit import jit_fuser from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0"))) int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128")) blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = ["fp8_autocast", "fp8_model_init"] __all__ = ["fp8_autocast", "fp8_model_init"]
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_K100_AI, is_BW
def check_fp8_support() -> Tuple[bool, str]: def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if (is_K100_AI() or is_BW()) and int8_simulation_fp8: if is_gfx938():
return True, "DCU turn on fp8 simulation with int8" return True, ""
else: if (is_gfx928() or is_gfx936()) and int8_simulation_fp8 and int8_simulation_fp8_tensorwise:
return False, "DCU not support fp8 for now"
else:
if get_device_compute_capability() >= (9, 0): # hopper and above
return True, "" return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada return False, "DCU not support fp8 for now"
return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_device_compute_capability() >= (9, 0): # hopper and above
if tex.get_cublasLt_version() < 120103: return True, ""
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." if get_device_compute_capability() < (8, 9): # pre-ada
if float(torch.version.cuda) < 12.1: return False, "Device compute capability 8.9 or higher required for FP8 execution."
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, "" return True, ""
def check_mxfp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if IS_HIP_EXTENSION:
return False, "DCU not support mxfp8 for now"
if get_device_compute_capability() >= (12, 0): if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, "" return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution." return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available""" """Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if is_K100_AI() or is_BW(): if is_gfx938():
return True, "" return True, ""
else: if (is_gfx928() or is_gfx936()) and int8_simulation_fp8:
return False, "DCU not support block_scaling fp8 for now" return True, ""
return False, "DCU not support block_scaling fp8 for now"
if ( if (
get_device_compute_capability() >= (9, 0) get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0) and get_device_compute_capability() < (10, 0)
...@@ -79,7 +78,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ...@@ -79,7 +78,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
return True, "" return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None: def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported.""" """Check if the given recipe is supported."""
recipe_supported = True recipe_supported = True
...@@ -102,7 +100,6 @@ def get_default_fp8_recipe() -> Recipe: ...@@ -102,7 +100,6 @@ def get_default_fp8_recipe() -> Recipe:
return Float8CurrentScaling() return Float8CurrentScaling()
return DelayedScaling() return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor""" """Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or ( if fp8_recipe.fp8_format == Format.E4M3 or (
......
...@@ -10,19 +10,20 @@ import os ...@@ -10,19 +10,20 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION
ActivationOffloadEnabled = False ActivationOffloadEnabled = False
def get_activation_offloading(): def get_activation_offloading():
"""Get global status of get_activation_offloading"""
global ActivationOffloadEnabled global ActivationOffloadEnabled
return ActivationOffloadEnabled return ActivationOffloadEnabled
def set_activation_offloading(activation_offloading): def set_activation_offloading(activation_offloading):
"""Set global status of get_activation_offloading"""
global ActivationOffloadEnabled global ActivationOffloadEnabled
ActivationOffloadEnabled = activation_offloading ActivationOffloadEnabled = activation_offloading
...@@ -467,34 +468,73 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ...@@ -467,34 +468,73 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
) )
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
def is_mi200(): @functools.lru_cache(maxsize=None)
"""check whether this machine is mi200/210/250""" def _get_gcn_arch_impl(device: torch.device) -> int:
import re props = torch.cuda.get_device_properties(device)
return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) import re
if re.search('gfx906', props.gcnArchName) is not None:
def is_K100_AI(): return 906
"""check whether this machine is K100_AI""" if re.search('gfx926', props.gcnArchName) is not None:
import re return 926
return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) if re.search('gfx928', props.gcnArchName) is not None:
return 928
def is_BW(): if re.search('gfx936', props.gcnArchName) is not None:
"""check whether this machine is BW""" return 936
import re if re.search('gfx938', props.gcnArchName) is not None:
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) return 938
raise RuntimeError(f"Unsupported GCN Arch {props.gcnArchName}")
def _get_gcn_arch() -> int:
return _get_gcn_arch_impl(torch.cuda.current_device())
def is_gfx906() -> bool:
"""check whether this machine is gfx906"""
return _get_gcn_arch() == 906
def is_gfx926() -> bool:
"""check whether this machine is gfx926"""
return _get_gcn_arch() == 926
def is_gfx928() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 928
def is_gfx936() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 936
def is_gfx938() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 938
else:
def is_gfx906() -> bool:
"""gfx906 is only available on ROCm"""
return False
def is_gfx926() -> bool:
"""gfx926 is only available on ROCm"""
return False
def is_gfx928() -> bool:
"""gfx928 is only available on ROCm"""
return False
def is_gfx936() -> bool:
"""gfx936 is only available on ROCm"""
return False
def is_gfx938() -> bool:
"""gfx938 is only available on ROCm"""
return False
def is_bf16_compatible() -> None: def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit """Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher. check on device compute capability to enforce sm_80 or higher.
""" """
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
# only MI200 and MI300 machines support bf16 # only these arch support bf16
if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW(): return is_gfx928() or is_gfx936() or is_gfx938()
return True return torch.cuda.get_device_capability()[0] >= 8
else:
return False
else:
return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -505,8 +545,7 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool: ...@@ -505,8 +545,7 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if is_blockwise: if is_blockwise:
return False return False
else: return True
return True
device_capability = torch.cuda.get_device_capability() device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0) return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment