Unverified Commit 3f5b4754 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core][PyTorch] NVFP4 recipe (#2177)



* Add NVFP4 recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add MathDx dependency to GitHub builds
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from GitHub Copilot
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move 2x shape logic from core to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compilation errors with CUDA 12.1
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* SM 70 is not supported in CUDA 13
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Typo
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Revert "Move 2x shape logic from core to PyTorch"

This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Added dequantize kernel for FP4
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 support with fusible ops

Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix logic for 2x shapes and move to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG test model config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug NVFP4 tensor size function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Proper handling of the RNG state
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Test SR properly
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix workspace size for GEMM heuristic.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compile error in C++ NVFP4 test

Some some numeric errors when blocks are all zero.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix distrbuted test problem shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* proper assert dim for low precision AG TP
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up duplicated code in nvfp4_utils.cuh
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pylint: disable=unused-argument
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* `nvte_cublas_gemm_v2` to take alpha pointer (#12)

* make nvte_cublas_gemm_v2 to take alpha/beta pointers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* users are expected to pass a valid C_tensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* typos
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* API to have const float* alpha
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Minor tweaks

Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug IMA with alpha pointer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support fused amax kernels with NVFP4 quantization
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused amax with cuDNN LayerNorm kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 cases to distributed tests for TE ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change assert to NVTE_CHECK in the hadamard cast fusion
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix compile error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use global thread IDs for Philox subsequences
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shape checks for NVFP4 cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not fuse amax if cuDNN normalization is forced by envvar
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfeef1a2
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
from typing import Tuple
import math
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
import pytest
import torch
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
contiguous: bool,
return_transpose: bool,
use_cpp_allocator: bool,
swizzled_scale: bool = False,
hadamard_dimension: int = 16,
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled."
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
x = x.transpose(0, 1) if not contiguous else x
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
amax_rowwise = x_nvfp4_sut._amax_rowwise
amax_colwise = x_nvfp4_sut._amax_columnwise
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
# Reference quantization using NVFP4QuantizerRef with built-in RHT
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
with_rht=with_rht,
with_random_sign_mask=with_random_sign_mask,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
ref_amax_rowwise = x_nvfp4_ref.global_amax_row
if return_transpose:
assert x_nvfp4_ref.data_t is not None
assert x_nvfp4_ref.scale_t is not None
qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8)
# Compute transpose amax using the same reference quantizer
x_t_for_amax = (
ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous()
)
ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1)
else:
qx_t_ref = None
sx_t_ref = None
ref_amax_colwise_t = None
torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# Real shapes,
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=True,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
):
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=False,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
_FP4_LUT = torch.tensor(
[
0.0, # 0: 0000 - zero
0.5, # 1: 0001 - smallest positive normal
1.0, # 2: 0010
1.5, # 3: 0011
2.0, # 4: 0100
3.0, # 5: 0101
4.0, # 6: 0110
6.0, # 7: 0111 - largest positive normal
-0.0, # 8: 1000 - negative zero
-0.5, # 9: 1001 - smallest negative normal
-1.0, # 10: 1010
-1.5, # 11: 1011
-2.0, # 12: 1100
-3.0, # 13: 1101
-4.0, # 14: 1110
-6.0, # 15: 1111 - largest negative normal
],
dtype=torch.float32,
)
def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor:
# Convert FP4 indices to their corresponding floating point values
# Each index (0-15) represents a 4-bit FP4 value in E2M1 format
# Values based on the FP4 E2M1 specification
fp4_lut = _FP4_LUT.to(fp4.device)
return fp4_lut[fp4.to(torch.long)]
def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor:
sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
dqx = fp4_to_fp32(unpack_fp4(qx))
sf = sf[: dqx.shape[0], : dqx.shape[1]]
dequant = dqx * sf * (amax / (6.0 * 448))
return dequant
def RHT(x: torch.Tensor) -> torch.Tensor:
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded signs for Hadamard transform"""
return torch.tensor(
[
1.0,
1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
],
dtype=torch.float32,
)
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
) -> torch.Tensor:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
h = torch.ones((1, 1), device=device, dtype=torch.float32)
while h.shape[0] < size:
h = torch.cat(
[
torch.cat([h, h], dim=1),
torch.cat([h, -h], dim=1),
],
dim=0,
)
if with_random_sign_mask:
sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
size, device=device, dtype=torch.float32
)
h = sign_mat @ h
return h.to(dtype)
rht_dim = 16
# Build H and scale
H = _build_hadamard_matrix(rht_dim, x.device, x.dtype)
scale = 1.0 / float(rht_dim) ** 0.5
# Perform blockwise transform along the last dimension
original_shape = x.shape
x_mat = x.contiguous().view(-1, rht_dim)
# Random sign matrix is identity in this reference (no sign flipping)
transform = H * scale
out = x_mat @ transform
return out.view(original_shape)
def quantize_fp4(
x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool
) -> torch.Tensor:
nvfp4_quantizer = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
x_nvfp4_sut = nvfp4_quantizer(x)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv
return qx, sx, qx_t, sx_t
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4(
x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT
)
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for i in range(n_iters):
q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4(
x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT
)
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# sr_result_tmp = sr_result / (i + 1)
# error_sr = (sr_result_tmp - x).float()
# me_sr = torch.sqrt((error_sr * error_sr).mean())
# sr_t_result_tmp = sr_t_result / (i + 1)
# error_t_sr = (sr_t_result_tmp - y).float()
# me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
# print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
# print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(8192, 8256), # to test the nonfused RHT path
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False, True], ids=str)
@pytest.mark.parametrize("use_RHT", [False, True], ids=str)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
)
......@@ -32,12 +32,59 @@ mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
reset_rng_states()
model_configs = {
"small": ModelConfig(32, 2, 2, 32),
"small": ModelConfig(2, 32, 2, 32),
}
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False
def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_rht_and_2d_quantization())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
......@@ -278,7 +325,7 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
def test_make_graphed_callables(
*,
module: str,
......@@ -295,8 +342,18 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
pytest.skip(
f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
)
if fp8 and fp8_recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe"
f" {fp8_recipe.__class__.__name__}"
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
......@@ -334,17 +391,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
def test_make_graphed_callables_with_fp8_weight_caching(
*,
module: str,
dtype: torch.dtype,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
) -> None:
test_make_graphed_callables(
module=module,
dtype=torch.float32,
dtype=dtype,
fp8_params=fp8_params,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
......
......@@ -10,7 +10,6 @@ import pytest
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
......@@ -273,6 +272,14 @@ class TestFP8RecipeLinearBase:
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod
def run_linear(
cls,
......
This diff is collapsed.
......@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast
......@@ -499,3 +500,39 @@ class TestFP8Recipe:
y = module(x, [batch_size])
else:
y = module(x)
fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available()
@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# # largest tile
(8192, 8192),
],
)
def test_fp4_dequantize(dtype, M, N):
q = NVFP4Quantizer()
a = torch.rand((M, N)).cuda().to(dtype=dtype)
starting_tensor = q(a)
dequantized_tensor = starting_tensor.dequantize()
new_tensor = q(dequantized_tensor)
torch.testing.assert_close(
new_tensor._rowwise_data,
starting_tensor._rowwise_data,
rtol=0,
atol=0,
)
new_dequantized_tensor = new_tensor.dequantize()
torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)
......@@ -87,9 +87,19 @@ model_configs = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
......@@ -379,6 +389,8 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -407,6 +419,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -437,6 +451,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
......@@ -476,6 +492,8 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4():
pytest.skip("NVFP4 not supported for grouped linear")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
......@@ -526,6 +544,8 @@ def test_sanity_layernorm_mlp(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -568,6 +588,8 @@ def test_sanity_gpt(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -629,6 +651,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -683,6 +707,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -734,6 +760,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -764,6 +792,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -798,6 +828,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -832,6 +864,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
......
......@@ -73,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat4E2M1:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.25
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
......@@ -95,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
return dict(rtol=0.25, atol=0.125) # epsilon = 0.125
raise ValueError(f"Unsupported dtype ({dtype})")
def quantization_tols(name: str) -> dict[str, float]:
"""Estimated numerical error for a quantization scheme"""
if name in (
"fp8",
"fp8_delayed_scaling",
"fp8_current_scaling",
"mxfp8",
"mxfp8_block_scaling",
):
return dtype_tols(tex.DType.kFloat8E4M3)
if name == "nvfp4":
return dtype_tols(tex.DType.kFloat4E2M1)
raise ValueError(f"Unsupported quantization scheme ({name})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
......@@ -118,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
if name == "nvfp4":
return transformer_engine.common.recipe.NVFP4BlockScaling(
disable_rht=True,
disable_stochastic_rounding=True,
disable_2d_quantization=True,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
......
......@@ -53,6 +53,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
endif()
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
if(NOT _MATHDX_LOC_MATCH)
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
......@@ -73,6 +95,7 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
......@@ -85,6 +108,7 @@ list(APPEND transformer_engine_SOURCES
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
......@@ -113,6 +137,9 @@ list(APPEND transformer_engine_SOURCES
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......@@ -145,6 +172,7 @@ target_link_libraries(transformer_engine PUBLIC
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......
......@@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
#if CUDA_VERSION >= 12080
case DType::kFloat4E2M1:
return CUDA_R_4F_E2M1;
#endif
default:
NVTE_ERROR("Invalid type");
}
......@@ -160,7 +164,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits) {
const uint32_t offset_elems, const size_t type_num_bits,
const CUtensorMapSwizzle swizzle) {
cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
......@@ -169,6 +175,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
}();
// rank is the number of dimensions of the array
constexpr uint32_t rank = 2;
// Dimension for the packed data types must reflect the number of individual U# values.
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
......@@ -207,7 +215,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
// Swizzling can be used to avoid shared memory bank conflicts.
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
swizzle,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -122,6 +122,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment