Unverified Commit 3f5b4754 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core][PyTorch] NVFP4 recipe (#2177)



* Add NVFP4 recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add MathDx dependency to GitHub builds
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from GitHub Copilot
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move 2x shape logic from core to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compilation errors with CUDA 12.1
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* SM 70 is not supported in CUDA 13
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Typo
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Revert "Move 2x shape logic from core to PyTorch"

This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Added dequantize kernel for FP4
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 support with fusible ops

Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix logic for 2x shapes and move to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG test model config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug NVFP4 tensor size function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Proper handling of the RNG state
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Test SR properly
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix workspace size for GEMM heuristic.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compile error in C++ NVFP4 test

Some some numeric errors when blocks are all zero.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix distrbuted test problem shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* proper assert dim for low precision AG TP
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up duplicated code in nvfp4_utils.cuh
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pylint: disable=unused-argument
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* `nvte_cublas_gemm_v2` to take alpha pointer (#12)

* make nvte_cublas_gemm_v2 to take alpha/beta pointers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* users are expected to pass a valid C_tensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* typos
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* API to have const float* alpha
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Minor tweaks

Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug IMA with alpha pointer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support fused amax kernels with NVFP4 quantization
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused amax with cuDNN LayerNorm kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 cases to distributed tests for TE ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change assert to NVTE_CHECK in the hadamard cast fusion
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix compile error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use global thread IDs for Philox subsequences
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shape checks for NVFP4 cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not fuse amax if cuDNN normalization is forced by envvar
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfeef1a2
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
from typing import Tuple
import math
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
import pytest
import torch
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
contiguous: bool,
return_transpose: bool,
use_cpp_allocator: bool,
swizzled_scale: bool = False,
hadamard_dimension: int = 16,
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled."
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
x = x.transpose(0, 1) if not contiguous else x
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
amax_rowwise = x_nvfp4_sut._amax_rowwise
amax_colwise = x_nvfp4_sut._amax_columnwise
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
# Reference quantization using NVFP4QuantizerRef with built-in RHT
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
with_rht=with_rht,
with_random_sign_mask=with_random_sign_mask,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
ref_amax_rowwise = x_nvfp4_ref.global_amax_row
if return_transpose:
assert x_nvfp4_ref.data_t is not None
assert x_nvfp4_ref.scale_t is not None
qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8)
# Compute transpose amax using the same reference quantizer
x_t_for_amax = (
ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous()
)
ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1)
else:
qx_t_ref = None
sx_t_ref = None
ref_amax_colwise_t = None
torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# Real shapes,
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=True,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
):
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=False,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
_FP4_LUT = torch.tensor(
[
0.0, # 0: 0000 - zero
0.5, # 1: 0001 - smallest positive normal
1.0, # 2: 0010
1.5, # 3: 0011
2.0, # 4: 0100
3.0, # 5: 0101
4.0, # 6: 0110
6.0, # 7: 0111 - largest positive normal
-0.0, # 8: 1000 - negative zero
-0.5, # 9: 1001 - smallest negative normal
-1.0, # 10: 1010
-1.5, # 11: 1011
-2.0, # 12: 1100
-3.0, # 13: 1101
-4.0, # 14: 1110
-6.0, # 15: 1111 - largest negative normal
],
dtype=torch.float32,
)
def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor:
# Convert FP4 indices to their corresponding floating point values
# Each index (0-15) represents a 4-bit FP4 value in E2M1 format
# Values based on the FP4 E2M1 specification
fp4_lut = _FP4_LUT.to(fp4.device)
return fp4_lut[fp4.to(torch.long)]
def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor:
sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
dqx = fp4_to_fp32(unpack_fp4(qx))
sf = sf[: dqx.shape[0], : dqx.shape[1]]
dequant = dqx * sf * (amax / (6.0 * 448))
return dequant
def RHT(x: torch.Tensor) -> torch.Tensor:
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded signs for Hadamard transform"""
return torch.tensor(
[
1.0,
1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
],
dtype=torch.float32,
)
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
) -> torch.Tensor:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
h = torch.ones((1, 1), device=device, dtype=torch.float32)
while h.shape[0] < size:
h = torch.cat(
[
torch.cat([h, h], dim=1),
torch.cat([h, -h], dim=1),
],
dim=0,
)
if with_random_sign_mask:
sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
size, device=device, dtype=torch.float32
)
h = sign_mat @ h
return h.to(dtype)
rht_dim = 16
# Build H and scale
H = _build_hadamard_matrix(rht_dim, x.device, x.dtype)
scale = 1.0 / float(rht_dim) ** 0.5
# Perform blockwise transform along the last dimension
original_shape = x.shape
x_mat = x.contiguous().view(-1, rht_dim)
# Random sign matrix is identity in this reference (no sign flipping)
transform = H * scale
out = x_mat @ transform
return out.view(original_shape)
def quantize_fp4(
x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool
) -> torch.Tensor:
nvfp4_quantizer = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
x_nvfp4_sut = nvfp4_quantizer(x)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv
return qx, sx, qx_t, sx_t
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4(
x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT
)
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for i in range(n_iters):
q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4(
x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT
)
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# sr_result_tmp = sr_result / (i + 1)
# error_sr = (sr_result_tmp - x).float()
# me_sr = torch.sqrt((error_sr * error_sr).mean())
# sr_t_result_tmp = sr_t_result / (i + 1)
# error_t_sr = (sr_t_result_tmp - y).float()
# me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
# print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
# print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(8192, 8256), # to test the nonfused RHT path
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False, True], ids=str)
@pytest.mark.parametrize("use_RHT", [False, True], ids=str)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
)
...@@ -32,12 +32,59 @@ mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() ...@@ -32,12 +32,59 @@ mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
reset_rng_states() reset_rng_states()
model_configs = { model_configs = {
"small": ModelConfig(32, 2, 2, 32), "small": ModelConfig(2, 32, 2, 32),
} }
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False
def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes
fp8_recipes = [] fp8_recipes = []
if mxfp8_available: if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_rht_and_2d_quantization())
if fp8_block_scaling_available: if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling()) fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available: if fp8_available:
...@@ -278,7 +325,7 @@ def _test_cuda_graphs( ...@@ -278,7 +325,7 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
def test_make_graphed_callables( def test_make_graphed_callables(
*, *,
module: str, module: str,
...@@ -295,8 +342,18 @@ def test_make_graphed_callables( ...@@ -295,8 +342,18 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8: if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") pytest.skip(
f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
)
if fp8 and fp8_recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe"
f" {fp8_recipe.__class__.__name__}"
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
...@@ -334,17 +391,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [ ...@@ -334,17 +391,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"module", "module",
_test_make_graphed_callables_with_fp8_weight_caching_modules, _test_make_graphed_callables_with_fp8_weight_caching_modules,
) )
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
def test_make_graphed_callables_with_fp8_weight_caching( def test_make_graphed_callables_with_fp8_weight_caching(
*, *,
module: str, module: str,
dtype: torch.dtype,
fp8_params: bool, fp8_params: bool,
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
) -> None: ) -> None:
test_make_graphed_callables( test_make_graphed_callables(
module=module, module=module,
dtype=torch.float32, dtype=dtype,
fp8_params=fp8_params, fp8_params=fp8_params,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
fp8_weight_caching=True, fp8_weight_caching=True,
......
...@@ -10,7 +10,6 @@ import pytest ...@@ -10,7 +10,6 @@ import pytest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
...@@ -273,6 +272,14 @@ class TestFP8RecipeLinearBase: ...@@ -273,6 +272,14 @@ class TestFP8RecipeLinearBase:
if bgrad_list is not None and bgrad is not None: if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone()) bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod @classmethod
def run_linear( def run_linear(
cls, cls,
......
...@@ -35,15 +35,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -35,15 +35,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions # Import utility functions
from utils import dtype_tols, make_recipe, reset_rng_states from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
# Check if FP8 is supported # Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
# Supported data types # Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16] _dtypes: list[torch.dtype] = [torch.float32, torch.float16]
...@@ -59,6 +61,8 @@ if fp8_available: ...@@ -59,6 +61,8 @@ if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
_quantization_list.append("mxfp8") _quantization_list.append("mxfp8")
if nvfp4_available:
_quantization_list.append("nvfp4")
def maybe_skip_quantization( def maybe_skip_quantization(
...@@ -66,6 +70,7 @@ def maybe_skip_quantization( ...@@ -66,6 +70,7 @@ def maybe_skip_quantization(
*, *,
dims: Optional[Iterable[int] | int] = None, dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None, device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
"""Skip test case if a quantization scheme is not supported""" """Skip test case if a quantization scheme is not supported"""
...@@ -73,12 +78,17 @@ def maybe_skip_quantization( ...@@ -73,12 +78,17 @@ def maybe_skip_quantization(
if quantization is None: if quantization is None:
return return
# Check if quantization scheme is supported # Check if quantization scheme is supported on device
if device is not None and torch.device(device).type != "cuda":
pytest.skip("Quantization is only supported on CUDA devices")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available: if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Check dims
if dims is not None: if dims is not None:
if not isinstance(dims, Iterable): if not isinstance(dims, Iterable):
dims = (dims,) dims = (dims,)
...@@ -88,10 +98,14 @@ def maybe_skip_quantization( ...@@ -88,10 +98,14 @@ def maybe_skip_quantization(
elif quantization == "mxfp8": elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")
elif quantization == "nvfp4":
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("NVFP4 GEMMs require dims that are divisible by 16")
# Check if device is supported # Check dtype
if device is not None and torch.device(device).type != "cuda": if dtype is not None:
pytest.skip("Quantization is only supported on CUDA devices") if quantization == "nvfp4" and dtype != torch.bfloat16:
pytest.skip("NVFP4 quantization is only supported with BF16 data")
@torch.no_grad() @torch.no_grad()
...@@ -141,6 +155,14 @@ def make_reference_and_test_tensors( ...@@ -141,6 +155,14 @@ def make_reference_and_test_tensors(
test = quantizer(test) test = quantizer(test)
elif quantization == "mxfp8": elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else: else:
raise ValueError(f"Unsupported quantization scheme ({quantization})") raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized: if isinstance(test, QuantizedTensor) and not test_is_quantized:
...@@ -395,12 +417,12 @@ class TestFuser: ...@@ -395,12 +417,12 @@ class TestFuser:
torch.testing.assert_close( torch.testing.assert_close(
y, y,
torch.full_like(y, y_val_ref), torch.full_like(y, y_val_ref),
**dtype_tols(tex.DType.kFloat8E4M3), **quantization_tols("fp8_delayed_scaling"),
) )
torch.testing.assert_close( torch.testing.assert_close(
x.grad, x.grad,
torch.full_like(x.grad, dx_val_ref), torch.full_like(x.grad, dx_val_ref),
**dtype_tols(tex.DType.kFloat8E5M2), **quantization_tols("fp8_delayed_scaling"),
) )
# Check that scaling factors match expected # Check that scaling factors match expected
...@@ -434,7 +456,8 @@ class TestFuser: ...@@ -434,7 +456,8 @@ class TestFuser:
# Skip invalid configurations # Skip invalid configurations
in_shape = (size, size) in_shape = (size, size)
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype)
maybe_skip_quantization(quantization, dtype=final_dtype)
# Random data # Random data
dtype = torch.float32 dtype = torch.float32
...@@ -502,7 +525,8 @@ class TestFuser: ...@@ -502,7 +525,8 @@ class TestFuser:
# Skip invalid configurations # Skip invalid configurations
in_shape = (size, size) in_shape = (size, size)
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype)
maybe_skip_quantization(quantization, dtype=autocast_dtype)
# Construct operation # Construct operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
...@@ -558,7 +582,7 @@ class TestBasicOps: ...@@ -558,7 +582,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -624,7 +648,7 @@ class TestBasicOps: ...@@ -624,7 +648,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4: if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors") pytest.skip("torch.channels_last only supports 4D tensors")
maybe_skip_quantization(quantization, device=device) maybe_skip_quantization(quantization, device=device, dtype=dtype)
with_quantization = quantization is not None with_quantization = quantization is not None
# Random data # Random data
...@@ -690,7 +714,7 @@ class TestBasicOps: ...@@ -690,7 +714,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -752,7 +776,7 @@ class TestBasicOps: ...@@ -752,7 +776,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device) maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8": if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape) maybe_skip_quantization(quantization, dims=in_shape)
...@@ -819,7 +843,7 @@ class TestBasicOps: ...@@ -819,7 +843,7 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features] out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
quantization_needed = any( quantization_needed = any(
( (
...@@ -899,7 +923,7 @@ class TestBasicOps: ...@@ -899,7 +923,7 @@ class TestBasicOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute or quantized_output or quantized_grad_input: if quantized_compute or quantized_output or quantized_grad_input:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1010,7 +1034,7 @@ class TestBasicOps: ...@@ -1010,7 +1034,7 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features] out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight): if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified") pytest.skip("Quantization scheme is not specified")
...@@ -1077,7 +1101,7 @@ class TestBasicOps: ...@@ -1077,7 +1101,7 @@ class TestBasicOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1114,7 +1138,7 @@ class TestBasicOps: ...@@ -1114,7 +1138,7 @@ class TestBasicOps:
in_shape = list(in_shape)[:-1] + list(weight_shape) in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1175,7 +1199,7 @@ class TestBasicOps: ...@@ -1175,7 +1199,7 @@ class TestBasicOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1284,7 +1308,7 @@ class TestBasicOps: ...@@ -1284,7 +1308,7 @@ class TestBasicOps:
in_shape = list(in_shape)[:-1] + list(weight_shape) in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1337,7 +1361,7 @@ class TestBasicOps: ...@@ -1337,7 +1361,7 @@ class TestBasicOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1417,7 +1441,7 @@ class TestBasicOps: ...@@ -1417,7 +1441,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x1_ref, x1_test = make_reference_and_test_tensors( x1_ref, x1_test = make_reference_and_test_tensors(
...@@ -1456,8 +1480,11 @@ class TestBasicOps: ...@@ -1456,8 +1480,11 @@ class TestBasicOps:
# Check results # Check results
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if with_quantization: if in_place:
tols = dtype_tols(x1_test._fp8_dtype) if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"):
tols = dtype_tols(x1_test._fp8_dtype)
elif quantization == "nvfp4":
tols = dtype_tols(x1_test._fp4_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
...@@ -1486,7 +1513,7 @@ class TestBasicOps: ...@@ -1486,7 +1513,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1559,7 +1586,7 @@ class TestBasicOps: ...@@ -1559,7 +1586,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if cache_quantized_input: if cache_quantized_input:
maybe_skip_quantization("fp8_current_scaling", device=device) maybe_skip_quantization("fp8_current_scaling", device=device)
...@@ -1633,8 +1660,10 @@ class TestBasicOps: ...@@ -1633,8 +1660,10 @@ class TestBasicOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute or cache_quantized_input: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
elif cache_quantized_input:
tols = quantization_tols("fp8_current_scaling")
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1665,7 +1694,7 @@ class TestBasicOps: ...@@ -1665,7 +1694,7 @@ class TestBasicOps:
quantized_compute = quantization is not None quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward): if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided") pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1699,7 +1728,7 @@ class TestBasicOps: ...@@ -1699,7 +1728,7 @@ class TestBasicOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1767,7 +1796,7 @@ class TestBasicOps: ...@@ -1767,7 +1796,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
quantized_input = quantization is not None quantized_input = quantization is not None
maybe_skip_quantization(quantization, dims=shape, device=device) maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype)
# Random data # Random data
# Note: Shift values to make sure inputs are non-zero # Note: Shift values to make sure inputs are non-zero
...@@ -1858,7 +1887,7 @@ class TestFusedOps: ...@@ -1858,7 +1887,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if dtype not in (torch.float16, torch.bfloat16): if dtype not in (torch.float16, torch.bfloat16):
pytest.skip( pytest.skip(
...@@ -1929,7 +1958,7 @@ class TestFusedOps: ...@@ -1929,7 +1958,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1965,7 +1994,7 @@ class TestFusedOps: ...@@ -1965,7 +1994,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2040,7 +2069,7 @@ class TestFusedOps: ...@@ -2040,7 +2069,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2078,7 +2107,7 @@ class TestFusedOps: ...@@ -2078,7 +2107,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2146,7 +2175,7 @@ class TestFusedOps: ...@@ -2146,7 +2175,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2179,7 +2208,7 @@ class TestFusedOps: ...@@ -2179,7 +2208,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device) maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8") pytest.skip("Unsupported tensor size for MXFP8")
...@@ -2241,7 +2270,7 @@ class TestFusedOps: ...@@ -2241,7 +2270,7 @@ class TestFusedOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if with_quantization: if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2360,7 +2389,7 @@ class TestFusedOps: ...@@ -2360,7 +2389,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2428,7 +2457,7 @@ class TestFusedOps: ...@@ -2428,7 +2457,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu") y1_test = y1_test.to(dtype=torch.float64, device="cpu")
...@@ -2463,7 +2492,7 @@ class TestFusedOps: ...@@ -2463,7 +2492,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2523,7 +2552,7 @@ class TestFusedOps: ...@@ -2523,7 +2552,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2564,7 +2593,7 @@ class TestCheckpointing: ...@@ -2564,7 +2593,7 @@ class TestCheckpointing:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
# Construct model # Construct model
...@@ -2690,7 +2719,7 @@ class TestSequentialModules: ...@@ -2690,7 +2719,7 @@ class TestSequentialModules:
ffn_shape = in_shape[:-1] + (ffn_hidden_size,) ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device) maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed: if quantization is None and quantization_needed:
......
...@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init, fp8_model_init,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast from transformer_engine.pytorch.distributed import fp8_autocast
...@@ -499,3 +500,39 @@ class TestFP8Recipe: ...@@ -499,3 +500,39 @@ class TestFP8Recipe:
y = module(x, [batch_size]) y = module(x, [batch_size])
else: else:
y = module(x) y = module(x)
fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available()
@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# # largest tile
(8192, 8192),
],
)
def test_fp4_dequantize(dtype, M, N):
q = NVFP4Quantizer()
a = torch.rand((M, N)).cuda().to(dtype=dtype)
starting_tensor = q(a)
dequantized_tensor = starting_tensor.dequantize()
new_tensor = q(dequantized_tensor)
torch.testing.assert_close(
new_tensor._rowwise_data,
starting_tensor._rowwise_data,
rtol=0,
atol=0,
)
new_dequantized_tensor = new_tensor.dequantize()
torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)
...@@ -87,9 +87,19 @@ model_configs = { ...@@ -87,9 +87,19 @@ model_configs = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1), "large": ModelConfig(2, 128, 4, 128, num_layers=1),
} }
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
fp8_recipes = [] fp8_recipes = []
if mxfp8_available: if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this
if fp8_block_scaling_available: if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling()) fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available: if fp8_available:
...@@ -379,6 +389,8 @@ def test_sanity_layernorm_linear( ...@@ -379,6 +389,8 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -407,6 +419,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -407,6 +419,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
...@@ -437,6 +451,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -437,6 +451,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
...@@ -476,6 +492,8 @@ def test_sanity_grouped_linear( ...@@ -476,6 +492,8 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4():
pytest.skip("NVFP4 not supported for grouped linear")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
...@@ -526,6 +544,8 @@ def test_sanity_layernorm_mlp( ...@@ -526,6 +544,8 @@ def test_sanity_layernorm_mlp(
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -568,6 +588,8 @@ def test_sanity_gpt( ...@@ -568,6 +588,8 @@ def test_sanity_gpt(
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -629,6 +651,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): ...@@ -629,6 +651,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -683,6 +707,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): ...@@ -683,6 +707,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -734,6 +760,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -734,6 +760,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -764,6 +792,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): ...@@ -764,6 +792,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -798,6 +828,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -798,6 +828,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -832,6 +864,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -832,6 +864,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
......
...@@ -73,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: ...@@ -73,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
# Transformer Engine dtypes # Transformer Engine dtypes
if isinstance(dtype, tex.DType): if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat4E2M1:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.25
dtype = { dtype = {
tex.DType.kByte: torch.uint8, tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32, tex.DType.kInt32: torch.int32,
...@@ -95,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: ...@@ -95,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 return dict(rtol=0.25, atol=0.125) # epsilon = 0.125
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
def quantization_tols(name: str) -> dict[str, float]:
"""Estimated numerical error for a quantization scheme"""
if name in (
"fp8",
"fp8_delayed_scaling",
"fp8_current_scaling",
"mxfp8",
"mxfp8_block_scaling",
):
return dtype_tols(tex.DType.kFloat8E4M3)
if name == "nvfp4":
return dtype_tols(tex.DType.kFloat4E2M1)
raise ValueError(f"Unsupported quantization scheme ({name})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]: def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme""" """Make recipe for quantization scheme"""
if name is None: if name is None:
...@@ -118,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ...@@ -118,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
) )
if name == "fp8_block_scaling": if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling() return transformer_engine.common.recipe.Float8BlockScaling()
if name == "nvfp4":
return transformer_engine.common.recipe.NVFP4BlockScaling(
disable_rht=True,
disable_stochastic_rounding=True,
disable_2d_quantization=True,
)
raise ValueError(f"Unsupported quantization scheme ({name})") raise ValueError(f"Unsupported quantization scheme ({name})")
......
...@@ -53,6 +53,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR ...@@ -53,6 +53,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python # Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
endif()
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
if(NOT _MATHDX_LOC_MATCH)
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
# Configure Transformer Engine library # Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..) include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES) set(transformer_engine_SOURCES)
...@@ -73,6 +95,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -73,6 +95,7 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu activation/gelu.cu
dropout/dropout.cu dropout/dropout.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
...@@ -85,6 +108,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -85,6 +108,7 @@ list(APPEND transformer_engine_SOURCES
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu gemm/cutlass_grouped_gemm.cu
normalization/common.cpp normalization/common.cpp
...@@ -113,6 +137,9 @@ list(APPEND transformer_engine_SOURCES ...@@ -113,6 +137,9 @@ list(APPEND transformer_engine_SOURCES
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/userbuffers/userbuffers.cu
...@@ -144,7 +171,8 @@ target_link_libraries(transformer_engine PUBLIC ...@@ -144,7 +171,8 @@ target_link_libraries(transformer_engine PUBLIC
CUDNN::cudnn_all) CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......
...@@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { ...@@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return CUDA_R_8F_E4M3; return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2; return CUDA_R_8F_E5M2;
#if CUDA_VERSION >= 12080
case DType::kFloat4E2M1:
return CUDA_R_4F_E2M1;
#endif
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -160,7 +164,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { ...@@ -160,7 +164,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems, const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits) { const uint32_t offset_elems, const size_t type_num_bits,
const CUtensorMapSwizzle swizzle) {
cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API // Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
...@@ -169,6 +175,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -169,6 +175,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
}(); }();
// rank is the number of dimensions of the array // rank is the number of dimensions of the array
constexpr uint32_t rank = 2; constexpr uint32_t rank = 2;
// Dimension for the packed data types must reflect the number of individual U# values.
uint64_t size[rank] = {globalX, globalY}; uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next // The stride is the number of bytes to traverse from the first element of one row to the next
...@@ -207,7 +215,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -207,7 +215,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
// Swizzling can be used to avoid shared memory bank conflicts. // Swizzling can be used to avoid shared memory bank conflicts.
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, swizzle,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider // L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines. // set of L2 cache lines.
......
...@@ -48,8 +48,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { ...@@ -48,8 +48,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING; return mode == NVTE_DELAYED_TENSOR_SCALING;
} }
inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }
inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) { inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries"); end, " in a vector with ", shape.size(), " entries");
...@@ -108,6 +114,7 @@ struct Tensor { ...@@ -108,6 +114,7 @@ struct Tensor {
SimpleTensor data; SimpleTensor data;
SimpleTensor columnwise_data; SimpleTensor columnwise_data;
SimpleTensor amax; SimpleTensor amax;
SimpleTensor columnwise_amax;
SimpleTensor scale; SimpleTensor scale;
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
...@@ -119,6 +126,7 @@ struct Tensor { ...@@ -119,6 +126,7 @@ struct Tensor {
: data(), : data(),
columnwise_data(), columnwise_data(),
amax(nullptr, {1}, DType::kFloat32), amax(nullptr, {1}, DType::kFloat32),
columnwise_amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
...@@ -129,6 +137,7 @@ struct Tensor { ...@@ -129,6 +137,7 @@ struct Tensor {
data.clear(); data.clear();
columnwise_data.clear(); columnwise_data.clear();
amax.clear(); amax.clear();
columnwise_amax.clear();
scale.clear(); scale.clear();
scale_inv.clear(); scale_inv.clear();
columnwise_scale_inv.clear(); columnwise_scale_inv.clear();
...@@ -174,6 +183,7 @@ struct Tensor { ...@@ -174,6 +183,7 @@ struct Tensor {
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/ */
switch (scaling_mode) { switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING:
case NVTE_DELAYED_TENSOR_SCALING: case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret; std::vector<size_t> ret;
...@@ -189,7 +199,6 @@ struct Tensor { ...@@ -189,7 +199,6 @@ struct Tensor {
} }
break; break;
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape; return columnwise_data.shape;
} else { } else {
...@@ -261,12 +270,18 @@ struct QuantizationConfig { ...@@ -261,12 +270,18 @@ struct QuantizationConfig {
NVTETensor noop_tensor = nullptr; NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format = Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY; Float8BlockScaleTensorFormat::GEMM_READY;
NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
static constexpr size_t attr_sizes[] = { static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales sizeof(bool), // force_pow_2_scales
sizeof(float), // amax_epsilon sizeof(float), // amax_epsilon
sizeof(NVTETensor), // noop_tensor sizeof(NVTETensor), // noop_tensor
sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization
sizeof(bool) // stochastic_rounding
}; };
}; };
...@@ -298,6 +313,8 @@ using fp8e8m0 = __nv_fp8_e8m0; ...@@ -298,6 +313,8 @@ using fp8e8m0 = __nv_fp8_e8m0;
#endif #endif
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif #endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
...@@ -334,17 +351,20 @@ struct TypeExtrema; ...@@ -334,17 +351,20 @@ struct TypeExtrema;
template <> template <>
struct TypeExtrema<fp4e2m1> { struct TypeExtrema<fp4e2m1> {
static constexpr float max = 6.0f; static constexpr float max = 6.0f;
static constexpr float max_inverse = 1.0 / max;
}; };
#endif #endif
template <> template <>
struct TypeExtrema<fp8e4m3> { struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f; static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;
}; };
template <> template <>
struct TypeExtrema<fp8e5m2> { struct TypeExtrema<fp8e5m2> {
static constexpr float max = 57344.0f; static constexpr float max = 57344.0f;
static constexpr float max_inverse = 1.0 / max;
}; };
template <> template <>
...@@ -558,6 +578,18 @@ struct TypeInfo { ...@@ -558,6 +578,18 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat4E2M1: { \
using type = __nv_fp4x2_storage_t; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -717,10 +749,11 @@ void checkCuDriverContext(CUstream stream); ...@@ -717,10 +749,11 @@ void checkCuDriverContext(CUstream stream);
CUtensorMapDataType get_CUtensorMapDataType(DType dtype); CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
// Set up parameters to create TMA descriptor. // Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void create_2D_tensor_map(
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY,
const uint32_t shmemX, const uint32_t stride_elems, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
const uint32_t offset_elems, const size_t type_num_bits); const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits,
const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
bool is_supported_by_CC_100(); bool is_supported_by_CC_100();
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "./config.h"
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cstring>
#include "../util/logging.h"
NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; }
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
void *buf, size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::MatmulConfig *>(config);
switch (attr) {
case kNVTEMatmulConfigBiasTensor:
std::memcpy(buf, &config_.bias_tensor, attr_size);
break;
case kNVTEMatmulConfigDBiasTensor:
std::memcpy(buf, &config_.dbias_tensor, attr_size);
break;
case kNVTEMatmulConfigWithGELUEpilogue:
std::memcpy(buf, &config_.with_gelu_epilogue, attr_size);
break;
case kNVTEMatmulConfigWithDGELUEpilogue:
std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size);
break;
case kNVTEMatmulConfigEpilogueAuxTensor:
std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size);
break;
case kNVTEMatmulConfigUseSplitAccumulator:
std::memcpy(buf, &config_.use_split_accumulator, attr_size);
break;
case kNVTEMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::MatmulConfig *>(config);
switch (attr) {
case kNVTEMatmulConfigBiasTensor:
std::memcpy(&config_.bias_tensor, buf, attr_size);
break;
case kNVTEMatmulConfigDBiasTensor:
std::memcpy(&config_.dbias_tensor, buf, attr_size);
break;
case kNVTEMatmulConfigWithGELUEpilogue:
std::memcpy(&config_.with_gelu_epilogue, buf, attr_size);
break;
case kNVTEMatmulConfigWithDGELUEpilogue:
std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size);
break;
case kNVTEMatmulConfigEpilogueAuxTensor:
std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size);
break;
case kNVTEMatmulConfigUseSplitAccumulator:
std::memcpy(&config_.use_split_accumulator, buf, attr_size);
break;
case kNVTEMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::MatmulConfig *>(config);
}
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#include <transformer_engine/transformer_engine.h>
namespace transformer_engine {
struct MatmulConfig {
NVTETensor bias_tensor = nullptr;
NVTETensor dbias_tensor = nullptr;
bool with_gelu_epilogue = false;
bool with_dgelu_epilogue = false;
NVTETensor epilogue_aux_tensor = nullptr;
bool use_split_accumulator = false;
int sm_count = 0;
static constexpr size_t attr_sizes[] = {
sizeof(NVTETensor), // bias_tensor
sizeof(NVTETensor), // dbias_tensor
sizeof(bool), // with_gelu_epilogue
sizeof(bool), // with_dgelu_epilogue
sizeof(NVTETensor), // epilogue_aux_tensor
sizeof(bool), // use_split_accumulator
sizeof(int) // sm_count
};
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
...@@ -9,20 +9,55 @@ ...@@ -9,20 +9,55 @@
#include <cuda.h> #include <cuda.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <algorithm>
#include <cstdint> #include <cstdint>
#include <mutex> #include <mutex>
#include <vector>
#include "../common.h" #include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h" #include "../util/handle_manager.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "../util/multi_stream.h" #include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h" #include "./config.h"
#include "cutlass_grouped_gemm.cuh" #include "./cutlass_grouped_gemm.cuh"
namespace { namespace {
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__ __constant__ float one_device;
__device__ __constant__ float zero_device;
inline float *GetScalarOne() {
static std::once_flag init_flag;
std::call_once(init_flag, []() {
float one = 1.0f;
NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float)));
});
// return address by cudaGetSymbolAddress
float *dev_ptr;
NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), one_device));
return dev_ptr;
}
inline float *GetScalarZero() {
static std::once_flag init_flag;
std::call_once(init_flag, []() {
float zero = 0.0f;
NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float)));
});
// return address by cudaGetSymbolAddress
float *dev_ptr;
NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), zero_device));
return dev_ptr;
}
__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; }
uint32_t _getAlignment(uintptr_t address) { uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes // alignment are in bytes
uint32_t alignment = 256; uint32_t alignment = 256;
...@@ -82,6 +117,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -82,6 +117,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
bool is_A_transposed = transA == CUBLAS_OP_T; bool is_A_transposed = transA == CUBLAS_OP_T;
bool is_B_transposed = transB == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T;
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
// Configure A matrix // Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) { if (is_tensor_scaling(A.scaling_mode)) {
// Unscaled or FP8 tensor scaling // Unscaled or FP8 tensor scaling
...@@ -102,10 +141,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -102,10 +141,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
} }
} }
} else if (is_mxfp_scaling(A.scaling_mode)) { } else if (nvfp4) {
// MXFP8 // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode),
"Input A has unsupported combination of recipe and layout");
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
}
ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T; // NVFP4 gemm is only supported in TN layout.
ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = k;
} else if (mxfp8) {
// MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
// Note: Row-wise and column-wise data are scaled along different // Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order). // dimensions (with matrix interpreted in row-major order).
if (is_A_transposed) { if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else { } else {
...@@ -161,10 +216,20 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -161,10 +216,20 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
} }
} }
} else if (is_mxfp_scaling(B.scaling_mode)) { } else if (nvfp4) {
// MXFP8 if (is_B_transposed) {
// Note: Row-wise and column-wise data are scaled along different NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
// dimensions (with matrix interpreted in row-major order). "Input B has unsupported combination of recipe and layout");
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = CUBLAS_OP_N; // NVFP4 gemm is only supported in TN layout.
ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;
} else if (mxfp8) {
if (is_B_transposed) { if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else { } else {
...@@ -221,7 +286,7 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas ...@@ -221,7 +286,7 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
float alpha, float beta, bool use_split_accumulator, int math_sm_count, const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count,
int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
cudaStream_t stream) { cudaStream_t stream) {
// Tensor dims in row-major order // Tensor dims in row-major order
...@@ -260,6 +325,49 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -260,6 +325,49 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
} }
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype); const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);
// Update scaling factors with NVFP4 tensor scales
// TODO: Check whether scales are on CPU/GPU or add API to control.
// Currently scales are assumed to be on CPU when amax is provided
// and on GPU when not provided, but this is brittle.
if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) {
// Reserve some workspace for alpha scale
NVTE_CHECK(workspaceSize >= 4,
"NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
workspaceSize, " bytes remaining.");
workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
float *new_alpha_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
// Update alpha scale on device
// Note: Compute NVFP4 tensor scales based on amaxes and then
// divide from alpha scale. This way we only need to apply NVFP4
// tensor scales in matmul output, instead of in matmul inputs.
float old_alpha = *reinterpret_cast<const float *>(alpha); // Assumed to be on CPU
TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb,
old_alpha, new_alpha_tensor.data(), stream);
alpha = new_alpha_ptr;
// Make sure beta scale is on device
float old_beta = *reinterpret_cast<const float *>(beta); // Assumed to be on CPU
if (old_beta == 0) {
beta = GetScalarZero(); // Device constant memory
} else if (old_beta == 1) {
beta = GetScalarOne(); // Device constant memory
} else {
// Move beta to workspace
NVTE_CHECK(workspaceSize >= 4,
"NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ",
workspaceSize, " bytes remaining.");
workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes
float *new_beta_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta);
NVTE_CHECK_CUDA(cudaGetLastError());
beta = new_beta_ptr;
}
}
const cudaDataType_t A_type = get_cuda_dtype(param.Atype); const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
const cudaDataType_t B_type = get_cuda_dtype(param.Btype); const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
...@@ -270,16 +378,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -270,16 +378,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP4 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP4 input to GEMM requires inverse of scale!");
// check consistency of arguments: // check consistency of arguments:
// if fp8 is desired, context cannot be null // if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now. // fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 && gelu) { if ((use_fp8 || use_fp4) && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!"); "fp8 Aux output for gemm + gelu fusion not supported!");
} }
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp4_dtype(outputD->data.dtype)) {
NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); NVTE_ERROR("FP4 GEMM output is not supported!");
}
if (use_fp4 && (D_type == CUDA_R_16F)) {
NVTE_ERROR("FP4 GEMM does not support FP16 output!");
} }
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
...@@ -319,12 +434,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -319,12 +434,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&math_sm_count, sizeof(math_sm_count))); &math_sm_count, sizeof(math_sm_count)));
} }
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4
// Note: gelu fusion isn't available right now, and we don't need // as appropriate. Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision). // amax(D) either (next op is high precision).
if (use_fp8) { const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode);
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; if (use_fp8 || use_fp4) {
// Fast accumulation is only supported for FP8.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode))); &fastAccuMode, sizeof(fastAccuMode)));
...@@ -333,7 +450,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -333,7 +450,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_a;
cublasLtMatmulMatrixScale_t scaling_mode_b; cublasLtMatmulMatrixScale_t scaling_mode_b;
#endif // CUBLAS_VERSION >= 120800 #endif // CUBLAS_VERSION >= 120800
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) {
void *A_scale_inverse = param.A_scale_inv; void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv; void *B_scale_inverse = param.B_scale_inv;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -346,7 +463,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -346,7 +463,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
#endif // CUBLAS_VERSION >= 120800 #endif // CUBLAS_VERSION >= 120800
} else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { } else if (mxfp8_gemm) {
#if CUBLAS_VERSION >= 120800 #if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800, NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
...@@ -371,6 +488,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -371,6 +488,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#else #else
NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
CUBLAS_VERSION); CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800
} else if (use_fp4) { // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
// make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
cublasDataType_t scale_type = CUDA_R_32F;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
// Set pointer mode: alpha and beta are both device pointers
// https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800 #endif // CUBLAS_VERSION >= 120800
} else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
...@@ -503,14 +648,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -503,14 +648,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION); CUDA_VERSION);
#endif #elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR( NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION); CUBLAS_VERSION);
#endif #else
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version()); cuda::cudart_version());
...@@ -565,16 +707,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -565,16 +707,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */
static_cast<const void *>(&alpha), /* alpha */ param.A, /* A */
param.A, /* A */ Adesc, param.B, /* B */
Adesc, param.B, /* B */ Bdesc, beta, /* beta */
Bdesc, static_cast<const void *>(&beta), /* beta */ C, /* C */
C, /* C */ Cdesc, D, /* D */
Cdesc, D, /* D */ Ddesc, &heuristicResult.algo, /* algo */
Ddesc, &heuristicResult.algo, /* algo */ workspace, /* workspace */
workspace, /* workspace */ workspaceSize, stream)); /* stream */
workspaceSize, stream)); /* stream */
// Update FP8 scale-inv in output tensor // Update FP8 scale-inv in output tensor
// Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
...@@ -600,35 +741,117 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -600,35 +741,117 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int math_sm_count, cudaStream_t stream) { int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
// Tensors
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D); Tensor *outputD = convertNVTETensorCheck(D);
const Tensor *biasTensor = convertNVTETensor(bias); const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace); Tensor *wspace = convertNVTETensor(workspace);
// Scales
const float alpha = 1;
const float beta = accumulate ? 1 : 0;
// Check for NVFP4
// TODO Remove once alpha scale logic is moved into cublas_gemm function
if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
}
// Launch GEMM
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
nullptr, stream); }
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm_v2);
using namespace transformer_engine;
// Data tensors
const Tensor *A_tensor = convertNVTETensorCheck(A);
const Tensor *B_tensor = convertNVTETensorCheck(B);
const Tensor *C_tensor = convertNVTETensorCheck(C);
Tensor *D_tensor = convertNVTETensorCheck(D);
NVTE_CHECK(C_tensor == D_tensor,
"Currently nvte_cublas_gemm_v2 does not support different C and D tensors.");
// Workspace
void *workspace_ptr = nullptr;
size_t workspace_size = 0;
Tensor *workspace_tensor = convertNVTETensor(workspace);
if (workspace_tensor != nullptr) {
workspace_ptr = workspace_tensor->data.dptr;
workspace_size =
get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype);
}
// Additional config
MatmulConfig config_;
if (config != nullptr) {
config_ = *reinterpret_cast<MatmulConfig *>(config);
}
// Configure GEMM epilogue
const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue);
if (with_grad_epilogue) {
NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue,
"Invalid epilogue (bias=", config_.bias_tensor != nullptr,
", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
", dgelu=", config_.with_dgelu_epilogue, ").");
}
Tensor dummy_tensor;
Tensor *epilogue_bias_tensor = &dummy_tensor;
if (!with_grad_epilogue && config_.bias_tensor != nullptr) {
epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor);
} else if (with_grad_epilogue && config_.dbias_tensor != nullptr) {
epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor);
}
Tensor *epilogue_aux_tensor = &dummy_tensor;
if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) {
NVTE_CHECK(config_.epilogue_aux_tensor != nullptr,
"Requested epilogue (bias=", config_.bias_tensor != nullptr,
", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor.");
epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor);
}
// Launch GEMM
cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor,
transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N,
with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta,
config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream);
} }
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm_scaled); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
// Tensors
const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D); Tensor *outputD = convertNVTETensorCheck(D);
const Tensor *biasTensor = convertNVTETensor(bias); const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace); Tensor *wspace = convertNVTETensor(workspace);
// Check for NVFP4
// TODO Remove once alpha scale logic is moved into cublas_gemm function
if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
}
// Launch GEMM
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
} }
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...@@ -639,17 +862,14 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -639,17 +862,14 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm); NVTE_API_CALL(nvte_cublas_atomic_gemm);
using namespace transformer_engine; using namespace transformer_engine;
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION); CUDA_VERSION);
#endif #elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR( NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION); CUBLAS_VERSION);
#endif #else
NVTE_CHECK( NVTE_CHECK(
transformer_engine::cuda::cudart_version() >= 12020 && transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000, transformer_engine::cuda::cudart_version() < 13000,
...@@ -668,13 +888,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -668,13 +888,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const Tensor *inputCounter = convertNVTETensor(counter); const Tensor *inputCounter = convertNVTETensor(counter);
Tensor *wspace = convertNVTETensor(workspace); Tensor *wspace = convertNVTETensor(workspace);
const void *alpha_ptr = GetScalarOne();
const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero();
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode), is_delayed_tensor_scaling(inputB->scaling_mode),
"Atomic GEMM only supports delayed scaling."); "Atomic GEMM only supports delayed scaling.");
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split,
n_split, gemm_producer, inputCounter, stream); gemm_producer, inputCounter, stream);
#endif
} }
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
...@@ -695,9 +919,30 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens ...@@ -695,9 +919,30 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens
} }
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, // Check whether GELU or dGELU epilogue is requested
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
detail::get_compute_stream(i % num_streams)); bool with_gelu_dgelu_epilogue =
(pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);
// Construct config
MatmulConfig config;
if (grad) {
config.dbias_tensor = bias[i];
config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
} else {
config.bias_tensor = bias[i];
config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
}
config.epilogue_aux_tensor = pre_gelu_out[i];
config.use_split_accumulator = use_split_accumulator;
config.sm_count = math_sm_count;
// Launch GEMM
const float alpha = 1.f;
const float beta = accumulate ? 1.f : 0.f;
nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
workspace[i % num_streams], &config,
detail::get_compute_stream(i % num_streams));
} }
// record events on compute streams // record events on compute streams
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace transformer_engine {
namespace {
constexpr int kThreadsPerWarp = 32;
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment
int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);
int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;
if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_t_reg)
: "r"(local_amax_t_reg), "r"(temp_amax_t_reg));
}
if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[1]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[2])
: "r"(a_frag[2]), "r"(a_frag[3]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[2]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}
}
template <int kN>
__device__ __host__ constexpr int NextPowerOf2() {
static_assert(kN > 0, "kN must be > 0");
// Round up to the next power of 2 by counting leading zeros.
return 1 << (32 - __builtin_clz(kN - 1));
}
template <int kNumWarps, bool kReturnPreRhtAmax, bool kReturnIdentityAmax,
bool kReturnTransposedAmax>
__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax,
const float transpose_amax, float* staging_for_pre_rht,
float* staging_for_identity, float* staging_for_transpose,
float* output_pre_rht_amax_ptr,
float* output_identity_amax_ptr,
float* output_transpose_amax_ptr, const int warpid) {
// intra-warp reduction
constexpr int kWarpSize = 32;
int local_rank = threadIdx.x % 32;
float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max<kWarpSize>(pre_rht_amax) : 0.0f;
float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max<kWarpSize>(identity_amax) : 0.0f;
float warp_transpose_amax =
kReturnTransposedAmax ? warp_reduce_max<kWarpSize>(transpose_amax) : 0.0f;
// inter-warp reduction
if (threadIdx.x % 32 == 0) {
if (kReturnPreRhtAmax) {
staging_for_pre_rht[warpid] = warp_pre_rht_amax;
}
if (kReturnIdentityAmax) {
staging_for_identity[warpid] = warp_identity_amax;
}
if (kReturnTransposedAmax) {
staging_for_transpose[warpid] = warp_transpose_amax;
}
}
__syncthreads();
constexpr int kNumWarpsPow2 = NextPowerOf2<kNumWarps>();
if (warpid == 0) {
if (kReturnIdentityAmax) {
float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f;
identity_accum = warp_reduce_max<kNumWarpsPow2>(identity_accum);
if (local_rank == 0) {
atomicMaxFloat(output_identity_amax_ptr, identity_accum);
}
}
}
if (warpid == 1) {
if (kReturnTransposedAmax) {
float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f;
transpose_accum = warp_reduce_max<kNumWarpsPow2>(transpose_accum);
if (local_rank == 0) {
atomicMaxFloat(output_transpose_amax_ptr, transpose_accum);
}
}
}
if (warpid == 2) {
if (kReturnPreRhtAmax) {
float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f;
pre_rht_accum = warp_reduce_max<kNumWarpsPow2>(pre_rht_accum);
if (local_rank == 0) {
atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum);
}
}
}
}
__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
float* __restrict__ output_transpose_amax_ptr) {
if (output_pre_rht_amax_ptr != nullptr) {
*output_pre_rht_amax_ptr = 0;
}
if (output_identity_amax_ptr != nullptr) {
*output_identity_amax_ptr = 0;
}
if (output_transpose_amax_ptr != nullptr) {
*output_transpose_amax_ptr = 0;
}
}
template <typename IType, int kHadamardDimension, int CHUNK_DIM_Y, int CHUNK_DIM_X, int BUFF_DIM_Y,
int BUFF_DIM_X, int THREADS_PER_CHUNK, int THREADS_PER_Y, bool kReturnPreRhtAmax,
bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input,
float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
float* __restrict__ output_transpose_amax_ptr,
uint16_t random_sign_mask, uint16_t random_sign_mask_t,
uint64_t num_rows, uint64_t row_length) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0);
static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0);
constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X;
constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp;
const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X;
extern __shared__ __align__(128) char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t* dshmem = reinterpret_cast<uint8_t*>((base_shmem_ptr + 127) & ~127ULL);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
IType* in_sh_0 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_sh_1 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_shs[2] = {in_sh_0, in_sh_1};
constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t* mbar = reinterpret_cast<uint64_t*>(dshmem);
dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y);
float* max_staging_identity = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_transpose = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_pre_rht = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
initialize_barriers<STAGES_X * STAGES_Y, THREADS_PER_CHUNK * THREADS_PER_Y>(mbar,
is_master_thread);
copy_2d_to_shared(in_shs[0], reinterpret_cast<const void*>(&tensor_map_input),
input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0],
is_master_thread);
uint32_t had_frag_i[4];
uint32_t had_frag_t[4];
get_hadamard_matrix_fragment<kReturnIdentityAmax, kReturnTransposedAmax, false, false>(
had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t);
float local_pre_rht_amax = 0.0;
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_pre_rht_amax_reg = *reinterpret_cast<uint32_t*>(&local_pre_rht_amax);
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
const int next_stage = stage + 1;
const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1;
const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y;
if (next_stage < STAGES_X * STAGES_Y) {
const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y;
const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X;
copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong
reinterpret_cast<const void*>(&tensor_map_input), input_global_offset_X,
input_global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
const size_t compute_stage_x_num =
BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp));
const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y);
const size_t in_row_stride = BUFF_DIM_X;
IType* in_sh_ptr = in_shs[stage % 2];
#pragma unroll
for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) {
const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y +
threadIdx.y * kHadamardDimension);
const int in_row_offset = row_idx_offset * in_row_stride;
#pragma unroll
for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) {
ComputeKernel<IType, kHadamardDimension, BUFF_DIM_Y, BUFF_DIM_X, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads();
}
}
}
const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp;
if constexpr (kReturnPreRhtAmax) {
unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax);
}
if constexpr (kReturnIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
}
if constexpr (kReturnTransposedAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
}
ReduceMax<kNumWarps, kReturnPreRhtAmax, kReturnIdentityAmax, kReturnTransposedAmax>(
local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity,
max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr,
output_transpose_amax_ptr, warpid);
destroy_barriers<STAGES_X * STAGES_Y>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <typename T, int kHadamardDimension, bool kComputeIdentity, bool kComputeTransposed,
bool kReturnIdentity, bool kReturnTransposed, bool kUpdateIdentityAmax,
bool kUpdateTransposeAmax, bool kOutputTrueTransposed>
__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output,
T* __restrict__ output_t, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, uint64_t num_input_rows,
uint64_t num_input_cols, float* __restrict__ amax,
float* __restrict__ amax_t, bool inverse_hadamard) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported.");
// The whole threadblock will share the same smem.
extern __shared__ __align__(16) T smem[];
// Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16.
// If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices.
int32_t tid = threadIdx.x;
int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z;
int32_t local_bx = threadIdx.y;
int32_t local_by = threadIdx.z;
// Define the register fragments
uint32_t a_frag[4]; // A matrix fragment
uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major)
uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major)
uint32_t c_frag[4]; // Result fragment
// row and col for each thread. 32 threads will work together in 128 chunk to
// load the data from global memory to shared memory.
uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4));
uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4));
uint32_t smem_index = tid;
uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension;
uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension;
bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows);
if (!load) {
// Out of bound, we are returning early. No thread divergence since the whole warp
// will return early.
return;
}
uint64_t global_offset = input_start_col + input_start_row * num_input_cols;
uint64_t global_offset_t =
kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset;
T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id;
uint32_t* smem_b32 = reinterpret_cast<uint32_t*>(base_smem);
uint4* smem_b128 = reinterpret_cast<uint4*>(base_smem);
// Asynchronously load the data from global memory to shared memory.
const uint4* input_b128 = reinterpret_cast<const uint4*>(input + global_offset);
// Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each
// 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4
// to load the data in the tensor core swizzled format.
__pipeline_memcpy_async(&smem_b128[smem_index],
&input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col],
sizeof(uint4));
__pipeline_commit(); // Commit the memcpy. Wait when we are in the computation.
if (inverse_hadamard) {
get_hadamard_matrix_fragment<kComputeIdentity, kComputeTransposed,
/*kInverseHadamard=*/true,
/*kInverseHadamardTransposed=*/true>(b_frag_i, random_sign_mask,
b_frag_t, random_sign_mask_t);
} else {
get_hadamard_matrix_fragment<kComputeIdentity, kComputeTransposed,
/*kInverseHadamard=*/false,
/*kInverseHadamardTransposed=*/false>(
b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t);
}
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
__pipeline_wait_prior(0);
__syncwarp(); // ensure all lanes finished their cp.async before reading smem
// Load the A to a_frag.
if constexpr (kComputeIdentity) {
load_matrix_16x16_from_shared<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32,
kHadamardDimension);
// 16x16 @ 16x16 leveraging all threads in the warp.
mma_m16_n16_k16_b16_b16_b16_noacc<kUpdateIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg);
// Store the result to the shared memory in non-transposed order.
if constexpr (kReturnIdentity) {
uint4* output_b128 = reinterpret_cast<uint4*>(output + global_offset);
store_matrix_16x16_to_global<false>(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128,
num_input_cols);
}
}
if constexpr (kComputeTransposed) {
if (kComputeIdentity) {
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
} else {
load_matrix_16x16_from_shared<true>(a_frag[0],
a_frag[2], // NOTE: intentional index swapping
a_frag[1], // NOTE: intentional index swapping
a_frag[3], smem_b32, kHadamardDimension);
}
mma_m16_n16_k16_b16_b16_b16_noacc<kUpdateTransposeAmax>(
a_frag[0],
// 2,1 is used if we are using movmatrix instruction.
// Thus loading the matrix in 2,1 order will just be normal.
// This is to be compatible with the movmatrix instruction.
a_frag[2], // NOTE: intentional index swapping for transpose purpose.
a_frag[1], // NOTE: intentional index swapping for transpose purpose.
a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1],
c_frag[2], c_frag[3], local_amax_t_reg);
// Store the result to the shared memory in non-transposed order.
if constexpr (kReturnTransposed) {
uint4* output_t_b128 = reinterpret_cast<uint4*>(output_t + global_offset_t);
store_matrix_16x16_to_global<!kOutputTrueTransposed>(
c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128,
kOutputTrueTransposed ? num_input_rows : num_input_cols);
}
}
if constexpr (kUpdateIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
local_amax = warp_reduce_max<kThreadsPerWarp>(local_amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero);
// atomic CAS to output memory.
if (tid % kThreadsPerWarp == 0) {
atomicMaxFloat(amax, local_amax);
}
}
if constexpr (kUpdateTransposeAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
local_amax_t = warp_reduce_max<kThreadsPerWarp>(local_amax_t);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero);
// atomic CAS to output memory.
if (tid % kThreadsPerWarp == 0) {
atomicMaxFloat(amax_t, local_amax_t);
}
}
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+.");
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
}
} // namespace
void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform);
// Check tensors
// NOTE (frsun): This is non-intuitive, we are writing the result of
// transposed RHT to the output of rowwise.
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be simple tensor, but scaling mode is ",
to_string(output_.scaling_mode), ".");
const SimpleTensor& input = input_.data;
SimpleTensor output;
SimpleTensor& output_t = output_.data;
// Check requested outputs
const bool return_identity = output.dptr != nullptr;
const bool return_transposed = output_t.dptr != nullptr;
if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior.
return;
}
checkCuDriverContext(stream);
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
using IType = bf16;
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
constexpr uint64_t kThreadBlockX = 4;
// Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth.
constexpr uint64_t kThreadBlockY = 4;
uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY;
// The shared memory number of bytes required for **the whole threadblock**.
size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM;
dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY);
dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX),
DIVUP(num_rows / kHadamardDimension, kThreadBlockY));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transposed, kReturnTransposed,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
auto kernel =
HadamardTransformKernel<IType, kHadamardDimension, kReturnIdentity, kReturnTransposed,
kReturnIdentity, kReturnTransposed, false, false, true>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
reinterpret_cast<const IType*>(input.dptr), reinterpret_cast<IType*>(output.dptr),
reinterpret_cast<IType*>(output_t.dptr), random_sign_mask, random_sign_mask_t,
num_rows, row_length, nullptr, nullptr, false);););
NVTE_CHECK_CUDA(cudaGetLastError());
}
// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then
// get the absolute max value of the result.
void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform_amax);
#if CUDA_VERSION >= 12080
// Check input tensor
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor& input = input_.data;
// Check amax tensors
SimpleTensor& output_pre_rht_amax = output_.amax;
SimpleTensor output_identity_amax;
SimpleTensor& output_transpose_amax = output_.columnwise_amax;
// Check requested outputs
const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr;
const bool return_identity_amax = output_identity_amax.dptr != nullptr;
const bool return_transposed_amax = output_transpose_amax.dptr != nullptr;
if (!return_identity_amax && !return_transposed_amax &&
!return_pre_rht_amax) { // Nothing to do/ill-defined behavior.
return;
}
// Zero out amaxes if needed
ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast<float*>(output_pre_rht_amax.dptr),
reinterpret_cast<float*>(output_identity_amax.dptr),
reinterpret_cast<float*>(output_transpose_amax.dptr));
NVTE_CHECK_CUDA(cudaGetLastError());
checkCuDriverContext(stream);
using IType = bf16;
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
constexpr uint64_t kChunkBlockXSmall = 128;
constexpr uint64_t kChunkBlockYSmall = 128;
constexpr uint64_t kBuffDimX = 64;
constexpr uint64_t kBuffDimY = 64;
alignas(64) CUtensorMap tensor_map_input{};
create_2D_tensor_map(
/*tensorMap=*/tensor_map_input,
/*tensor=*/input,
/*globalY=*/num_rows,
/*globalX=*/row_length,
/*shmemY=*/kBuffDimY,
/*shmemX=*/kBuffDimX,
/*stride_elems=*/row_length,
/*offset_elems=*/0,
/*type_num_bits=*/sizeof(IType) * 8,
/*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B);
constexpr uint64_t kThreadBlockX = 4;
constexpr uint64_t kThreadBlockY = 1;
constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY;
dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY);
dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transposed_amax, kReturnTransposedAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity_amax, kReturnIdentityAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_pre_rht_amax, kReturnPreRhtAmax,
// *2 for ping-pong
size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType);
size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) *
(kChunkBlockYSmall / kBuffDimY);
size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes = (shmem_bytes + 128);
auto kernel = HadamardAmaxTmaKernel<
IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY,
kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
tensor_map_input, reinterpret_cast<float*>(output_pre_rht_amax.dptr),
reinterpret_cast<float*>(output_identity_amax.dptr),
reinterpret_cast<float*>(output_transpose_amax.dptr), random_sign_mask,
random_sign_mask_t, num_rows, row_length);)));
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ",
CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace transformer_engine
void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform);
using namespace transformer_engine;
hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), stream);
}
void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform_amax);
using namespace transformer_engine;
hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
namespace transformer_engine {
namespace detail {
namespace {
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread());
using namespace cute;
using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
// calculate the global encode scale factor for a given global amax.
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float kFP8E4M3Max = 448.0f;
constexpr float kFP4E2M1Max = 6.0f;
// If scale is infinity, return max value of float32
float global_encode_scale = cutlass::minimum_with_nan_propagation<float>{}(
kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits<float>::max());
// If global amax is 0 or infinity, return 1
return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale;
}
template <class ElementA,
class ElementB,
class ASmemLayout,
class BSmemLayout>
struct SharedStorage {
static constexpr int AccumulatorPipelineStageCount = 16;
using AtomThrShapeMNK = cute::Shape<_1, _1, _1>;
using AccumulatorPipeline = cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
MainloopPipelineStageCount,
Shape<_1,_1,_1>,
AtomThrShapeMNK>;
using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage;
alignas(16) AccumulatorPipelineStorage accumulator;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) cute::uint64_t tma_barrier[1];
uint32_t tmem_base_ptr;
struct TensorStorage : cute::aligned_struct<128, _1> {
// cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementB, cute::cosize_v<BSmemLayout>> smem_B;
} tensors;
};
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile( \
"{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
"}" \
: "=h"(output_ptr[0]),
"=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1]));
#else
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return output;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 16>
StochasticNumericConverter(cutlass::Array<float, 16> const &input, cutlass::Array<uint32_t, 4> const *rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 16>;
result_type output;
cutlass::Array<cutlass::float_e2m1_t, 8> *result_ptr = reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8> *>(&output);
cutlass::Array<float, 8> const *source_ptr = reinterpret_cast<cutlass::Array<float, 8> const *>(&input);
cutlass::Array<uint32_t, 2> const *rbits_ptr = reinterpret_cast<cutlass::Array<uint32_t, 2> const *>(rbits);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; i++) {
result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]);
}
return output;
}
template <class MShape, class NShape, class KShape, class ClusterTileShape,
class TA, class AStride, class ASmemLayout, class TmaLoadA,
class TB, class BStride, class BSmemLayout, class TmaLoadB,
class TC, class CStride, class CSmemLayout,
class TSFC,
class TiledMMA,
bool kEnableStochasticRounding = false>
__global__ static
void
rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a,
TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b,
TC * C, CStride dC, CSmemLayout ,
TSFC * SFC,
TiledMMA mma,
float const* global_amax,
const size_t* rng_state)
{
using namespace cute;
using X = Underscore;
// static constexpr bool kApplyStochasticRounding = true;
using ElementAccumulator = float;
static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{});
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>;
static constexpr uint32_t kTmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v<TA>);
static constexpr int kTmaRhtTensorTransactionBytes =
cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v<TB>);
static constexpr int AccumulatorPipelineStageCount = 16;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
MainloopPipelineStageCount,
Shape<_1,_1,_1>,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
using TmemAllocator = cute::TMEM::Allocator1Sm;
static constexpr int VectorSize = 16;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ASmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<BSmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<CSmemLayout>::value);
// Represent the full tensors
Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N));
Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16));
Tensor mC = make_tensor(cute::subbyte_iterator<TC>(C), make_shape(M,N), dC); // (M,N)
auto sfc_shape = make_shape(
M,
make_shape( make_shape(Int<16>{}, _4{}), N / 64 )
);
auto sfc_stride = make_stride(
N / 16,
make_stride( make_stride(_0{}, _1{}), _4{} )
);
auto sfc_layout = make_layout(sfc_shape, sfc_stride);
Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout);
auto cluster_shape = Shape< _1, _1, _1>{};
// Get the appropriate blocks for this Cluster
dim3 cluster_coord_in_grid = cluster_id_in_grid();
// Total number of k-tiles
const int K_TILE_MAX = min(N, K) / 64;
uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile);
uint32_t tiles_in_n = (N + 64 - 1) / 64;
uint32_t linear_tile_idx = blockIdx.x;
uint32_t tile_idx_m = linear_tile_idx % tiles_in_m;
uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
auto mainloop_tiler = Shape<_128,_16,_64>{};
auto epilogue_tiler = Shape<_128,_64,_64>{};
Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{});
Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Allocate SMEM
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
//
// MMA: Define C accumulators and A/B partitioning
//
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k)
auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, ElementAccumulator,
128, 64,
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1,_1>>{});
ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster);
using TiledMmaEpilogue = decltype(mma_epilogue);
Tensor tCgA = thr_mma.partition_A(gA_mk);
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE)
auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{}));
auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler));
auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma,
Int<AccumulatorPipelineStageCount>{}));
auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue,
Int<AccumulatorPipelineStageCount / 4>{}));
TmemAllocator tmem_allocator{};
cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
Layout cta_layout_mnk = make_layout(cluster_shape);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster);
auto [tAgA, tAsA] = tma_partition(tma_load_a,
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
auto [tBgB, tBsB] = tma_partition(tma_load_b,
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
int warp_idx = cutlass::canonical_warp_idx_sync();
bool is_mma_warp = (warp_idx == 0);
bool is_dma_warp = (warp_idx == 1);
bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7);
if (is_epilogue_warp && elect_one_sync()) {
cute::prefetch(raw_pointer_cast(global_amax));
}
typename MainloopPipeline::Params mainloop_pipeline_params;
if (is_dma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (is_mma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp;
mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes;
mainloop_pipeline_params.initializing_warp = 0;
MainloopPipeline mainloop_pipeline(shared_storage.mainloop,
mainloop_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
MainloopPipelineState mainloop_pipe_consumer_state;
MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
using AccumulatorPipeline = cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState;
AccumulatorPipelineState accumulator_pipe_consumer_state;
AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state<AccumulatorPipeline>();
typename AccumulatorPipeline::Params accumulator_pipeline_params;
if (is_mma_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer;
}
if (is_epilogue_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params.producer_arv_count = 1;
accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128;
accumulator_pipeline_params.initializing_warp = 1;
AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator,
accumulator_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
if (warp_idx == 2 && elect_one_sync()) {
cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1);
}
__syncthreads();
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
if (is_dma_warp) {
if (elect_one_sync()) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes);
copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0));
}
cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/);
do {
bool is_first_wave = linear_tile_idx == blockIdx.x;
uint32_t skip_wait = is_first_wave;
auto tAgA_mk = tAgA(_,tile_idx_m,_);
int k_tile = 0;
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
CUTE_NO_UNROLL
while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) {
int k_tile_idx_n = tile_idx_n + k_tile;
++k_tile;
skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount);
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
if (cute::elect_one_sync()) {
copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage));
}
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
} else if (is_mma_warp) {
mma.accumulate_ = UMMA::ScaleOut::Zero;
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
tmem_allocation_result_barrier.arrive();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_mma.data() = tmem_base_ptr;
do {
uint32_t skip_wait = K_TILE_MAX <= 0;
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
CUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; )
{
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
int read_stage = mainloop_pipe_consumer_state.index();
auto tCrA_mk = tCrA(_,_,_,read_stage);
auto tCrB_nk = tCrB(_,_,0,0);
CUTE_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block)
{
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTE_UNROLL
for (int i = 0; i < 4; i++) {
auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i);
gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators);
}
accumulator_pipeline.producer_commit(accumulator_pipe_producer_state);
++accumulator_pipe_producer_state;
}
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
++mainloop_pipe_consumer_state;
++k_tile;
skip_wait = k_tile >= K_TILE_MAX;
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
tmem_allocator.release_allocation_lock();
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} else if (is_epilogue_warp) {
const float global_amax_val = *global_amax;
static constexpr int FragmentSize = 256 / sizeof_bits_v<TC>;
tmem_allocation_result_barrier.arrive_and_wait();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_epilogue.data() = tmem_base_ptr;
int thread_idx = threadIdx.x % 128;
Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N)
auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{}));
auto tiled_r2g = make_tiled_copy_D(Copy_Atom<SM100_STORE_256bit_CACHE_NOALLOCATION, TC>{}, tiled_t2r);
auto thr_t2r = tiled_t2r.get_slice(thread_idx);
auto thr_r2g = tiled_r2g.get_slice(thread_idx);
// NVFP4 non-E8 recipe constants and global scales
static constexpr float fp4_max = 6.0f;
const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
const float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile);
Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile);
accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state);
auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index());
Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDrC = make_tensor<TC>(shape(tDgC));
Tensor tTR_rAcc_frag = recast<cutlass::Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
Tensor tDrC_frag = recast<cutlass::Array<TC, FragmentSize>>(coalesce(tDrC));
Tensor src = thr_r2g.retile_S(tDrC);
Tensor dst = thr_r2g.retile_D(tDgC);
Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout(
make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}),
make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{})
));
Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC));
Tensor tDrSFC = make_tensor<TSFC>(shape(tDgSFC));
static constexpr int NumVecs = size(tDgC) / VectorSize;
Tensor tC_rRowSFD_frg = recast<cutlass::Array<TSFC, NumVecs>>(tDrSFC);
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>, true> amax_reduction;
cutlass::Array<ElementAccumulator, NumVecs> vec_maxs;
cutlass::Array<ElementAccumulator, NumVecs> pvscales;
// TMEM_LOAD
copy(tiled_t2r, tDtC, tTR_rAcc);
cutlass::arch::fence_view_async_tmem_load();
accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state);
++accumulator_pipe_consumer_state;
// Cast data from FP32 to BF16 to FP32.
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale);
auto acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
// Initialize RNG for tile
const size_t rng_sequence
= thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = uint4{0, 0, 0, 0};
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
// auto acc_scale = acc_scales[v];
if constexpr (kEnableStochasticRounding) {
random_uint4 = dist.generate4(rng);
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v],
acc_scale
),
reinterpret_cast<cutlass::Array<uint32_t, 4>*>(&random_uint4));
} else {
output_frgs[v] = cutlass::NumericArrayConverter<TC, ElementAccumulator, VectorSize>{}(cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(compute_frgs[v], acc_scale));
}
}
copy(tiled_r2g, src, dst);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
}
}
// this function computes RHT-GEMM for
// A: m x n: col-major
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
void
rht_gemm_ntt_w_sfc(int m, int n,
TA const* A,
TB const* B,
TC * C,
TSFC * SFC,
float const* global_amax,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 2048)
{
using namespace cute;
// Define shapes (dynamic)
auto M = static_cast<int>(m);
auto N = static_cast<int>(n);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, m); // (dM,dK)
auto dB = make_stride(Int<1>{}, 16); // (dN,dK)
auto dC = make_stride(n, Int<1>{}); // (dM,dN)
auto cga_shape = Shape< _1, _1, _1>{};
auto cga_tile_shape = Shape<_128,_16,_16>{};
auto cluster_tile_mainloop = Shape<_128,_16,_64>{};
// Construct the MMA
auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, float,
128, 16,
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1,_1>>{});
// MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never}
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma));
CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma)));
// Determine the A and B shapes
auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape)));
using TiledMma = decltype(mma);
using AtomThrID = typename TiledMma::AtomThrID;
using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(cga_tile_shape));
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>());
auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop)));
using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{}))));
using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>());
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes
constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB);
constexpr int kReservedBytes = 256; // Reserve for barriers and other uses
constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage;
auto sP = Int<kMaxStages>{}; // SMEM pipelines
auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE)
auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE)
auto sC = Layout<_1>{}; // XXX Dummy
// Create GMEM tensors
Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N)
Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16)
// Create the TiledCopy
auto tma_load_a = make_tma_copy_A_sm100(
SM90_TMA_LOAD{},
tensorA,
sA(_,_,_,0),
cluster_tile_mainloop,
mma);
auto tma_load_b = make_tma_copy_B_sm100(
SM90_TMA_LOAD{},
tensorB,
sB(_,_,_,0),
cga_tile_shape,
mma);
// Assert checks on tile sizes -- no predication
NVTE_CHECK(M % size<0>(cga_tile_shape) == 0,
"Inner dimension must be divisible by ", static_cast<size_t>(size<0>(cga_tile_shape)), " but got ", M, ".");
NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0,
"Outer dimension must be divisible by ", 4 * static_cast<size_t>(size<1>(cga_tile_shape)),
" but got ", N, ".");
uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size));
tiles = (tiles < sm_count) ? tiles : sm_count;
dim3 dimBlock(256);
dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape));
dim3 dimGrid(tiles, 1, 1);
int smem_size = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
auto* kernel_ptr = &rht_gemm_device<
decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape),
TA, decltype(dA), decltype(sA), decltype(tma_load_a),
TB, decltype(dB), decltype(sB), decltype(tma_load_b),
TC, decltype(dC), decltype(sC),
TSFC,
decltype(mma),
kEnableStochasticRounding>;
bool status = cudaFuncSetAttribute(*kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (status != cudaSuccess) {
std::cerr << "Error: Failed to set Shared Memory size." << std::endl;
return;
}
(*kernel_ptr)
<<< dimGrid, dimBlock, smem_size, stream >>>
(M, N, k_tile_size, cga_tile_shape,
A, dA, sA, tma_load_a,
B, dB, sB, tma_load_b,
C, dC, sC,
SFC,
mma, global_amax,
rng_state);
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
void
rht_gemm_ttt_wrapper(int m, int n,
TA const* A,
TB const* B,
TC * C,
TSFC * SFC,
float const* global_amax,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 1024)
{
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
// ultilize as many SMs as possible while keeping
// a relatively large contiguous dimension.
// for example, after swapping m, n for transpose purposes,
// the input / output tensor shapes for RHT-GEMM are:
// A: n x m: col-major
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding>(
n, m,
A, B, C,
SFC, global_amax,
rng_state,
sm_count, stream,
k_tile_size);
}
} // namespace
} // namespace detail
// clang-format on
void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_,
const Tensor &hadamard_matrix_,
QuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise);
// Check input and output tensors
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor &input = input_.data;
SimpleTensor &global_amax = output_.amax;
SimpleTensor &output_t = output_.data;
SimpleTensor &scale_inv_t = output_.scale_inv;
// Stochastic rounding config
const bool use_stochastic_rounding = quant_config.stochastic_rounding;
const size_t *rng_state = nullptr;
if (quant_config.rng_state != nullptr) {
Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state);
NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TC = cutlass::float_e2m1_t;
using TSFC = cutlass::float_ue4m3_t;
checkCuDriverContext(stream);
// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Hadamard matrix must be BF16 tensor, but scaling mode is ",
to_string(hadamard_matrix_.scaling_mode), ".");
NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16,
"Hadamard matrix must be BF16 tensor, but dtype is ",
to_string(hadamard_matrix_.dtype()), ".");
const SimpleTensor &hadamard_matrix = hadamard_matrix_.data;
NVTE_CHECK(
(hadamard_matrix_.shape() == std::vector<size_t>{kHadamardDimension, kHadamardDimension}),
"Hadamard matrix must have shape=",
std::vector<size_t>{kHadamardDimension, kHadamardDimension},
", but got shape=", hadamard_matrix_.shape(), ".");
const size_t hadamard_dimension = hadamard_matrix.shape[0];
const size_t ndim = input.shape.size();
const size_t n = input.shape[ndim - 1];
size_t m = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
m *= input.shape[i];
}
auto sm_count = transformer_engine::cuda::sm_count();
NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension");
int k_tile_size = 1024;
if (m == 8192 && n == 5120) {
k_tile_size = 512;
} else if (m == 8192 && n == 10240) {
k_tile_size = 1024;
} else if (m == 8192 && n == 2560) {
k_tile_size = 1280;
} else if (m == 8192 && n == 11328) {
k_tile_size = 1024;
} else if (m == 8192 && n == 512) {
k_tile_size = 256;
} else if (m == 8192 && n == 3584) {
k_tile_size = 512;
} else if (m == 11328 && n == 8192) {
k_tile_size = 1024;
} else if (m == 5120 && n == 8192) {
k_tile_size = 512;
} else if (m == 10240 && n == 8192) {
k_tile_size = 1024;
} else if (m == 2560 && n == 8192) {
k_tile_size = 1280;
} else if (m == 512 && n == 8192) {
k_tile_size = 256;
} else if (m == 3584 && n == 8192) {
k_tile_size = 512;
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding>(
/*m=*/m,
/*n=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*C=*/reinterpret_cast<TC *>(output_t.dptr),
/*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr),
/*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr),
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size););
}
} // namespace transformer_engine
void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output,
const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise);
using namespace transformer_engine;
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
hadamard_transform_cast_fusion_columnwise(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream);
}
...@@ -15,9 +15,76 @@ ...@@ -15,9 +15,76 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif // __cplusplus
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. /*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
enum NVTEMatmulConfigAttribute {
/*! Bias tensor
*
* If provided, the bias tensor is applied in the GEMM epilogue.
*/
kNVTEMatmulConfigBiasTensor = 0,
/*! Bias gradient tensor
*
* If provided, the bias gradient tensor will be filled in the GEMM epilogue.
*/
kNVTEMatmulConfigDBiasTensor = 1,
/*! Whether to compute GELU in GEMM epilogue. */
kNVTEMatmulConfigWithGELUEpilogue = 2,
/*! Whether to compute GELU backward in GEMM epilogue. */
kNVTEMatmulConfigWithDGELUEpilogue = 3,
/*! Auxilliary tensor for GEMM epilogue.
*
* For GELU, this will be filled with the GELU input. For GELU
* backward, this is expected to already be filled with the GELU
* input.
*/
kNVTEMatmulConfigEpilogueAuxTensor = 4,
/*! Whether to use split accumulator for FP8 GEMM. */
kNVTEMatmulConfigUseSplitAccumulator = 5,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEMatmulConfigSMCount = 6,
kNVTEMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
void *buf, size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
* *
* Computes: * Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
...@@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream); int math_sm_count, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
*
* Computes:
* - `D = alpha * op(A) * op(B) + beta * C`
*
* \param[in] transa Whether to transpose A matrix.
* \param[in] transb Whether to transpose B matrix.
* \param[in] alpha Scaling factor applied to matmul output.
* \param[in] A A matrix.
* \param[in] B B matrix.
* \param[in] beta Scaling factor applied to C matrix.
* \param[in] C C matrix.
* \param[out] D Output matrix.
* \param[in] workspace Workspace tensor.
* \param[in] config Additional configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input * allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
* *
* Computes: * Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
...@@ -133,14 +223,16 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -133,14 +223,16 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on. * \param[in] stream CUDA stream to wait on.
*/ */
void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace, bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count, bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream); cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif // __cplusplus
#ifdef __cplusplus
/*! \namespace transformer_engine /*! \namespace transformer_engine
*/ */
...@@ -153,6 +245,89 @@ namespace transformer_engine { ...@@ -153,6 +245,89 @@ namespace transformer_engine {
void nvte_cublas_handle_init(); void nvte_cublas_handle_init();
/*! \struct MatmulConfigWrapper
* \brief C++ wrapper for NVTEMatmulConfig.
*/
class MatmulConfigWrapper {
public:
MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {}
MatmulConfigWrapper(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~MatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEMatmulConfig.
*
* \return NVTEMatmulConfig held by this MatmulConfigWrapper.
*/
operator NVTEMatmulConfig() const noexcept { return config_; }
/*! \brief Set bias tensor. */
void set_bias_tensor(NVTETensor bias_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor,
sizeof(NVTETensor));
}
/*! \brief Set bias gradient tensor. */
void set_dbias_tensor(NVTETensor dbias_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor,
sizeof(NVTETensor));
}
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void set_with_gelu_epilogue(bool with_gelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue,
&with_gelu_epilogue, sizeof(bool));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void set_with_dgelu_epilogue(bool with_dgelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue,
&with_dgelu_epilogue, sizeof(bool));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor,
&epilogue_aux_tensor, sizeof(NVTETensor));
}
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void set_use_split_accumulator(bool use_split_accumulator) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator,
&use_split_accumulator, sizeof(bool));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void set_sm_count(int sm_count) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int));
}
private:
/*! \brief Wrapped NVTEMatmulConfig. */
NVTEMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine } // namespace transformer_engine
#endif // __cplusplus
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file hadamard_transform.h
* \brief Functions for Hadamard transforms.
*/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Perform a randomized Hadamard transform on the input tensor.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! \brief Perform the absolute maximum reduction on the input tensor with/without
* randomized hadamard transform. The rowwise result is the absolute maximum
* of the input tensor. The columnwise result is the absolute maximum of the
* input tensor transposed and applied randomized hadamard transformation.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! \brief Perform the columnwise hadamard transform cast fusion.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] hadamard_matrix Hadamard matrix.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output,
const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
...@@ -122,6 +122,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, ...@@ -122,6 +122,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t start_offset, size_t block_len, size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream); const NVTEDType out_dtype, cudaStream_t stream);
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -66,6 +66,7 @@ enum NVTETensorParam { ...@@ -66,6 +66,7 @@ enum NVTETensorParam {
kNVTEAmax = 3, /*!< Amax tensor */ kNVTEAmax = 3, /*!< Amax tensor */
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTENumTensorParams kNVTENumTensorParams
}; };
...@@ -88,10 +89,9 @@ enum NVTEScalingMode { ...@@ -88,10 +89,9 @@ enum NVTEScalingMode {
*/ */
NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3, NVTE_BLOCK_SCALING_2D = 3,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), /*! Single scale per block of 16 elements consecutive in either
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). * rowwise or columnwise direction */
*/ NVTE_NVFP4_1D_SCALING = 4,
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
NVTE_INVALID_SCALING = 100 NVTE_INVALID_SCALING = 100
}; };
...@@ -330,6 +330,12 @@ enum NVTEQuantizationConfigAttribute { ...@@ -330,6 +330,12 @@ enum NVTEQuantizationConfigAttribute {
* likely be refactored away in the future. * likely be refactored away in the future.
*/ */
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState = 4,
/*! Whether to use 2D block scaling for NVFP4 */
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
kNVTEQuantizationConfigNumAttributes kNVTEQuantizationConfigNumAttributes
}; };
...@@ -431,6 +437,15 @@ inline bool is_fp8_dtype(const DType t) { ...@@ -431,6 +437,15 @@ inline bool is_fp8_dtype(const DType t) {
*/ */
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
*/
inline bool is_high_precision_dtype(const DType t) {
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
}
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
*/ */
...@@ -566,6 +581,11 @@ class TensorWrapper { ...@@ -566,6 +581,11 @@ class TensorWrapper {
return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape);
} }
template <typename ShapeType>
TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape);
}
// Parameter getters // Parameter getters
NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept {
...@@ -590,6 +610,10 @@ class TensorWrapper { ...@@ -590,6 +610,10 @@ class TensorWrapper {
return get_parameter(kNVTEColumnwiseScaleInv); return get_parameter(kNVTEColumnwiseScaleInv);
} }
NVTEBasicTensor get_columnwise_amax() const noexcept {
return get_parameter(kNVTEColumnwiseAmax);
}
/*! \brief Get an underlying NVTETensor. /*! \brief Get an underlying NVTETensor.
* *
* \return NVTETensor held by this TensorWrapper. * \return NVTETensor held by this TensorWrapper.
...@@ -838,6 +862,24 @@ class QuantizationConfigWrapper { ...@@ -838,6 +862,24 @@ class QuantizationConfigWrapper {
&format, sizeof(Float8BlockScaleTensorFormat)); &format, sizeof(Float8BlockScaleTensorFormat));
} }
/*! \brief Set stochastic rounding state */
void set_rng_state(NVTETensor rng_state) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state,
sizeof(NVTETensor));
}
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization,
&nvfp4_2d_quantization, sizeof(bool));
}
/*! \brief Set whether to use stochastic rounding */
void set_stochastic_rounding(bool stochastic_rounding) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding,
&stochastic_rounding, sizeof(bool));
}
private: private:
/*! \brief Wrapped NVTEQuantizationConfig. */ /*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr; NVTEQuantizationConfig config_ = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment