Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -18,11 +18,11 @@ from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string
from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string
IMAGE_H = 28
IMAGE_W = 28
......@@ -189,12 +189,12 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
with te.autocast(
enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
......@@ -308,8 +308,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
......@@ -299,7 +299,7 @@ def _train(opts):
dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
......
......@@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# ...
```
**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support
**NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above.
......@@ -173,7 +173,7 @@ def parse_fsdp_args():
"--no-fp8",
action="store_true",
default=False,
help="Disables the te.fp8_autocast() context.",
help="Disables the te.autocast() context.",
)
parser.add_argument(
"--no-defer-init",
......@@ -284,11 +284,11 @@ def train(opts):
dtype=opts.dtype,
device="cuda",
)
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
# autocast needs to be given the FSDP process group for amax reductions
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
# calculate gradient and take training step outside the autocast context
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
......
......@@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with te.fp8_autocast(enabled=use_fp8):
with te.autocast(enabled=use_fp8):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
......@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=fp8, calibrating=True):
with te.autocast(enabled=fp8, calibrating=True):
output = model(data)
......@@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
with te.autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
......@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -7,6 +7,8 @@
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
......@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
......@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
......
......@@ -9,4 +9,4 @@ set -xe
mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
......@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
......@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -3,9 +3,11 @@
# See LICENSE for license information.
pip3 install onnxruntime==1.20.1
pip3 install onnxruntime_extensions==0.13.0
pip3 install onnxruntime
pip3 install onnxruntime_extensions
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -23,6 +23,7 @@ from build_tools.utils import (
cuda_version,
get_frameworks,
remove_dups,
min_python_version_str,
)
frameworks = get_frameworks()
......@@ -211,7 +212,7 @@ if __name__ == "__main__":
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8",
python_requires=f">={min_python_version_str()}",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
license_files=("LICENSE",),
......
......@@ -66,6 +66,13 @@ else()
add_executable(test_operator ${test_hip_sources})
endif()
# Add profiling and debug flags for CUDA compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package(OpenMP REQUIRED)
if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
......
......@@ -529,6 +529,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 2u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
......@@ -580,6 +586,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 1u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
......
......@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
......@@ -310,7 +311,8 @@ void performTest_x1(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
......@@ -481,7 +483,7 @@ void performTest_x2(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
......@@ -490,7 +492,7 @@ void performTest_x2(const ProcessingMethod processing_method,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
......
......@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows,
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
}
const size_t mismatches_elts = 32 * mismatches_scales;
......@@ -378,7 +379,7 @@ void performTest_x2(const size_t rows,
const double rel_tolerable_mismatches_limit = 1.0e-4;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
......@@ -386,7 +387,7 @@ void performTest_x2(const size_t rows,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
......@@ -394,6 +395,7 @@ void performTest_x2(const size_t rows,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
#include <fstream>
using namespace transformer_engine;
using namespace test;
namespace {
enum ActivationType {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) {
const __half2_raw raw_truncated_to_fp4e2m1_pair =
__nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1);
const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair);
const double truncated_to_fp4e2m1_x = static_cast<double>(truncated_to_fp4e2m1_pair.x);
const double truncated_to_fp4e2m1_y = static_cast<double>(truncated_to_fp4e2m1_pair.y);
return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y};
}
template <typename InputType>
std::vector<InputType> create_transpose(const InputType* const input, const size_t rows, size_t cols) {
std::vector<InputType> input_t(cols * rows);
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
const size_t idx_t = j * rows + i;
input_t[idx_t] = input[idx];
}
}
return input_t;
}
// Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax) {
constexpr float fp8_max = 448.0f; // 448.0f;
constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
// 1D Scaling: Original implementation with 1x16 blocks
template <typename InputType>
void quantize_nvfp4_1d(float (*OP)(const float),
const InputType* const input,
fp4e2m1x2* const output,
fp8e4m3* const scales,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax) {
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_X = 16;
const size_t blocks_X = divide_round_up(cols, block_size_X);
std::array<float, block_size_X> cache_buffer;
for (size_t i = 0; i < block_size_X; ++i) {
cache_buffer[i] = 0.0f;
}
for (size_t i = 0; i < rows; ++i) {
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t j_min = block_X * block_size_X;
const size_t j_max = j_min + block_size_X;
// Find block amax
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = j - j_min;
const float input_elt = static_cast<float>(input[idx]);
const float act_elt = OP(input_elt);
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
const float elt = static_cast<float>(static_cast<InputType>(act_elt));
cache_buffer[cache_idx] = elt;
block_amax = std::max(block_amax, std::abs(elt));
}
// 2. Compute E4M3 scaling factor
// Compute per-block encoding/decoding scaling factor
const float S_dec_b = block_amax / 6.0f;
// Scale & Store per-block decoding scaling factor
const float S_dec_b_fp8 = S_dec_b * S_enc;
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = static_cast<fp8e4m3>(S_dec_b_fp8);
const float scale_reciprocal = S_enc_b_fp8;
for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2;
const int cache_idx_x = j - j_min;
const int cache_idx_y = cache_idx_x + 1;
const float cached_x = cache_buffer[cache_idx_x];
const float cached_y = cache_buffer[cache_idx_y];
const float scaled_elt_x = cached_x * scale_reciprocal;
const float scaled_elt_y = cached_y * scale_reciprocal;
const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y};
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair;
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
}
}
}
}
// Compute 2D mathematical scaling factors (8x8 for 128x128 input)
template <typename InputType>
void compute_2d_mathematical_scales(float (*OP)(const float),
const InputType* const input,
const size_t rows,
const size_t cols,
const float global_amax,
std::vector<std::vector<fp8e4m3>>& math_scales) {
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
const size_t blocks_X = divide_round_up(cols, block_size_X);
math_scales.resize(blocks_Y, std::vector<fp8e4m3>(blocks_X));
for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) {
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t i_min = block_Y * block_size_Y;
const size_t i_max = std::min(i_min + block_size_Y, rows);
const size_t j_min = block_X * block_size_X;
const size_t j_max = std::min(j_min + block_size_X, cols);
// Find 2D block amax over entire 16x16 region
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const float input_elt = static_cast<float>(input[idx]);
const float act_elt = OP(input_elt);
const float elt = static_cast<float>(static_cast<InputType>(act_elt));
block_amax = std::max(block_amax, std::abs(elt));
}
}
// Compute E4M3 scaling factor for this 16x16 block
const float S_dec_b = block_amax / 6.0f;
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
math_scales[block_Y][block_X] = S_dec_b_fp8;
}
}
}
// 2D Scaling: NEW implementation with proper replication
template <typename InputType>
void quantize_nvfp4_2d(float (*OP)(const float),
const InputType* const input,
fp4e2m1x2* const output,
fp8e4m3* const scales,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax) {
// Step 1: Compute mathematical 8x8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
const size_t blocks_X = divide_round_up(cols, block_size_X);
// Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr
if (scales != nullptr) {
// Each of the 128 rows gets scaling factors from its corresponding 16×16 block
for (size_t i = 0; i < rows; ++i) {
const size_t block_Y = i / block_size_Y;
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = math_scales[block_Y][block_X];
}
}
}
// Step 3: Apply quantization using the mathematical scaling factors
std::array<std::array<float, block_size_X>, block_size_Y> cache_buffer;
for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) {
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t i_min = block_Y * block_size_Y;
const size_t i_max = std::min(i_min + block_size_Y, rows);
const size_t j_min = block_X * block_size_X;
const size_t j_max = std::min(j_min + block_size_X, cols);
// Get the scaling factor for this block
const float S_dec_b_fp8 = static_cast<float>(math_scales[block_Y][block_X]);
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
const float scale_reciprocal = S_enc_b_fp8;
// Process and cache data for this 16x16 block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx_y = i - i_min;
const size_t cache_idx_x = j - j_min;
const float input_elt = static_cast<float>(input[idx]);
const float act_elt = OP(input_elt);
const float elt = static_cast<float>(static_cast<InputType>(act_elt));
cache_buffer[cache_idx_y][cache_idx_x] = elt;
}
}
// Apply scaling to all elements in this 16x16 block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2;
const size_t cache_idx_y = i - i_min;
const size_t cache_idx_x1 = j - j_min;
const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1);
const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1];
const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ?
cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f;
const float scaled_elt_x = cached_x * scale_reciprocal;
const float scaled_elt_y = cached_y * scale_reciprocal;
const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y};
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair;
}
}
}
}
}
// Wrapper function that calls appropriate implementation based on 2D flag
template <typename InputType>
void quantize_nvfp4(float (*OP)(const float),
const InputType* const input,
fp4e2m1x2* const output,
fp8e4m3* const scales,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax,
const bool use_2d_quantization = false) {
if (use_2d_quantization) {
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
} else {
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
}
}
template <typename InputType>
void compute_ref(float (*OP)(const float),
const InputType* input,
fp4e2m1x2* output,
fp4e2m1x2* output_t,
fp8e4m3* scales,
fp8e4m3* scales_t,
const float global_amax,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const size_t scales_stride_t,
const bool use_2d_quantization = false)
{
std::vector<InputType> input_t = create_transpose(input, rows, cols);
if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
const size_t blocks_X = divide_round_up(cols, block_size_X);
// Step 2: Generate scales (128×8) by replicating row-wise
for (size_t i = 0; i < rows; ++i) {
const size_t block_Y = i / block_size_Y;
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = math_scales[block_Y][block_X];
}
}
// Step 3: Generate scales_t (128×8) with proper transposed block mapping
for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data
const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X
for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate
const size_t scale_idx = i * scales_stride_t + block_Y_new;
scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig];
}
}
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
// (This part processes the actual FP4 data using the mathematical scaling factors)
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled
} else {
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization);
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization);
}
}
void compare_nvfp4_tensors(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8) {
std::vector<std::string> mismatch_messages;
size_t total_mismatches = 0;
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; j += 2) {
const int idx = i * cols + j;
double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[idx/2]));
double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[idx/2]));
for (int k = 0; k < 2; ++k) {
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<fp4e2m1>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<fp4e2m1>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion) {
total_mismatches++;
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg);
// Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) {
std::cout << "Error in tensor " << name << ": " << msg << std::endl;
}
}
}
}
}
// Always report summary - either success or failure
std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl;
std::cout << "Total elements checked: " << (rows * cols) << std::endl;
if (total_mismatches > 0) {
std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
}
std::cout << "============================" << std::endl;
GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name;
} else {
std::cout << "STATUS: PASSED for output" << std::endl;
std::cout << "All elements match within tolerance!" << std::endl;
std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl;
std::cout << "============================" << std::endl;
}
}
// Optional: Function to dump tensor data to files for detailed analysis
void dump_nvfp4_tensor_data(const std::string& prefix,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols) {
std::string test_file = prefix + "_test.txt";
std::string ref_file = prefix + "_ref.txt";
std::string diff_file = prefix + "_diff.txt";
std::ofstream test_out(test_file);
std::ofstream ref_out(ref_file);
std::ofstream diff_out(diff_file);
if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) {
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; j += 2) {
const int idx = i * cols + j;
double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[idx/2]));
double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[idx/2]));
for (int k = 0; k < 2; ++k) {
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
const int pos = idx + k;
test_out << "pos[" << pos << "] = " << t << std::endl;
ref_out << "pos[" << pos << "] = " << r << std::endl;
diff_out << "pos[" << pos << "] test=" << t << " ref=" << r
<< " abs_diff=" << fabs(t - r)
<< " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl;
}
}
}
std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl;
} else {
std::cout << "WARNING: Could not open files for tensor data dump" << std::endl;
}
}
void print_detailed_tensor_comparison(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols) {
printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n",
name.c_str(), rows, cols, rows * cols);
const int total_elements = rows * cols;
const int check_count = 128;
printf("--- FIRST %d ELEMENTS ---\n", check_count);
printf("Index | Test_Value | Ref_Value | Match\n");
printf("------|---------------|---------------|-------\n");
for (int i = 0; i < std::min(check_count, total_elements); ++i) {
double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[i/2]));
double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[i/2]));
double t = (i % 2 == 0) ? test_pair.x : test_pair.y;
double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y;
bool match = (fabs(t - r) < 1e-6);
printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗");
}
if (total_elements > 2 * check_count) {
printf("\n--- LAST %d ELEMENTS ---\n", check_count);
printf("Index | Test_Value | Ref_Value | Match\n");
printf("------|---------------|---------------|-------\n");
for (int i = total_elements - check_count; i < total_elements; ++i) {
double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[i/2]));
double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[i/2]));
double t = (i % 2 == 0) ? test_pair.x : test_pair.y;
double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y;
bool match = (fabs(t - r) < 1e-6);
printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗");
}
}
printf("==================================\n");
}
void compareResults_nvfp4(const Tensor &test,
const void *ref, const void *ref_t, const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) {
if (if_on_gpus) test.to_cpu();
const fp4e2m1 *test_data = test.rowwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *ref_data = reinterpret_cast<const fp4e2m1*>(ref);
const fp4e2m1 *ref_data_t = reinterpret_cast<const fp4e2m1*>(ref_t);
// Print detailed element-by-element comparison
// print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols);
// print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows);
// Optionally dump tensor data to files for detailed analysis
if (dump_data) {
dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols);
dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
}
compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol);
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
}
template <typename InputType>
void performTest(float (*OP)(const float),
const std::vector<size_t>& shape) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = DType::kFloat4E2M1;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
// Use get_scale_tensor_dims for NVFP4 scale tensor dimensions
// Now that CheckScaleTensorShape is fixed, this should work correctly
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, 1, 16);
const std::array<size_t,4> scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
const size_t unpadded_blocks_Y_t = scale_dims_t[0];
const size_t unpadded_blocks_X_t = scale_dims_t[1];
const size_t blocks_Y_t = scale_dims_t[2];
const size_t blocks_X_t = scale_dims_t[3];
const size_t scales_stride_t = blocks_X_t;
Tensor input("input", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING);
std::unique_ptr<fp4e2m1x2[]> ref_output = std::make_unique<fp4e2m1x2[]>(rows * (cols / 2));
std::unique_ptr<fp4e2m1x2[]> ref_output_t = std::make_unique<fp4e2m1x2[]>(cols * (rows / 2));
std::unique_ptr<fp8e4m3[]> ref_scales = std::make_unique<fp8e4m3[]>(blocks_Y * blocks_X);
std::unique_ptr<fp8e4m3[]> ref_scales_t = std::make_unique<fp8e4m3[]>(blocks_Y_t * blocks_X_t);
fillCase<fp32>(&input, InputsFillCase::uniform);
// Find global amax
float amax = 0.0f;
const InputType* input_dptr = input.rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
amax = fmaxf(amax, static_cast<float>(input_dptr[idx]));
}
}
// Set 2nd stage NVFP4 scaling factor
output.set_scale(amax);
bool use_2d_quantization = false;
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
output.scale(),
rows,
cols,
scales_stride,
scales_stride_t,
use_2d_quantization);
QuantizationConfigWrapper quant_config;
// Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence
rng_state.from_cpu();
quant_config.set_stochastic_rounding(false);
quant_config.set_rng_state(rng_state.data());
// Set 2D quantization based on compile-time flag
quant_config.set_nvfp4_2d_quantization(use_2d_quantization);
// Call appropriate function based on operation type
// Activation functions take 3 parameters (input, output, stream)
// nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream)
if (OP == &gelu) {
nvte_gelu(input.data(), output.data(), 0);
} else if (OP == &silu) {
nvte_silu(input.data(), output.data(), 0);
} else if (OP == &relu) {
nvte_relu(input.data(), output.data(), 0);
} else if (OP == &qgelu) {
nvte_qgelu(input.data(), output.data(), 0);
} else if (OP == &srelu) {
nvte_srelu(input.data(), output.data(), 0);
} else {
nvte_quantize_v2(input.data(), output.data(), quant_config, 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
if (err != cudaSuccess) {
printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err));
}
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
const double atol = 0.05;
const double rtol = 0.1;
// Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_ptr = ref_scales.get();
const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get();
size_t scale_mismatches_num = 0;
compare_scaling_factors<fp8e4m3>("scales", output.rowwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
scale_mismatches_num);
compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);
}
std::vector<std::vector<size_t>> tensor_dims = {
{32, 32},
{32, 64},
{64, 32},
{64, 96},
{128, 128},
{256, 256},
{512, 512},
{1024, 1024},
{2048, 2048},
{128, 256},
{8192, 128},
{2048, 160},
{8, 32, 1024},
{16, 8, 4, 512},
{1024, 16384},
{4096, 13312},
};
// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
ActivationType::GeLU,
ActivationType::SiLU,
ActivationType::ReLU,
ActivationType::QGeLU,
ActivationType::SReLU,
};
} // namespace
class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType,
std::vector<size_t>,
transformer_engine::DType>> {};
TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ActivationType Act_type = std::get<0>(GetParam());
const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
// Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) {
GTEST_SKIP();
}
// Forward activations
auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &gelu; break;
case ActivationType::SiLU: OP = &silu; break;
case ActivationType::ReLU: OP = &relu; break;
case ActivationType::QGeLU: OP = &qgelu; break;
case ActivationType::SReLU: OP = &srelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims);
);
}
std::string to_string(const ActivationType Act_type) {
switch (Act_type) {
case ActivationType::Identity: return "CAST_ONLY";
case ActivationType::GeLU: return "GeLU";
case ActivationType::SiLU: return "SiLU";
case ActivationType::ReLU: return "ReLU";
case ActivationType::QGeLU: return "QGeLU";
case ActivationType::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
FusedCastTransposeNVFP4TestSuite,
::testing::Combine(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<2>(info.param));
return name;
});
......@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y));
}
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
return DIVUP(x, y) * y;
}
struct scale_inv_meta {
std::vector<size_t> shape;
DType type;
......@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul, 4ul};
{
auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type = DType::kFloat8E8M0;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
{
auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
NVTE_CHECK(last_dim % 32 == 0);
NVTE_CHECK(first_dim % 32 == 0);
scale_inv_meta ret_rowwise, ret_colwise;
size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y, scale_dim_X};
size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};
ret_rowwise.type = DType::kFloat8E4M3;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
ret_colwise.type = DType::kFloat8E4M3;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
......@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
} else {
// Same shape for MX
// Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
......@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name,
std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
}
}
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) {
const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float));
......@@ -314,8 +372,14 @@ Tensor::Tensor(const std::string& name,
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
} else {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode());
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
// Used for NVFP4 second stage scaling
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
......@@ -350,13 +414,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
const DType colwise_type = tensor_.dtype();
const size_t colwise_size = bytes(s, colwise_type);
cudaMemcpy(cpu_data_columnwise_.get(),
tensor_.get_columnwise_data().data_ptr,
size,
colwise_size,
cudaMemcpyDeviceToHost);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(),
......@@ -368,8 +435,7 @@ void Tensor::to_cpu() const {
sizeof(float),
cudaMemcpyDeviceToHost);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
......@@ -398,15 +464,15 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
......@@ -423,7 +489,7 @@ void Tensor::from_cpu() const {
}
void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale;
......@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) {
}
void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_);
}
......@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(tensor_.shape(), tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1) {
......@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void Tensor::shareFP8Meta(const Tensor &other) {
if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
|| isFp4Type(dtype()) && isFp4Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
......@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
}
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
template <typename T>
struct CastToType;
template <>
struct CastToType<uint8_t> {
using type = int;
};
template <>
struct CastToType<fp8e4m3> {
using type = float;
};
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{
using UpcastType = typename CastToType<T>::type;
auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
......@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j;
const int test_val = static_cast<int>(test[idx]);
const int ref_val = static_cast<int>(ref[idx]);
const int abs_delta = std::abs(test_val - ref_val);
float t, r;
bool assertion = false;
if (abs_delta > atol) {
if (std::is_same<T, uint8_t>::value) {
t = static_cast<float>(test[idx]);
r = static_cast<float>(ref[idx]);
assertion = std::abs(t - r) > atol;
} else {
t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
&& (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
if (mismatch) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
}
if (assertion) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
......@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
<< static_cast<UpcastType>(test[index]) << " vs "
<< static_cast<UpcastType>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
......@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
}
}
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
std::pair<double, double> getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
......@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
}
bool isFp4Type(DType type) {
return type == DType::kFloat4E2M1;
}
int32_t getDeviceComputeCapability() {
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
......@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols) {
const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool is_rowwise = (block_size_rows == 1)
&& ((block_size_cols == 32) || (block_size_cols == 16));
const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment