Unverified Commit 3f5b4754 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core][PyTorch] NVFP4 recipe (#2177)



* Add NVFP4 recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add MathDx dependency to GitHub builds
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from GitHub Copilot
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move 2x shape logic from core to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compilation errors with CUDA 12.1
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* SM 70 is not supported in CUDA 13
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Typo
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Revert "Move 2x shape logic from core to PyTorch"

This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Added dequantize kernel for FP4
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 support with fusible ops

Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix logic for 2x shapes and move to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG test model config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug NVFP4 tensor size function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Proper handling of the RNG state
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Test SR properly
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix workspace size for GEMM heuristic.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compile error in C++ NVFP4 test

Some some numeric errors when blocks are all zero.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix distrbuted test problem shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* proper assert dim for low precision AG TP
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up duplicated code in nvfp4_utils.cuh
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pylint: disable=unused-argument
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* `nvte_cublas_gemm_v2` to take alpha pointer (#12)

* make nvte_cublas_gemm_v2 to take alpha/beta pointers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* users are expected to pass a valid C_tensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* typos
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* API to have const float* alpha
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Minor tweaks

Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug IMA with alpha pointer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support fused amax kernels with NVFP4 quantization
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused amax with cuDNN LayerNorm kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 cases to distributed tests for TE ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change assert to NVTE_CHECK in the hadamard cast fusion
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix compile error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use global thread IDs for Philox subsequences
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shape checks for NVFP4 cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not fuse amax if cuDNN normalization is forced by envvar
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfeef1a2
......@@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
scale_padding_to = 1
permute_scale = False
TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}
def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")
assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}
timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6
input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems
total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)
throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)
print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps
# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()
if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")
shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]
if args.profile:
shapes = [
(16384, 6144),
]
data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)
df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
......@@ -234,15 +234,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
archs = os.getenv("NVTE_CUDA_ARCHS")
if archs is None:
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
if version >= (13, 0):
os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120"
archs = "75;80;89;90;100;100a;103a;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8):
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120"
archs = "70;80;89;90;100;100a;120"
else:
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90"
return os.getenv("NVTE_CUDA_ARCHS")
archs = "70;80;89;90"
return archs
def cuda_version() -> Tuple[int, ...]:
......
......@@ -3,8 +3,7 @@
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip",
"torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
......
......@@ -31,6 +31,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
......
......@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
......
......@@ -11,6 +11,7 @@ add_executable(test_operator
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
......@@ -31,6 +32,13 @@ add_executable(test_operator
test_swap_first_dims.cu
../test_common.cu)
# Add profiling and debug flags for CUDA compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
......
......@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
......@@ -310,7 +311,8 @@ void performTest_x1(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
......@@ -481,7 +483,7 @@ void performTest_x2(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
......@@ -490,7 +492,7 @@ void performTest_x2(const ProcessingMethod processing_method,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
......
......@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows,
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
}
const size_t mismatches_elts = 32 * mismatches_scales;
......@@ -378,7 +379,7 @@ void performTest_x2(const size_t rows,
const double rel_tolerable_mismatches_limit = 1.0e-4;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
......@@ -386,7 +387,7 @@ void performTest_x2(const size_t rows,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
......@@ -394,6 +395,7 @@ void performTest_x2(const size_t rows,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
#include <fstream>
using namespace transformer_engine;
using namespace test;
namespace {
enum ActivationType {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) {
const __half2_raw raw_truncated_to_fp4e2m1_pair =
__nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1);
const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair);
const double truncated_to_fp4e2m1_x = static_cast<double>(truncated_to_fp4e2m1_pair.x);
const double truncated_to_fp4e2m1_y = static_cast<double>(truncated_to_fp4e2m1_pair.y);
return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y};
}
template <typename InputType>
std::vector<InputType> create_transpose(const InputType* const input, const size_t rows, size_t cols) {
std::vector<InputType> input_t(cols * rows);
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
const size_t idx_t = j * rows + i;
input_t[idx_t] = input[idx];
}
}
return input_t;
}
// Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax) {
constexpr float fp8_max = 448.0f; // 448.0f;
constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
// 1D Scaling: Original implementation with 1x16 blocks
template <typename InputType>
void quantize_nvfp4_1d(float (*OP)(const float),
const InputType* const input,
fp4e2m1x2* const output,
fp8e4m3* const scales,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax) {
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_X = 16;
const size_t blocks_X = divide_round_up(cols, block_size_X);
std::array<float, block_size_X> cache_buffer;
for (size_t i = 0; i < block_size_X; ++i) {
cache_buffer[i] = 0.0f;
}
for (size_t i = 0; i < rows; ++i) {
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t j_min = block_X * block_size_X;
const size_t j_max = j_min + block_size_X;
// Find block amax
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = j - j_min;
const float input_elt = static_cast<float>(input[idx]);
const float act_elt = OP(input_elt);
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
const float elt = static_cast<float>(static_cast<InputType>(act_elt));
cache_buffer[cache_idx] = elt;
block_amax = std::max(block_amax, std::abs(elt));
}
// 2. Compute E4M3 scaling factor
// Compute per-block encoding/decoding scaling factor
const float S_dec_b = block_amax / 6.0f;
// Scale & Store per-block decoding scaling factor
const float S_dec_b_fp8 = S_dec_b * S_enc;
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = static_cast<fp8e4m3>(S_dec_b_fp8);
const float scale_reciprocal = S_enc_b_fp8;
for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2;
const int cache_idx_x = j - j_min;
const int cache_idx_y = cache_idx_x + 1;
const float cached_x = cache_buffer[cache_idx_x];
const float cached_y = cache_buffer[cache_idx_y];
const float scaled_elt_x = cached_x * scale_reciprocal;
const float scaled_elt_y = cached_y * scale_reciprocal;
const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y};
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair;
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
}
}
}
}
// Compute 2D mathematical scaling factors (8x8 for 128x128 input)
template <typename InputType>
void compute_2d_mathematical_scales(float (*OP)(const float),
const InputType* const input,
const size_t rows,
const size_t cols,
const float global_amax,
std::vector<std::vector<fp8e4m3>>& math_scales) {
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
const size_t blocks_X = divide_round_up(cols, block_size_X);
math_scales.resize(blocks_Y, std::vector<fp8e4m3>(blocks_X));
for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) {
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t i_min = block_Y * block_size_Y;
const size_t i_max = std::min(i_min + block_size_Y, rows);
const size_t j_min = block_X * block_size_X;
const size_t j_max = std::min(j_min + block_size_X, cols);
// Find 2D block amax over entire 16x16 region
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const float input_elt = static_cast<float>(input[idx]);
const float act_elt = OP(input_elt);
const float elt = static_cast<float>(static_cast<InputType>(act_elt));
block_amax = std::max(block_amax, std::abs(elt));
}
}
// Compute E4M3 scaling factor for this 16x16 block
const float S_dec_b = block_amax / 6.0f;
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
math_scales[block_Y][block_X] = S_dec_b_fp8;
}
}
}
// 2D Scaling: NEW implementation with proper replication
template <typename InputType>
void quantize_nvfp4_2d(float (*OP)(const float),
const InputType* const input,
fp4e2m1x2* const output,
fp8e4m3* const scales,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax) {
// Step 1: Compute mathematical 8x8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
const size_t blocks_X = divide_round_up(cols, block_size_X);
// Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr
if (scales != nullptr) {
// Each of the 128 rows gets scaling factors from its corresponding 16×16 block
for (size_t i = 0; i < rows; ++i) {
const size_t block_Y = i / block_size_Y;
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = math_scales[block_Y][block_X];
}
}
}
// Step 3: Apply quantization using the mathematical scaling factors
std::array<std::array<float, block_size_X>, block_size_Y> cache_buffer;
for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) {
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t i_min = block_Y * block_size_Y;
const size_t i_max = std::min(i_min + block_size_Y, rows);
const size_t j_min = block_X * block_size_X;
const size_t j_max = std::min(j_min + block_size_X, cols);
// Get the scaling factor for this block
const float S_dec_b_fp8 = static_cast<float>(math_scales[block_Y][block_X]);
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
const float scale_reciprocal = S_enc_b_fp8;
// Process and cache data for this 16x16 block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx_y = i - i_min;
const size_t cache_idx_x = j - j_min;
const float input_elt = static_cast<float>(input[idx]);
const float act_elt = OP(input_elt);
const float elt = static_cast<float>(static_cast<InputType>(act_elt));
cache_buffer[cache_idx_y][cache_idx_x] = elt;
}
}
// Apply scaling to all elements in this 16x16 block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2;
const size_t cache_idx_y = i - i_min;
const size_t cache_idx_x1 = j - j_min;
const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1);
const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1];
const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ?
cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f;
const float scaled_elt_x = cached_x * scale_reciprocal;
const float scaled_elt_y = cached_y * scale_reciprocal;
const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y};
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair;
}
}
}
}
}
// Wrapper function that calls appropriate implementation based on 2D flag
template <typename InputType>
void quantize_nvfp4(float (*OP)(const float),
const InputType* const input,
fp4e2m1x2* const output,
fp8e4m3* const scales,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax,
const bool use_2d_quantization = false) {
if (use_2d_quantization) {
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
} else {
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
}
}
template <typename InputType>
void compute_ref(float (*OP)(const float),
const InputType* input,
fp4e2m1x2* output,
fp4e2m1x2* output_t,
fp8e4m3* scales,
fp8e4m3* scales_t,
const float global_amax,
const size_t rows,
const size_t cols,
const size_t scales_stride,
const size_t scales_stride_t,
const bool use_2d_quantization = false)
{
std::vector<InputType> input_t = create_transpose(input, rows, cols);
if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
const size_t blocks_X = divide_round_up(cols, block_size_X);
// Step 2: Generate scales (128×8) by replicating row-wise
for (size_t i = 0; i < rows; ++i) {
const size_t block_Y = i / block_size_Y;
for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = math_scales[block_Y][block_X];
}
}
// Step 3: Generate scales_t (128×8) with proper transposed block mapping
for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data
const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X
for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate
const size_t scale_idx = i * scales_stride_t + block_Y_new;
scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig];
}
}
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
// (This part processes the actual FP4 data using the mathematical scaling factors)
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled
} else {
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization);
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization);
}
}
void compare_nvfp4_tensors(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8) {
std::vector<std::string> mismatch_messages;
size_t total_mismatches = 0;
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; j += 2) {
const int idx = i * cols + j;
double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[idx/2]));
double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[idx/2]));
for (int k = 0; k < 2; ++k) {
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<fp4e2m1>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<fp4e2m1>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion) {
total_mismatches++;
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg);
// Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) {
std::cout << "Error in tensor " << name << ": " << msg << std::endl;
}
}
}
}
}
// Always report summary - either success or failure
std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl;
std::cout << "Total elements checked: " << (rows * cols) << std::endl;
if (total_mismatches > 0) {
std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
}
std::cout << "============================" << std::endl;
GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name;
} else {
std::cout << "STATUS: PASSED for output" << std::endl;
std::cout << "All elements match within tolerance!" << std::endl;
std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl;
std::cout << "============================" << std::endl;
}
}
// Optional: Function to dump tensor data to files for detailed analysis
void dump_nvfp4_tensor_data(const std::string& prefix,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols) {
std::string test_file = prefix + "_test.txt";
std::string ref_file = prefix + "_ref.txt";
std::string diff_file = prefix + "_diff.txt";
std::ofstream test_out(test_file);
std::ofstream ref_out(ref_file);
std::ofstream diff_out(diff_file);
if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) {
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; j += 2) {
const int idx = i * cols + j;
double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[idx/2]));
double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[idx/2]));
for (int k = 0; k < 2; ++k) {
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
const int pos = idx + k;
test_out << "pos[" << pos << "] = " << t << std::endl;
ref_out << "pos[" << pos << "] = " << r << std::endl;
diff_out << "pos[" << pos << "] test=" << t << " ref=" << r
<< " abs_diff=" << fabs(t - r)
<< " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl;
}
}
}
std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl;
} else {
std::cout << "WARNING: Could not open files for tensor data dump" << std::endl;
}
}
void print_detailed_tensor_comparison(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols) {
printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n",
name.c_str(), rows, cols, rows * cols);
const int total_elements = rows * cols;
const int check_count = 128;
printf("--- FIRST %d ELEMENTS ---\n", check_count);
printf("Index | Test_Value | Ref_Value | Match\n");
printf("------|---------------|---------------|-------\n");
for (int i = 0; i < std::min(check_count, total_elements); ++i) {
double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[i/2]));
double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[i/2]));
double t = (i % 2 == 0) ? test_pair.x : test_pair.y;
double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y;
bool match = (fabs(t - r) < 1e-6);
printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗");
}
if (total_elements > 2 * check_count) {
printf("\n--- LAST %d ELEMENTS ---\n", check_count);
printf("Index | Test_Value | Ref_Value | Match\n");
printf("------|---------------|---------------|-------\n");
for (int i = total_elements - check_count; i < total_elements; ++i) {
double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&test_data[i/2]));
double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast<const fp4e2m1x2*>(&ref_data[i/2]));
double t = (i % 2 == 0) ? test_pair.x : test_pair.y;
double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y;
bool match = (fabs(t - r) < 1e-6);
printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗");
}
}
printf("==================================\n");
}
void compareResults_nvfp4(const Tensor &test,
const void *ref, const void *ref_t, const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) {
if (if_on_gpus) test.to_cpu();
const fp4e2m1 *test_data = test.rowwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *ref_data = reinterpret_cast<const fp4e2m1*>(ref);
const fp4e2m1 *ref_data_t = reinterpret_cast<const fp4e2m1*>(ref_t);
// Print detailed element-by-element comparison
// print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols);
// print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows);
// Optionally dump tensor data to files for detailed analysis
if (dump_data) {
dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols);
dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
}
compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol);
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
}
template <typename InputType>
void performTest(float (*OP)(const float),
const std::vector<size_t>& shape) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = DType::kFloat4E2M1;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
// Use get_scale_tensor_dims for NVFP4 scale tensor dimensions
// Now that CheckScaleTensorShape is fixed, this should work correctly
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, 1, 16);
const std::array<size_t,4> scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
const size_t unpadded_blocks_Y_t = scale_dims_t[0];
const size_t unpadded_blocks_X_t = scale_dims_t[1];
const size_t blocks_Y_t = scale_dims_t[2];
const size_t blocks_X_t = scale_dims_t[3];
const size_t scales_stride_t = blocks_X_t;
Tensor input("input", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING);
std::unique_ptr<fp4e2m1x2[]> ref_output = std::make_unique<fp4e2m1x2[]>(rows * (cols / 2));
std::unique_ptr<fp4e2m1x2[]> ref_output_t = std::make_unique<fp4e2m1x2[]>(cols * (rows / 2));
std::unique_ptr<fp8e4m3[]> ref_scales = std::make_unique<fp8e4m3[]>(blocks_Y * blocks_X);
std::unique_ptr<fp8e4m3[]> ref_scales_t = std::make_unique<fp8e4m3[]>(blocks_Y_t * blocks_X_t);
fillCase<fp32>(&input, InputsFillCase::uniform);
// Find global amax
float amax = 0.0f;
const InputType* input_dptr = input.rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
amax = fmaxf(amax, static_cast<float>(input_dptr[idx]));
}
}
// Set 2nd stage NVFP4 scaling factor
output.set_scale(amax);
bool use_2d_quantization = false;
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
output.scale(),
rows,
cols,
scales_stride,
scales_stride_t,
use_2d_quantization);
QuantizationConfigWrapper quant_config;
// Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence
rng_state.from_cpu();
quant_config.set_stochastic_rounding(false);
quant_config.set_rng_state(rng_state.data());
// Set 2D quantization based on compile-time flag
quant_config.set_nvfp4_2d_quantization(use_2d_quantization);
// Call appropriate function based on operation type
// Activation functions take 3 parameters (input, output, stream)
// nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream)
if (OP == &gelu) {
nvte_gelu(input.data(), output.data(), 0);
} else if (OP == &silu) {
nvte_silu(input.data(), output.data(), 0);
} else if (OP == &relu) {
nvte_relu(input.data(), output.data(), 0);
} else if (OP == &qgelu) {
nvte_qgelu(input.data(), output.data(), 0);
} else if (OP == &srelu) {
nvte_srelu(input.data(), output.data(), 0);
} else {
nvte_quantize_v2(input.data(), output.data(), quant_config, 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
if (err != cudaSuccess) {
printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err));
}
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
const double atol = 0.05;
const double rtol = 0.1;
// Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_ptr = ref_scales.get();
const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get();
size_t scale_mismatches_num = 0;
compare_scaling_factors<fp8e4m3>("scales", output.rowwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
scale_mismatches_num);
compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);
}
std::vector<std::vector<size_t>> tensor_dims = {
{32, 32},
{32, 64},
{64, 32},
{64, 96},
{128, 128},
{256, 256},
{512, 512},
{1024, 1024},
{2048, 2048},
{128, 256},
{8192, 128},
{2048, 160},
{8, 32, 1024},
{16, 8, 4, 512},
{1024, 16384},
{4096, 13312},
};
// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
ActivationType::GeLU,
ActivationType::SiLU,
ActivationType::ReLU,
ActivationType::QGeLU,
ActivationType::SReLU,
};
} // namespace
class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType,
std::vector<size_t>,
transformer_engine::DType>> {};
TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ActivationType Act_type = std::get<0>(GetParam());
const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
// Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) {
GTEST_SKIP();
}
// Forward activations
auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &gelu; break;
case ActivationType::SiLU: OP = &silu; break;
case ActivationType::ReLU: OP = &relu; break;
case ActivationType::QGeLU: OP = &qgelu; break;
case ActivationType::SReLU: OP = &srelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims);
);
}
std::string to_string(const ActivationType Act_type) {
switch (Act_type) {
case ActivationType::Identity: return "CAST_ONLY";
case ActivationType::GeLU: return "GeLU";
case ActivationType::SiLU: return "SiLU";
case ActivationType::ReLU: return "ReLU";
case ActivationType::QGeLU: return "QGeLU";
case ActivationType::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
FusedCastTransposeNVFP4TestSuite,
::testing::Combine(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<2>(info.param));
return name;
});
......@@ -107,6 +107,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y));
}
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
return DIVUP(x, y) * y;
}
struct scale_inv_meta {
std::vector<size_t> shape;
DType type;
......@@ -143,21 +147,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul, 4ul};
{
auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type = DType::kFloat8E8M0;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
{
auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
NVTE_CHECK(last_dim % 32 == 0);
NVTE_CHECK(first_dim % 32 == 0);
scale_inv_meta ret_rowwise, ret_colwise;
size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y, scale_dim_X};
size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};
ret_rowwise.type = DType::kFloat8E4M3;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
ret_colwise.type = DType::kFloat8E4M3;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
......@@ -176,13 +230,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
size_t scale_dim_0 = DIVUP(first_dim, 128lu);
size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
size_t scale_dim_0 = DIVUP(last_dim, 128lu);
size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
......@@ -202,13 +256,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
size_t scale_dim_0 = DIVUP(last_dim, 128lu);
size_t scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
size_t scale_dim_0 = DIVUP(first_dim, 128lu);
size_t scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
......@@ -250,14 +304,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
} else {
// Same shape for MX
// Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
......@@ -283,10 +338,13 @@ Tensor::Tensor(const std::string& name,
std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
}
}
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) {
const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float));
......@@ -310,8 +368,14 @@ Tensor::Tensor(const std::string& name,
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
} else {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode());
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
// Used for NVFP4 second stage scaling
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
......@@ -346,13 +410,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
const DType colwise_type = tensor_.dtype();
const size_t colwise_size = bytes(s, colwise_type);
cudaMemcpy(cpu_data_columnwise_.get(),
tensor_.get_columnwise_data().data_ptr,
size,
colwise_size,
cudaMemcpyDeviceToHost);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(),
......@@ -364,8 +431,7 @@ void Tensor::to_cpu() const {
sizeof(float),
cudaMemcpyDeviceToHost);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
......@@ -394,15 +460,15 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
......@@ -419,7 +485,7 @@ void Tensor::from_cpu() const {
}
void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale;
......@@ -429,7 +495,7 @@ void Tensor::set_scale(float scale) {
}
void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_);
}
......@@ -437,8 +503,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(tensor_.shape(), tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1) {
......@@ -468,7 +533,8 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void Tensor::shareFP8Meta(const Tensor &other) {
if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
|| isFp4Type(dtype()) && isFp4Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
......@@ -681,12 +747,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
}
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
template <typename T>
struct CastToType;
template <>
struct CastToType<uint8_t> {
using type = int;
};
template <>
struct CastToType<fp8e4m3> {
using type = float;
};
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{
using UpcastType = typename CastToType<T>::type;
auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
......@@ -696,11 +780,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j;
const int test_val = static_cast<int>(test[idx]);
const int ref_val = static_cast<int>(ref[idx]);
const int abs_delta = std::abs(test_val - ref_val);
float t, r;
bool assertion = false;
if (abs_delta > atol) {
if (std::is_same<T, uint8_t>::value) {
t = static_cast<float>(test[idx]);
r = static_cast<float>(ref[idx]);
assertion = std::abs(t - r) > atol;
} else {
t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
&& (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
if (mismatch) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
}
if (assertion) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
......@@ -708,8 +812,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
<< static_cast<UpcastType>(test[index]) << " vs "
<< static_cast<UpcastType>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
......@@ -718,6 +822,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
}
}
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
std::pair<double, double> getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
......@@ -873,6 +993,10 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
}
bool isFp4Type(DType type) {
return type == DType::kFloat4E2M1;
}
int32_t getDeviceComputeCapability() {
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
......@@ -894,7 +1018,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols) {
const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool is_rowwise = (block_size_rows == 1)
&& ((block_size_cols == 32) || (block_size_cols == 16));
const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise
......
......@@ -62,6 +62,8 @@ using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif
template <typename T>
......@@ -223,7 +225,9 @@ class Tensor {
float scale() const {
if(scale_cpu_data_) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!");
NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
"Invalid scaling_mode!");
to_cpu();
return *scale_cpu_data_;
} else {
......@@ -237,6 +241,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -250,6 +256,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -304,10 +312,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
......@@ -456,13 +464,15 @@ void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols);
......@@ -484,6 +494,7 @@ const std::string& caseName(InputsFillCase type);
extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
bool isFp4Type(DType type);
int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
......@@ -561,7 +572,7 @@ constexpr int32_t blackwellComputeCapability = 100;
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
......@@ -580,7 +591,7 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type MARKED TEST 2."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
......@@ -588,7 +599,7 @@ constexpr int32_t blackwellComputeCapability = 100;
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
......@@ -613,5 +624,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type MARKED TEST 4."); \
NVTE_ERROR("Invalid type."); \
}
......@@ -9,6 +9,7 @@ import datetime
import os
import sys
from functools import wraps
import math
import transformer_engine.pytorch as te
import torch
......@@ -20,10 +21,15 @@ from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
NVFP4BlockScaling,
Format,
Recipe,
QParams,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors
SEQ_LEN, BATCH_SIZE = 16, 16
......@@ -47,6 +53,14 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
def nvfp4_vanilla():
nvfp4_recipe = NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = QParams()
nvfp4_recipe.fp4_quant_fwd_weight = QParams()
nvfp4_recipe.fp4_quant_bwd_grad = QParams()
return nvfp4_recipe
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
......@@ -59,6 +73,8 @@ def quantization_recipe() -> Recipe:
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
if QUANTIZATION == "nvfp4":
return nvfp4_vanilla()
return te.fp8.get_default_fp8_recipe()
......@@ -96,10 +112,14 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
if QUANTIZATION in ("fp8", "mxfp8"):
if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"):
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
# For fp8 block scaling, block size is 128,
# and to make low precision TP work, input tensor
# must be 128x128 divisible to be eligible for
# low precision All-Gather when needed
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
......@@ -107,6 +127,7 @@ def main(argv=None, namespace=None):
test_dict = [
test_quantizer,
test_quantized_all_gather,
test_linear,
test_layernorm,
test_layernorm_linear,
......@@ -176,6 +197,9 @@ def _get_tolerances(dtype):
# row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION == "nvfp4":
# TODO(zhongboz): investigate why the tolerance is so large
return {"rtol": 0.125, "atol": 0.12}
elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625}
......@@ -326,24 +350,36 @@ def _alloc_main_grad(model_single_node, model_distributed):
###############################################
# Quantizer #
###############################################
def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size):
"""
quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
"""
if quantizer_class == Float8CurrentScalingQuantizer:
quantizer_dist = quantizer_class(
fp8_dtype=fp8_dtype,
fp8_dtype=low_precision_dtype,
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
)
quantizer = quantizer_class(
fp8_dtype=fp8_dtype,
fp8_dtype=low_precision_dtype,
device=device,
with_amax_reduction=False,
)
return quantizer, quantizer_dist
elif quantizer_class == NVFP4Quantizer:
quantizer_dist = quantizer_class(
fp4_dtype=low_precision_dtype,
with_amax_reduction=True,
amax_reduction_group=tp_group,
)
quantizer = quantizer_class(
fp4_dtype=low_precision_dtype,
with_amax_reduction=False,
amax_reduction_group=None,
)
return quantizer, quantizer_dist
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_class}")
......@@ -414,6 +450,194 @@ def test_quantizer():
_test_quantizer(input_dtype, fp8_dtype)
############################################
# Quantized All-Gather #
############################################
def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape):
"""
Zero padding the scale_inv.
scale_inv shape is the padded shape, but not zero padded
unpadded_shape is the original shape before padding
"""
dim0, dim1 = scale_inv.shape
unpadded_dim0, unpadded_dim1 = unpadded_shape
pad_dim0 = (128 - unpadded_dim0 % 128) % 128
pad_dim1 = (4 - unpadded_dim1 % 4) % 4
new_dim0 = unpadded_dim0 + pad_dim0
new_dim1 = unpadded_dim1 + pad_dim1
assert dim0 == new_dim0
assert dim1 == new_dim1
# return input if no padding is needed
if pad_dim0 == 0 and pad_dim1 == 0:
return scale_inv
# unpad first to remove random bits from torch empty
scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous()
# using torch padding
new_scale_inv = torch.nn.functional.pad(
scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0
)
assert new_scale_inv.shape == (new_dim0, new_dim1)
return new_scale_inv
def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise):
"""
Calculate the unpadded shape of the scale_inv tensor.
"""
M, K = 1, 1
M = math.prod(input_shape[:-1])
K = input_shape[-1]
if quantizer_cls == NVFP4Quantizer:
if columnwise:
outer = K
inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE)
return (outer, inner)
else:
outer = M
inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE)
return (outer, inner)
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_cls}")
@run_distributed_test()
def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):
"""Test the quantizer under distributed settings.
Args:
input_dtype (torch.dtype): The data type of the input.
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
"""
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
# high precision input
x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
# set one element of the input to a very large value, which doesn't live in rank 0 after the split
# to test the amax reduction on purpose
# x_hp_cpu[M - 1, N - 1] = 1e4
# get the unpadded shapes
unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False)
unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True)
# rank 0 takes the full copy and quantize with GPU 0 for verification
if WORLD_RANK == 0:
x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]
# Create quantizers
quantizer, quantizer_dist = _construct_quantizer(
quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
)
# quantize the entire input
if WORLD_RANK == 0:
x_low_precision_single = quantizer(x_hp_rank0)
# run all-gather with a quantizer as input for quantized all-gather
x_low_precision_total, _ = gather_along_first_dim(
x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist
)
# check the outputs
if WORLD_RANK == 0:
# assert all data and scale_inv are the same
torch.testing.assert_close(
x_low_precision_single._rowwise_data,
x_low_precision_total._rowwise_data,
rtol=0.0,
atol=0.0,
)
# check the rowwise scale without any padding
unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape
unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
torch.testing.assert_close(
unpadded_rowwise_scale_inv_ref,
unpadded_rowwise_scale_inv,
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
_ref_zero_padding_scale_inv(
x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
),
_ref_zero_padding_scale_inv(
x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
),
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
x_low_precision_single._columnwise_data,
x_low_precision_total._columnwise_data,
rtol=0.0,
atol=0.0,
)
unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape
unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
torch.testing.assert_close(
unpadded_columnwise_scale_inv_ref,
unpadded_columnwise_scale_inv,
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
_ref_zero_padding_scale_inv(
x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
),
_ref_zero_padding_scale_inv(
x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
),
rtol=0.0,
atol=0.0,
)
def test_quantized_all_gather():
"""
Run quantized all-gather tests with various configurations.
"""
# skip this test for other quantization schemes
is_nvfp4 = QUANTIZATION == "nvfp4"
# add other recipes for testing if needed
if not is_nvfp4:
return
input_dtypes = [torch.bfloat16]
fp4_dtype = [tex.DType.kFloat4E2M1]
fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
quantizer_cls_nvfp4 = [NVFP4Quantizer]
# add FP8 quantizers if needed
quantizer_cls_fp8 = []
low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype
quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8
for quantizer_cls in quantizer_cls_list:
for input_dtype in input_dtypes:
for low_precision_dtype in low_precisio_dtypes:
_test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls)
############################################
# Linear #
############################################
......@@ -514,10 +738,11 @@ def test_linear():
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
......@@ -693,11 +918,12 @@ def test_layernorm_linear():
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"zero_centered_gamma": False},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
......@@ -799,7 +1025,7 @@ def test_layernorm_mlp():
{"normalization": "RMSNorm"},
{"zero_centered_gamma": True},
{"bias": False},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"activation": "relu"},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
......@@ -897,7 +1123,7 @@ def test_transformer_layer():
{"fuse_qkv_params": True, "fuse_wgrad_accumulation": True},
{"qkv_weight_interleaved": False},
{"bias": False},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"fuse_qkv_params": True},
{"activation": "relu"},
]
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import datetime
import os
import sys
from functools import wraps
import math
import transformer_engine.pytorch as te
import torch
from torch import nn
import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
NVFP4BlockScaling,
Format,
Recipe,
QParams,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from run_layer_with_overlap import _compare_tensors
BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "nvfp4":
return nvfp4_rht_and_2d_quantization()
raise ValueError(f"Unsupported quantization: {QUANTIZATION}")
def setup_environment_for_reference():
if QUANTIZATION == "nvfp4":
os.environ["QAT_PARAMS"] = "9003"
else:
raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}")
def cleanup_environment():
if "QAT_PARAMS" in os.environ:
del os.environ["QAT_PARAMS"]
def main(argv=None, namespace=None):
global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
"timeout": datetime.timedelta(seconds=30),
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
NCCL_WORLD = dist.new_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
parser = argparse.ArgumentParser()
parser.add_argument("--quantization", type=str, default=None)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--hidden-size", type=int, default=128)
parser.add_argument("--out-size", type=int, default=128)
args = parser.parse_args(argv, namespace)
# Quantization scheme
QUANTIZATION = args.quantization
BATCH_SIZE = args.batch_size
HIDDEN_SIZE = args.hidden_size
OUT_SIZE = args.out_size
test_dict = [
test_linear,
test_layernorm_linear,
]
for test in test_dict:
test()
dist.destroy_process_group()
return 0
def run_distributed_test(test_name=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
name = test_name if test_name is not None else func.__name__
dist_print(f"Starting test {name} with args {args} and {kwargs}")
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
func(*args, **kwargs)
dist.barrier()
dist_print(f"Passed test {name}")
return wrapper
return decorator
def dist_print(msg, src=None, end="\n", error=False):
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
############################################
# Linear #
############################################
class TestDistributedLinearBase:
@staticmethod
def _prepare_data(
batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32
):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda")
bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None
gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda")
return x, w, bias, gradient
@staticmethod
def _shard_tensor(x, world_size, axis):
split_size = x.size()[axis] // world_size
split_tensor = torch.split(x, split_size, axis)
out = []
for tensor in split_tensor:
out.append(tensor.detach().clone().requires_grad_(x.requires_grad))
return out
@staticmethod
def _gather_tensor(local, world_size, tp_group, concat_dim):
out_list = [torch.zeros_like(local) for _ in range(world_size)]
torch.distributed.all_gather(out_list, local, tp_group)
return torch.cat(out_list, dim=concat_dim)
@staticmethod
def _all_reduce_tensor(local, world_size, tp_group):
if world_size == 1:
return local
handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False)
return local
@staticmethod
def _get_sum_abs_error(a, b):
return torch.sum(torch.abs(a - b))
@staticmethod
def _get_mean_abs_relative_error(a, b):
error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
return torch.mean(error)
@classmethod
def run_linear_preprocess_parallel(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_size=1,
rank=0,
):
if tp_size > 1:
if parallel_mode == "column":
# split w in N dim, which should be axis 0
w = cls._shard_tensor(w, tp_size, 0)[rank]
bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None
# split gradient in N dim, which should be axis 1
gradient = cls._shard_tensor(gradient, tp_size, 1)[rank]
if sequence_parallel:
# split x in M dim, which should be axis 0
x = cls._shard_tensor(x, tp_size, 0)[rank]
# row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1
if parallel_mode == "row":
# split x in K dim, which should be axis 1
x = cls._shard_tensor(x, tp_size, 1)[rank]
# split w in K dim, which should be axis 1
w = cls._shard_tensor(w, tp_size, 1)[rank]
if sequence_parallel:
# split gradient in M dim, which should be axis 0
gradient = cls._shard_tensor(gradient, tp_size, 0)[rank]
return x, w, bias, gradient
@classmethod
def run_linear_postprocess_parallel(
cls,
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
):
if tp_size > 1:
if parallel_mode == "column":
# gather y_q in N dim, which should be axis 1
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1)
# gather wgrad in N dim, which should be axis 0
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0)
# gather bgrad in N dim, which should be axis 0
bgrad = (
cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None
)
if sequence_parallel:
# gather dgrad in M dim, which should be axis 0
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0)
if parallel_mode == "row":
# gather dgrad in K dim, which should be axis 1
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1)
# gather wgrad in K dim, which should be axis 1
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1)
if sequence_parallel:
# gather y_q in M dim, which should be axis 0
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0)
# we need to sum bias gradient when using TP + SP
bgrad = (
cls._all_reduce_tensor(bgrad, tp_size, tp_group)
if bgrad is not None
else None
)
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_one_step(
cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False
):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
if isinstance(layer, te.Linear):
# Kitchen Linear
y_q = layer.forward(x, is_first_microbatch=is_first_microbatch)
else:
# the default torch.nn.Linear
y_q = layer(x)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
bgrad = (
layer._parameters["bias"].grad
if layer._parameters.get("bias", None) is not None
else None
)
assert "weight" in layer._parameters
if fuse_wgrad_accumulation:
wgrad = layer._parameters["weight"].main_grad
assert layer._parameters["weight"].grad is None
else:
wgrad = layer._parameters["weight"].grad
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls,
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation=False,
):
"""
Run multiple steps of linear layer and collect results.
"""
y_q_list, dgrad_list, wgrad_list = [], [], []
bgrad_list = [] if layer._parameters.get("bias", None) is not None else None
for i in range(run_num_steps):
x_i = (x + i).clone().detach().requires_grad_(True)
# run_linear_one_step
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(
layer,
x_i,
gradient,
is_first_microbatch=(i == 0) if enable_weight_cache else None,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
# Collect results
y_q_list.append(y_q.detach().clone())
dgrad_list.append(dgrad.detach().clone())
wgrad_list.append(wgrad.detach().clone())
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod
def run_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
fuse_wgrad_accumulation=False,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = te.Linear(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
layer = layer.to("cuda")
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
if fuse_wgrad_accumulation:
assert (
run_num_steps > 1
), "Fused weight gradient accumulation requires run_num_steps > 1"
layer.weight.main_grad = torch.zeros_like(layer.weight)
# Run one step or multiple steps
if run_num_steps == 1:
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
else:
y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps(
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation,
)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, dgrad, wgrad, bgrad
@run_distributed_test()
def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'row' or 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference
"""
params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True)
fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False)
seed = torch.initial_seed()
recipe = quantization_recipe()
# turn on weight quantization cache when fusing wgrad accumulation
enable_weight_cache = fuse_wgrad_accumulation
run_num_steps = 1 if not fuse_wgrad_accumulation else 5
x, w, bias, gradient = TestDistributedLinearBase._prepare_data(
BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype
)
# run the recipe under test
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
run_num_steps=1 if not fuse_wgrad_accumulation else 5,
enable_weight_cache=fuse_wgrad_accumulation,
)
# run the reference
setup_environment_for_reference()
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
run_num_steps=run_num_steps,
enable_weight_cache=enable_weight_cache,
)
# Clean up env
cleanup_environment()
# compare results, zero tolerance
if WORLD_RANK == 0:
torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch")
torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch")
torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch")
if bgrad is not None and bgrad_ref is not None:
torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch")
def test_linear():
"""Run linear layer tests with various configurations."""
kwargs_list = [
{"bias": False},
]
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs)
############################################
# LayerNormLinear #
############################################
class TestDistributedLayerNormLinearBase(TestDistributedLinearBase):
@classmethod
def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
parameters = layer._parameters
# bias and weight gradients
bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None
assert "weight" in parameters
wgrad = parameters["weight"].grad
return y_q, ln_out, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False
):
# raise error, no test case for multiple steps for now
raise NotImplementedError("LayerNormLinear does not support test multiple steps for now")
@classmethod
def run_layernorm_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
LayerNormLinearClass=te.LayerNormLinear,
normalization="LayerNorm",
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = LayerNormLinearClass(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
normalization=normalization,
return_layernorm_output=True,
)
layer = layer.to("cuda")
# Copy weights
# kitchen_linear has different parameter names
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
# Run one step
y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, ln_out, dgrad, wgrad, bgrad
@run_distributed_test()
def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
"""
params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True)
seed = torch.initial_seed()
recipe = quantization_recipe()
# run multiple steps currently not supported for LayerNormLinear
run_num_steps = 1
x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data(
BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype
)
# run the recipe under test
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
run_num_steps=run_num_steps,
enable_weight_cache=False,
)
# run the reference
setup_environment_for_reference()
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = (
TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
run_num_steps=run_num_steps,
enable_weight_cache=False,
)
)
# Clean up env
cleanup_environment()
# compare results, zero tolerance
if WORLD_RANK == 0:
torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch")
torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch")
torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch")
torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch")
if bgrad is not None and bgrad_ref is not None:
torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch")
def test_layernorm_linear():
kwargs_list = [
{"bias": False},
]
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
_test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)
if __name__ == "__main__":
sys.exit(main())
......@@ -27,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
......@@ -34,17 +35,20 @@ import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
from utils import dtype_tols, make_recipe, quantization_tols
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
if nvfp4_available:
quantization_list.append("nvfp4")
@functools.cache
......@@ -115,6 +119,14 @@ def make_reference_and_test_tensors(
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
......@@ -437,7 +449,7 @@ def _test_basic_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -609,7 +621,7 @@ def _test_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......
......@@ -31,6 +31,7 @@ mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available(
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
......@@ -51,7 +52,9 @@ def _run_test(quantization):
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
@pytest.mark.parametrize(
"quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -61,4 +64,6 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
"""
Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def _run_test(quantization, batch_size, hidden_size, out_size):
test_path = TEST_ROOT / "run_numerics_exact.py"
test_cmd = LAUNCH_CMD + [str(test_path)]
test_cmd += ["--quantization", quantization]
test_cmd += ["--batch-size", str(batch_size)]
test_cmd += ["--hidden-size", str(hidden_size)]
test_cmd += ["--out-size", str(out_size)]
result = subprocess.run(test_cmd, env=os.environ, check=False)
assert result.returncode == 0
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", ["nvfp4"])
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(64, 128, 128),
(128, 128, 128),
(128, 256, 256),
(512, 1024, 768),
(512, 256, 1024),
(2048, 2048, 2048),
],
)
def test_distributed(quantization, batch_size, hidden_size, out_size):
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization, batch_size, hidden_size, out_size)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
def check_nvfp4_gemm_versus_reference(
x_dtype: torch.dtype,
w_dtype: torch.dtype,
out_dtype: torch.dtype,
M: int,
K: int,
N: int,
accumulate: bool,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
):
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input tensors
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
x = torch.randn(x_shape, dtype=x_dtype, device=device)
w = torch.randn(w_shape, dtype=w_dtype, device=device)
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device)
else:
out = None
# Native TE NVFP4 quantization
x_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
w_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
# Quantize x and w
x_nvfp4_native = x_quantizer.make_empty(
x_shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native)
w_nvfp4_native = w_quantizer.make_empty(
w_shape, dtype=w_dtype, device=device, requires_grad=False
)
w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native)
# Extract quantized data from native NVFP4Tensors
qx_data = (
x_nvfp4_native._columnwise_data.view(dtype=torch.uint8)
if x_columnwise
else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8)
)
qw_data = (
w_nvfp4_native._columnwise_data.view(dtype=torch.uint8)
if w_columnwise
else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8)
)
sx_native = (
x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv
)
sw_native = (
w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv
)
# Trim quantized data to match the actual tensor dimensions (remove padding)
qx_data = qx_data[:M, :]
qw_data = qw_data[:N, :]
# NVFP4 uses 16-element blocks, trim scales to remove padding
block_length = 16 # NVFP4 uses 16-element blocks
expected_sx_cols = expected_sw_cols = K // block_length
# Trim the scales to remove padding
sx_trimmed = sx_native[:M, :expected_sx_cols]
sw_trimmed = sw_native[:N, :expected_sw_cols]
# Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn
# for the reference GEMM to work correctly
sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn)
sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn)
# Create reference quantizer for reference GEMM
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=True,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
# Create reference quantized tensors needed by reference GEMM
x_nvfp4_ref = ref_quantizer.quantize(x)
w_nvfp4_ref = ref_quantizer.quantize(w)
# Reference GEMM using quantizer's qgemm method
y_ref = ref_quantizer.qgemm(
qx=qx_data,
qw=qw_data,
m_params=None, # MMParams not used in reference
out_dtype=out_dtype,
sx=sx_trimmed,
sw=sw_trimmed,
bias=None, # No bias for this test
out=out.clone() if accumulate else None,
accumulate=accumulate,
gemm_type=None, # GEMMType not used in reference
qresult_x=x_nvfp4_ref,
qresult_w=w_nvfp4_ref,
)
# Native TE GEMM using tex.generic_gemm (cuBLAS GEMM)
# Allocate cuBLAS workspace
workspace = torch.empty(4, dtype=torch.uint8, device=device)
transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
out_quantizer = None
bias = None
bias_dtype = TE_DType[torch.bfloat16]
use_gelu = False
gelu_input = None
use_grad = False
use_split_accumulator = False
# Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_native = tex.generic_gemm(
w_nvfp4_native,
transa,
x_nvfp4_native,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
gelu_input,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# just in case of accumulation, make sure y_ref and y_native are not the same tensor
assert y_ref is not y_native, "y_ref and y_native should not be the same tensor"
# Reset nans to zeros because torch.assert_close does not assume nans to be equal
assert not torch.isnan(y_ref.float()).all(), "All elements are nan"
y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref)
y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native)
# Compare results with some tolerance
torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, K, N",
[
(128, 128, 128),
(256, 128, 256),
(256, 256, 256),
(256, 1024, 256),
(1024, 1024, 1024),
(4096, 512, 3072),
(112, 128, 96),
(304, 640, 304),
(1008, 3072, 992),
(256, 64, 256),
(128, 128, 112),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise",
[
(False, False), # Only rowwise x rowwise is supported by reference GEMM
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
# Columnwise layouts are not supported by the reference implementation
],
ids=["rowxrow"],
)
def test_nvfp4_gemm_versus_reference(
M: int,
K: int,
N: int,
x_dtype: torch.dtype,
w_dtype: torch.dtype,
out_dtype: torch.dtype,
accumulate: bool,
is_x_columnwise: bool,
is_w_columnwise: bool,
):
check_nvfp4_gemm_versus_reference(
x_dtype=x_dtype,
w_dtype=w_dtype,
out_dtype=out_dtype,
M=M,
K=K,
N=N,
accumulate=accumulate,
x_columnwise=is_x_columnwise,
w_columnwise=is_w_columnwise,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import torch
import transformer_engine as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common import recipe
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
class GetRecipes:
@staticmethod
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
@staticmethod
def nvfp4_rht_only():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True)
return nvfp4_recipe
@staticmethod
def nvfp4_2d_quantization_only():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False)
return nvfp4_recipe
@staticmethod
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
@staticmethod
def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False):
if with_rht and with_2d_quantization:
return GetRecipes.nvfp4_rht_and_2d_quantization()
elif with_rht:
return GetRecipes.nvfp4_rht_only()
elif with_2d_quantization:
return GetRecipes.nvfp4_2d_quantization_only()
else:
return GetRecipes.nvfp4_vanilla()
def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False):
if with_rht and with_2d_quantization:
os.environ["QAT_PARAMS"] = "9003"
elif with_rht:
os.environ["QAT_PARAMS"] = "960109"
elif with_2d_quantization:
os.environ["QAT_PARAMS"] = "9002"
else:
os.environ["QAT_PARAMS"] = "6010"
def cleanup_environment():
if "QAT_PARAMS" in os.environ:
del os.environ["QAT_PARAMS"]
def reset_rng_states():
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def check_nvfp4_module_versus_reference(
module_class,
in_features: int,
out_features: int,
bias: bool,
x_dtype: torch.dtype,
num_steps: int = 1,
with_rht: bool = False,
with_2d_quantization: bool = False,
):
"""
Compare native NVFP4 module against reference implementation.
Args:
module_class: te.Linear or te.LayerNormLinear
in_features: Input feature dimension
out_features: Output feature dimension
bias: Whether to use bias
x_dtype: Input tensor dtype
num_steps: Number of forward/backward steps to test
"""
device = "cuda"
batch_size = 32
seq_len = 128
# Create both modules with identical initialization
cleanup_environment()
reset_rng_states()
# Create native module
print("\nCreate native module")
if module_class == te.pytorch.Linear:
native_module = te.pytorch.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
elif module_class == te.pytorch.LayerNormLinear:
native_module = te.pytorch.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
else:
raise ValueError(f"Unsupported module class: {module_class}")
# Create reference module with same weights
setup_environment_for_reference(with_rht, with_2d_quantization)
reset_rng_states()
# Create reference module
print("Create reference module")
if module_class == te.pytorch.Linear:
ref_module = te.pytorch.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
elif module_class == te.pytorch.LayerNormLinear:
ref_module = te.pytorch.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
# Sync weights between native and reference modules
with torch.no_grad():
# Copy main weight and bias parameters
if hasattr(native_module, "weight") and hasattr(ref_module, "weight"):
ref_module.weight.copy_(native_module.weight)
if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"):
ref_module.bias.copy_(native_module.bias)
# Copy layer norm parameters if they exist
if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"):
ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight)
if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
# Training loop comparison
native_outputs = []
ref_outputs = []
for step in range(num_steps):
torch.manual_seed(1234 + step)
torch.cuda.manual_seed(1234 + step)
x_shape = (batch_size, seq_len, in_features)
x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device)
x_native = x_val.clone().detach().requires_grad_(True)
x_ref = x_native.clone().detach().requires_grad_(True)
grad_output_shape = (batch_size, seq_len, out_features)
grad_output_val = torch.normal(
mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device
)
grad_output = grad_output_val.clone().detach()
# Native forward/backward
cleanup_environment()
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
# enable weight cache by giving is_first_microbatch
y_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
setup_environment_for_reference(with_rht, with_2d_quantization)
with fp8_autocast(
enabled=True, fp8_recipe=nvfp4_recipe
): # Exact recipe does not play a role here
y_ref = ref_module(x_ref)
y_ref.backward(grad_output)
# Store results
native_outputs.append(
{
"output": y_native.detach().clone(),
"input_grad": (
x_native.grad.detach().clone() if x_native.grad is not None else None
),
"weight_grad": (
native_module.weight.grad.detach().clone()
if native_module.weight.grad is not None
else None
),
"bias_grad": (
native_module.bias.grad.detach().clone()
if bias and native_module.bias.grad is not None
else None
),
}
)
ref_outputs.append(
{
"output": y_ref.detach().clone(),
"input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None),
"weight_grad": (
ref_module.weight.grad.detach().clone()
if ref_module.weight.grad is not None
else None
),
"bias_grad": (
ref_module.bias.grad.detach().clone()
if bias and ref_module.bias.grad is not None
else None
),
}
)
# Compare results across all steps
for step in range(num_steps):
native_out = native_outputs[step]
ref_out = ref_outputs[step]
# Compare outputs
torch.testing.assert_close(
native_out["output"],
ref_out["output"],
atol=1e-6,
rtol=1e-6,
msg=f"Output mismatch at step {step}",
)
# Compare input gradients
torch.testing.assert_close(
native_out["input_grad"],
ref_out["input_grad"],
atol=1e-6,
rtol=1e-6,
msg=(
f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:"
f" {ref_out['input_grad']}"
),
)
# Compare weight gradients
torch.testing.assert_close(
native_out["weight_grad"],
ref_out["weight_grad"],
atol=1e-6,
rtol=1e-6,
msg=(
f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']},"
f" Ref: {ref_out['weight_grad']}"
),
)
# Compare bias gradients
if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None:
torch.testing.assert_close(
native_out["bias_grad"],
ref_out["bias_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Bias gradient mismatch at step {step}",
)
# Clean up
cleanup_environment()
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"in_features, out_features",
[
(128, 256),
(256, 128),
(512, 512),
(768, 3072),
(1024, 4096),
],
)
# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"])
@pytest.mark.parametrize("bias", [False], ids=["no_bias"])
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"])
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"]
)
def test_nvfp4_linear_versus_reference(
in_features: int,
out_features: int,
bias: bool,
x_dtype: torch.dtype,
num_steps: int,
with_rht: bool,
with_2d_quantization: bool,
):
"""Test NVFP4 Linear module against reference implementation."""
if with_rht and x_dtype != torch.bfloat16:
pytest.skip("RHT is only supported for bfloat16 input")
check_nvfp4_module_versus_reference(
module_class=te.pytorch.Linear,
in_features=in_features,
out_features=out_features,
bias=bias,
x_dtype=x_dtype,
num_steps=num_steps,
with_rht=with_rht,
with_2d_quantization=with_2d_quantization,
)
def check_nvfp4_layernorm_linear_versus_reference(
in_features: int,
out_features: int,
bias: bool,
normalization: str,
x_dtype: torch.dtype,
num_steps: int = 1,
with_rht: bool = False,
with_2d_quantization: bool = False,
):
"""
Compare native NVFP4 LayerNormLinear module against reference implementation,
including ln_out.
"""
device = "cuda"
batch_size = 32
seq_len = 128
# Create both modules with identical initialization
cleanup_environment()
reset_rng_states()
# Native module
native_module = te.pytorch.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
normalization=normalization,
return_layernorm_output=True,
)
# Reference module
setup_environment_for_reference(with_rht, with_2d_quantization)
reset_rng_states()
ref_module = te.pytorch.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
normalization=normalization,
return_layernorm_output=True,
)
# Sync weights and LN params
with torch.no_grad():
if hasattr(native_module, "weight") and hasattr(ref_module, "weight"):
ref_module.weight.copy_(native_module.weight)
if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"):
ref_module.bias.copy_(native_module.bias)
if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"):
if (
native_module.layer_norm_weight is not None
and ref_module.layer_norm_weight is not None
):
ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight)
if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None:
ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
native_outputs = []
ref_outputs = []
for step in range(num_steps):
torch.manual_seed(1234 + step)
torch.cuda.manual_seed(1234 + step)
x_shape = (batch_size, seq_len, in_features)
x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device)
x_native = x_val.clone().detach().requires_grad_(True)
x_ref = x_native.clone().detach().requires_grad_(True)
grad_output_shape = (batch_size, seq_len, out_features)
grad_output_val = torch.normal(
mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device
)
grad_output = grad_output_val.clone().detach()
# Native forward/backward
cleanup_environment()
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
setup_environment_for_reference(with_rht, with_2d_quantization)
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
y_ref, ln_out_ref = ref_module(x_ref)
y_ref.backward(grad_output)
native_outputs.append(
{
"output": y_native.detach().clone(),
"ln_out": ln_out_native.detach().clone(),
"input_grad": (
x_native.grad.detach().clone() if x_native.grad is not None else None
),
"weight_grad": (
native_module.weight.grad.detach().clone()
if native_module.weight.grad is not None
else None
),
"bias_grad": (
native_module.bias.grad.detach().clone()
if bias and native_module.bias.grad is not None
else None
),
}
)
ref_outputs.append(
{
"output": y_ref.detach().clone(),
"ln_out": ln_out_ref.detach().clone(),
"input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None),
"weight_grad": (
ref_module.weight.grad.detach().clone()
if ref_module.weight.grad is not None
else None
),
"bias_grad": (
ref_module.bias.grad.detach().clone()
if bias and ref_module.bias.grad is not None
else None
),
}
)
# Compare results
for step in range(num_steps):
n = native_outputs[step]
r = ref_outputs[step]
torch.testing.assert_close(
n["output"],
r["output"],
atol=1e-6,
rtol=1e-6,
msg=f"Output mismatch at step {step}",
)
torch.testing.assert_close(
n["ln_out"],
r["ln_out"],
atol=1e-6,
rtol=1e-6,
msg=f"LN output mismatch at step {step}",
)
torch.testing.assert_close(
n["input_grad"],
r["input_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Input gradient mismatch at step {step}",
)
torch.testing.assert_close(
n["weight_grad"],
r["weight_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Weight gradient mismatch at step {step}",
)
if bias and n["bias_grad"] is not None and r["bias_grad"] is not None:
torch.testing.assert_close(
n["bias_grad"],
r["bias_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Bias gradient mismatch at step {step}",
)
cleanup_environment()
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"in_features, out_features",
[
(128, 256),
(256, 128),
],
)
@pytest.mark.parametrize("bias", [False], ids=["no_bias"])
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("num_steps", [1], ids=["single_step"])
@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"])
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"]
)
def test_nvfp4_layernorm_linear_versus_reference(
in_features: int,
out_features: int,
bias: bool,
normalization: str,
x_dtype: torch.dtype,
num_steps: int,
with_rht: bool,
with_2d_quantization: bool,
):
if with_rht and x_dtype != torch.bfloat16:
pytest.skip("RHT is only supported for bfloat16 input")
check_nvfp4_layernorm_linear_versus_reference(
in_features=in_features,
out_features=out_features,
bias=bias,
normalization=normalization,
x_dtype=x_dtype,
num_steps=num_steps,
with_rht=with_rht,
with_2d_quantization=with_2d_quantization,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
swizzled_scale: bool,
use_cpp_allocator: bool,
with_2d_quantization: bool,
) -> None:
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=with_2d_quantization,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
# Reference quantization
quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16)
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=quant_tile_shape,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
if x_nvfp4_ref.data_t is not None
else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# # largest tile
(8192, 8192),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"])
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"]
)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
swizzled_scale: bool,
use_cpp_allocator: bool,
with_2d_quantization: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
return_transpose=return_transpose,
swizzled_scale=swizzled_scale,
use_cpp_allocator=use_cpp_allocator,
with_2d_quantization=with_2d_quantization,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(128, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_extrema_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
extrema_high: bool,
return_transpose: bool,
use_cpp_allocator: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if extrema_high:
x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device)
else:
x = torch.zeros((M, N), dtype=x_dtype, device=device)
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(16, 128),
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_boundary_values(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
):
"""
Stress rounding/threshold behavior by placing values just below/above
many potential bin edges within each 16-element microblock.
Validates native vs reference byte-for-byte and scale parity.
"""
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Construct a single row with paired boundary values: v-eps, v+eps
# spanning a wide dynamic range to exercise clipping and multiple bins.
# Ensure even N and N is multiple of 16 for microblocks, which holds for 128.
base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device)
eps = torch.full_like(base, 1e-3)
# Avoid zero eps for very small magnitudes
eps = torch.maximum(eps, 1e-4 * torch.ones_like(base))
lower = base - eps
upper = base + eps
row = torch.empty(N, dtype=torch.float32, device=device)
row[0::2] = lower
row[1::2] = upper
x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype)
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only valid portion of scales (trim any padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 17
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Start from a contiguous tensor, then make a non-contiguous view by transpose
x_base = torch.randn((M, N), dtype=x_dtype, device=device)
x_nc = x_base.t() # shape (N, M), non-contiguous
assert not x_nc.is_contiguous()
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x_nc)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x_nc.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x_nc)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
# Quantized must match
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only valid portion of scales (trim padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment