Commit 065160ab authored by wenjh's avatar wenjh
Browse files

Add int8 blockwise gemm test to float8


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 0c461880
...@@ -40,7 +40,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetenso ...@@ -40,7 +40,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetenso
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
......
...@@ -6,16 +6,23 @@ import pytest ...@@ -6,16 +6,23 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
) )
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def fp8_blockwise_gemm_supported() -> bool: def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
...@@ -45,6 +52,11 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -45,6 +52,11 @@ def cublas_gemm_fp8_blockwise_case(
atol: float = 0.0, atol: float = 0.0,
rtol: float = 0.0 rtol: float = 0.0
): ):
if IS_HIP_EXTENSION and int8_simulation_fp8:
if use_bias or use_gelu:
pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.")
if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)):
pytest.skip("Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm.")
if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2: if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2") pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_x_1d_scaled or is_w_1d_scaled): if not (is_x_1d_scaled or is_w_1d_scaled):
...@@ -157,27 +169,64 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -157,27 +169,64 @@ def cublas_gemm_fp8_blockwise_case(
aux_tensor_ref = aux_tensor.clone() if use_gelu else None aux_tensor_ref = aux_tensor.clone() if use_gelu else None
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output if IS_HIP_EXTENSION and int8_simulation_fp8:
# We are just capturing out. if use_lightop_w8a8([block_len, block_len]):
y = tex.generic_gemm( if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
qw, y = lightop.gemm_w8a8_asm(
transa, qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
qx, )
transb, elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
out.clone() if accumulate else None, y = lightop.gemm_w8a8_xgrad_asm(
out_quantizer, qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
TE_DType[out_dtype], )
bias, elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
bias_dtype, y = lightop.gemm_w8a8_wgrad_asm(
use_gelu, qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None, accumulate, [block_len, block_len], out_dtype, 'TN'
aux_tensor, )
use_grad, else:
workspace, assert False, "Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
workspace.shape[0], else:
accumulate, if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
use_split_accumulator, y, _ = w8a8_block_int8_matmul(
)[0] qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul_wgrad(
qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=out_dtype
)
else:
assert False, "Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
else:
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y = tex.generic_gemm(
qw,
transa,
qx,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# just in case of accumulation, make sure y_ref and y are not the same tensor # just in case of accumulation, make sure y_ref and y are not the same tensor
assert y_ref is not y, "y_ref and y should not be the same tensor" assert y_ref is not y, "y_ref and y should not be the same tensor"
...@@ -227,6 +276,8 @@ def cublas_gemm_test_constraint_enforced( ...@@ -227,6 +276,8 @@ def cublas_gemm_test_constraint_enforced(
expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED", expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED",
expected_err_cls=RuntimeError expected_err_cls=RuntimeError
): ):
if IS_HIP_EXTENSION:
pytest.skip("ROCm does not support cuBLAS GEMM. No need to test constraint enforcement.")
if not fp8_blockwise_gemm_supported(): if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.") pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed # Setup device and random seed
...@@ -333,8 +384,8 @@ def cublas_gemm_test_constraint_enforced( ...@@ -333,8 +384,8 @@ def cublas_gemm_test_constraint_enforced(
(1024, 4096, 1024), (1024, 4096, 1024),
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str)
...@@ -389,8 +440,8 @@ def test_cublas_gemm_fp8_blockwise_shape_varying( ...@@ -389,8 +440,8 @@ def test_cublas_gemm_fp8_blockwise_shape_varying(
(320, 256, 336), (320, 256, 336),
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str) @pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) @pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str)
...@@ -449,8 +500,8 @@ def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying( ...@@ -449,8 +500,8 @@ def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying(
(256, 256, 256), (256, 256, 256),
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str) @pytest.mark.parametrize("x_magnitude", [1e-3], ids=str)
...@@ -511,8 +562,8 @@ def test_cublas_gemm_fp8_blockwise_bias( ...@@ -511,8 +562,8 @@ def test_cublas_gemm_fp8_blockwise_bias(
(4096, 128, 4096), (4096, 128, 4096),
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str)
...@@ -584,8 +635,8 @@ def test_cublas_gemm_fp8_blockwise_columnwise( ...@@ -584,8 +635,8 @@ def test_cublas_gemm_fp8_blockwise_columnwise(
(256, 256, 256), (256, 256, 256),
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str)
...@@ -913,8 +964,8 @@ def test_illegal_2D_by_2D_enforced( ...@@ -913,8 +964,8 @@ def test_illegal_2D_by_2D_enforced(
(256, 128, 252, False, False), (256, 128, 252, False, False),
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn] if not int8_simulation_fp8 else [torch.int8], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) @pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) @pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment