Commit a68e5f87 authored by wenjh's avatar wenjh
Browse files

Enable fp8 on nmz


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 99a1c744
Pipeline #3434 failed with stages
in 0 seconds
...@@ -36,10 +36,12 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PA ...@@ -36,10 +36,12 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PA
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact_int8.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py_int8"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
# channelwise int8 test # channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact_int8.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py_int8"
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact_int8_tensorwise.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py_int8_tensorwise"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
......
...@@ -51,11 +51,26 @@ def _run_test(quantization): ...@@ -51,11 +51,26 @@ def _run_test(quantization):
all_boolean = [True, False] all_boolean = [True, False]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"] "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
) )
def test_distributed(quantization): def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization)
@pytest.mark.parametrize(
"quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
def test_int8_distributed(quantization):
if quantization == "fp8" and not fp8_available: if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available: if quantization == "fp8_cs" and not fp8_available:
......
...@@ -47,7 +47,7 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -47,7 +47,7 @@ def cublas_gemm_fp8_blockwise_case(
atol: float = 0.0, atol: float = 0.0,
rtol: float = 0.0 rtol: float = 0.0
): ):
if IS_HIP_EXTENSION and int8_simulation_fp8: if IS_HIP_EXTENSION:
if use_bias or use_gelu: if use_bias or use_gelu:
pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.") pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.")
if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)): if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)):
...@@ -249,7 +249,7 @@ def cublas_gemm_test_constraint_enforced( ...@@ -249,7 +249,7 @@ def cublas_gemm_test_constraint_enforced(
expected_err_cls=RuntimeError expected_err_cls=RuntimeError
): ):
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
pytest.skip("ROCm does not support cuBLAS GEMM. No need to test constraint enforcement.") pytest.skip("ROCm does not support cuBLAS blockwise FP8 gemm. No need to test constraint enforcement.")
if not fp8_blockwise_gemm_supported(): if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.") pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed # Setup device and random seed
......
...@@ -9,7 +9,7 @@ import pathlib ...@@ -9,7 +9,7 @@ import pathlib
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
...@@ -507,6 +507,9 @@ def test_quantization_block_tiling_extrema_versus_reference( ...@@ -507,6 +507,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
rtol=0.0, rtol=0.0,
) )
def fp8_blockwise_scaling_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
# FP8 per tesnor current scaling # FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
...@@ -541,12 +544,65 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase): ...@@ -541,12 +544,65 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
out_size, out_size,
dtype, dtype,
use_bias=True, use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_int8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
): ):
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import importlib import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None) ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
os.environ["NVTE_INT8_SIM_FP8"] = "1" os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None # if we cannot get all four tensors, then still set the tensor dump to None
...@@ -612,12 +668,71 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase) ...@@ -612,12 +668,71 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
out_size, out_size,
dtype, dtype,
use_bias=True, use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1.6,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_int8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
): ):
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import importlib import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None) ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
os.environ["NVTE_INT8_SIM_FP8"] = "1" os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8) importlib.reload(te.pytorch.fp8)
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None # if we cannot get all four tensors, then still set the tensor dump to None
......
# NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 test_int8_channelwise_gemm_exact.py
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union, Optional
import pytest
import torch
import transformer_engine as te
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_opt,
channelwise_dequantize,
channelwise_dequantize_transA,
channelwise_dequantize_transB,
tensorwise_dequantize)
import time
import os
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
tensorwise_int8_check = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE_CHECK", "0")))
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dtype_tols(t2.dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)
# TN
m = 4096
k = 4096
n = 6144
seed = 4096
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = "cuda"
out_dtype = torch.int32
# Allocate cuBLAS workspace
workspace_size = 128
workspace = torch.empty(128, dtype=torch.uint8, device=device)
out_quantizer = None
accumulate = False
use_gelu = False
use_bias = False
bias = None
use_grad = False
assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM"
aux_tensor = torch.empty((m, n), dtype=out_dtype, device=device) if use_gelu else None
out = torch.empty((m, n), dtype=out_dtype, device=device) if accumulate else None
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
use_split_accumulator = False
# bf16 to int8
# transa = True
# transb = False
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_out = torch.matmul(x_bf16, w_bf16.t())
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# # print("x_int8: ", x_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # y_int32 = torch._int_mm(x_int8, w_int8.t())
# # print("y_int32: ", y_int32)
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# # print("out_scales.shape: ", out_scales.shape)
# # print("out_scales: ", out_scales)
# # print("bf16_out: ", bf16_out)
# # print("output: ", output)
# # NN
# transa = False
# transb = False
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dx = torch.matmul(dy_bf16, w_bf16)
# dy_int8, dy_scales = per_token_quant_int8(dy_bf16)
# w_int8, w_scales = per_token_quant_int8_v2(w_bf16)
# # print("dy_scales.shape: ", dy_scales.shape)
# # print("w_scales.shape: ", w_scales.shape)
# # print("dy_int8: ", dy_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # dx_int32 = torch._int_mm(dy_int8, w_int8)
# # print("dx_int32: ", dx_int32)
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # print("dx_scales.shape: ", dx_scales.shape)
# # print("dx_scales: ", dx_scales)
# # print("bf16_dx: ", bf16_dx)
# # print("dx: ", dx)
# # NT
# transa = False
# transb = True
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
# dy_int8, dy_scales = per_token_quant_int8_v2(dy_bf16)
# x_int8, x_scales = per_token_quant_int8_v2(x_bf16)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # print("bf16_dw: ", bf16_dw)
# # print("dw: ", dw)
# fp8 to int8
quantizer_e5m2 = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
force_pow_2_scales=False,
amax_epsilon=0.0,
)
quantizer_e4m3 = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
force_pow_2_scales=False,
amax_epsilon=0.0,
)
# current scaling
def to_float8_CS(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E5M2,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
quantizer = quantizer_e5m2 if fp8_dtype == tex.DType.kFloat8E5M2 else quantizer_e4m3
if return_transpose:
quantizer.set_usage(rowwise=True, columnwise=True)
else:
quantizer.set_usage(rowwise=True, columnwise=False)
return quantizer(tensor)
# TN
transa = True
transb = False
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
output = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
for i in range(20):
bf16_out = torch.matmul(x_bf16, w_bf16.t())
print("bf16_out: ", bf16_out)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
end = time.time()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if int8_simulation_fp8_tensorwise:
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(x_scales, w_scales, y_int32, output)
else:
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
print("output: ", output)
if tensorwise_int8_check:
lt_output = tex.generic_gemm(
w_fp8,
transa,
x_fp8,
transb,
out,
out_quantizer,
TE_DType[torch.bfloat16],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
True,
)[0]
print("lt_output: ", lt_output)
assert_allclose([output], [lt_output])
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# torch.cuda.synchronize()
# end = time.time()
# NN
# transa = True
transa = False
transb = False
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
dx = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
bf16_dx = torch.matmul(dy_bf16, w_bf16)
print("bf16_dx: ", bf16_dx)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_dx = torch.matmul(dy_bf16, w_bf16)
torch.cuda.synchronize()
end = time.time()
# Cast to FP8 and back
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = dy_fp8._data.view(dtype=torch.int8), dy_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("dy_scales.shape: ", dy_scales.shape)
# print("w_scales.shape: ", w_scales.shape)
# print("dy_int8: ", dy_int8)
# print("w_int8: ", w_int8)
# print("w_scales: ", w_scales)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales, w_scales, dx_int32, dx)
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
print("dx: ", dx)
if tensorwise_int8_check:
lt_dx = tex.generic_gemm(
w_fp8,
transa,
dy_fp8,
transb,
out,
out_quantizer,
TE_DType[torch.bfloat16],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
True,
)[0]
print("lt_dx: ", lt_dx)
assert_allclose([dx], [lt_dx])
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# torch.cuda.synchronize()
# end = time.time()
# NT
# transa = True
# transb = False
transa = False
transb = True
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
dw = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
print("bf16_dw: ", bf16_dw)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
torch.cuda.synchronize()
end = time.time()
# Cast to FP8 and back
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = dy_fp8._data.view(dtype=torch.int8), dy_fp8._scale_inv
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
else:
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales, x_scales, dw_int32, dw)
else:
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print("dw: ", dw)
if tensorwise_int8_check:
lt_dw = tex.generic_gemm(
x_fp8,
transa,
dy_fp8,
transb,
out,
out_quantizer,
TE_DType[torch.bfloat16],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
True,
)[0]
print("lt_dw: ", lt_dw)
assert_allclose([dw], [lt_dw])
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
# torch.cuda.synchronize()
# end = time.time()
# bacth gemm wgrad
m = 1024
k = 1024
n = 1024
b = 4
transa = False
transb = True
dy_int8 = (torch.randn((b, m, n), device=device)).to(dtype=torch.int8)
x_int8 = (torch.randn((b, m, k), device=device)).to(dtype=torch.int8)
int32_dw_list = []
for i in range(b):
int32_dw = torch._int_mm(dy_int8[i].t(), x_int8[i])
# bf16_dw = torch.matmul(dy_int8[i].t(), x_int8[i])
int32_dw_list.append(int32_dw)
batched_int32_dw = torch.stack(int32_dw_list)
# print("batched_int32_dw.shape: ", batched_int32_dw.shape)
# print("batched_int32_dw: ", batched_int32_dw)
out_dtype = torch.int32
out = torch.empty((b, n, k), dtype=out_dtype, device=device)
te_dw = tex.generic_batchgemm(
x_int8.view(-1, x_int8.size(-1)),
transa,
dy_int8.view(-1, dy_int8.size(-1)),
transb,
out.view(-1, out.size(-1)),
b,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch.testing.assert_close(te_dw.view(b, -1, te_dw.size(-1)), batched_int32_dw, atol=0, rtol=0)
# NT
b = 4
transa = False
transb = True
dy_bf16 = [(torch.randn((m, n), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
x_bf16 = [(torch.randn((m, k), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
dw_ref = [(torch.randn((n, k), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
dw = [(torch.randn((n, k), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
# Cast to FP8 and back
dy_fp8 = [to_float8_CS(dy_bf16[i], fp8_dtype=tex.DType.kFloat8E5M2) for i in range(b)]
x_fp8 = [to_float8_CS(x_bf16[i], fp8_dtype=tex.DType.kFloat8E5M2) for i in range(b)]
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = [dy_fp8[i]._data.view(dtype=torch.int8) for i in range(b)], [dy_fp8[i]._scale_inv for i in range(b)]
x_int8, x_scales = [x_fp8[i]._data.view(dtype=torch.int8) for i in range(b)], [x_fp8[i]._scale_inv for i in range(b)]
else:
dy_int8, dy_scales = [], []
x_int8, x_scales = [], []
assert False
for i in range(b):
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32 = tex.generic_gemm(
x_int8[i],
transa,
dy_int8[i],
transb,
None,
None,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales[i], x_scales[i], dw_int32, dw_ref[i])
else:
assert False
dw_ref_tensor = torch.stack(dw_ref).contiguous().view(-1, dw_ref[0].size(-1))
# print("dw_ref_tensor: ", dw_ref_tensor)
torch.cuda.synchronize()
dy_int8_tensor = torch.stack(dy_int8).contiguous()
dy_scales_tensor = torch.stack(dy_scales).contiguous()
x_int8_tensor = torch.stack(x_int8).contiguous()
x_scales_tensor = torch.stack(x_scales).contiguous()
dw_tensor = torch.stack(dw).contiguous()
out_dtype = torch.bfloat16
dw_tensor = tex.tensorwise_int8_batchgemm(
x_int8_tensor.view(-1, x_int8_tensor.size(-1)),
transa,
dy_int8_tensor.view(-1, dy_int8_tensor.size(-1)),
transb,
x_scales_tensor,
dy_scales_tensor,
dw_tensor.view(-1, dw_tensor.size(-1)),
b,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# print("dw_tensor: ", dw_tensor)
torch.testing.assert_close(dw_ref_tensor, dw_tensor, atol=1e-5, rtol=1e-5)
...@@ -28,7 +28,7 @@ from transformer_engine.common.recipe import ( ...@@ -28,7 +28,7 @@ from transformer_engine.common.recipe import (
) )
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import (get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938)
from .jit import jit_fuser from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
...@@ -45,18 +45,14 @@ __all__ = [ ...@@ -45,18 +45,14 @@ __all__ = [
"get_default_recipe", "get_default_recipe",
] ]
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_K100_AI, is_BW
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def check_fp8_support() -> Tuple[bool, str]: def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if (is_K100_AI() or is_BW()) and int8_simulation_fp8: if is_gfx938():
return True, "DCU turn on fp8 simulation with int8" return True, ""
else: if (is_gfx928() or is_gfx936()) and int8_simulation_fp8 and int8_simulation_fp8_tensorwise:
return False, "DCU not support fp8 for now" return True, ""
else:
if get_device_compute_capability() >= (9, 0): # hopper and above if get_device_compute_capability() >= (9, 0): # hopper and above
return True, "" return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada if get_device_compute_capability() < (8, 9): # pre-ada
...@@ -71,6 +67,8 @@ def check_fp8_support() -> Tuple[bool, str]: ...@@ -71,6 +67,8 @@ def check_fp8_support() -> Tuple[bool, str]:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def check_mxfp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if IS_HIP_EXTENSION:
return False, "DCU not support mxfp8 for now"
if get_device_compute_capability() >= (12, 0): if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above if get_device_compute_capability() >= (10, 0): # blackwell and above
...@@ -83,7 +81,6 @@ def check_nvfp4_support() -> Tuple[bool, str]: ...@@ -83,7 +81,6 @@ def check_nvfp4_support() -> Tuple[bool, str]:
"""Return if nvfp4 support is available""" """Return if nvfp4 support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
return False, "NVFP4 is not supported on rocm platform." return False, "NVFP4 is not supported on rocm platform."
else:
if get_device_compute_capability() >= (10, 0): # blackwell and above if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, "" return True, ""
return False, "Device compute capability 10.0 or higher required for NVFP4 execution." return False, "Device compute capability 10.0 or higher required for NVFP4 execution."
...@@ -93,9 +90,10 @@ def check_nvfp4_support() -> Tuple[bool, str]: ...@@ -93,9 +90,10 @@ def check_nvfp4_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available""" """Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if is_K100_AI() or is_BW() and int8_simulation_fp8: if is_gfx938():
return True, ""
if (is_gfx928() or is_gfx936()) and int8_simulation_fp8:
return True, "" return True, ""
else:
return False, "DCU not support block_scaling fp8 for now" return False, "DCU not support block_scaling fp8 for now"
if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9:
return True, "" return True, ""
......
...@@ -10,10 +10,9 @@ import os ...@@ -10,10 +10,9 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from . import torch_version from . import torch_version
from .quantized_tensor import Quantizer from .quantized_tensor import Quantizer
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"]
...@@ -445,20 +444,64 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ...@@ -445,20 +444,64 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
) )
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
def is_mi200(): @functools.lru_cache(maxsize=None)
"""check whether this machine is mi200/210/250""" def _get_gcn_arch_impl(device: torch.device) -> int:
props = torch.cuda.get_device_properties(device)
import re import re
return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) if re.search('gfx906', props.gcnArchName) is not None:
return 906
if re.search('gfx926', props.gcnArchName) is not None:
return 926
if re.search('gfx928', props.gcnArchName) is not None:
return 928
if re.search('gfx936', props.gcnArchName) is not None:
return 936
if re.search('gfx938', props.gcnArchName) is not None:
return 938
raise RuntimeError(f"Unsupported GCN Arch {props.gcnArchName}")
def _get_gcn_arch() -> int:
return _get_gcn_arch_impl(torch.cuda.current_device())
def is_gfx906() -> bool:
"""check whether this machine is gfx906"""
return _get_gcn_arch() == 906
def is_gfx926() -> bool:
"""check whether this machine is gfx926"""
return _get_gcn_arch() == 926
def is_gfx928() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 928
def is_gfx936() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 936
def is_gfx938() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 938
else:
def is_gfx906() -> bool:
"""gfx906 is only available on ROCm"""
return False
def is_K100_AI(): def is_gfx926() -> bool:
"""check whether this machine is K100_AI""" """gfx926 is only available on ROCm"""
import re return False
return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_BW(): def is_gfx928() -> bool:
"""check whether this machine is BW""" """gfx928 is only available on ROCm"""
import re return False
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_gfx936() -> bool:
"""gfx936 is only available on ROCm"""
return False
def is_gfx938() -> bool:
"""gfx938 is only available on ROCm"""
return False
def assert_dim_for_all_gather( def assert_dim_for_all_gather(
tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer
...@@ -475,12 +518,8 @@ def is_bf16_compatible() -> bool: ...@@ -475,12 +518,8 @@ def is_bf16_compatible() -> bool:
check on device compute capability to enforce sm_80 or higher. check on device compute capability to enforce sm_80 or higher.
""" """
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
# only MI200 and MI300 machines support bf16 # only these arch support bf16
if get_device_compute_capability() >= (9, 4) or is_mi200() or is_K100_AI() or is_BW(): return is_gfx928() or is_gfx936() or is_gfx938()
return True
else:
return False
else:
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
...@@ -515,7 +554,6 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool: ...@@ -515,7 +554,6 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if is_blockwise: if is_blockwise:
return False return False
else:
return True return True
device_capability = torch.cuda.get_device_capability() device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0) return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment