Unverified Commit 89cc2a7e authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][NVFP4][MOE] NVFP4 Grouped Hadamard Amax Kernel (#2351)



* minor fix of torch view dtype
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* multi-tensor RHT amax, compiles
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* setup multi_tensor_quantize_nvfp4_impl
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* wire things up and run without crash
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* numerical test
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* unit test passing
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* finish unit test of split quantize api
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* bump up padding to 64 for nvfp4 grouped quantize
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix stochastic rounding
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* change error message
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* enable multi-amax without RHT
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix col-only quantize mode
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* improve benchmark script
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* add NCU example script
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* add larger test case
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* add contiguous_data_and_scale check to bulk allocator
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* unified naming and differentiate between group_ and multi_
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* move regular amax into multi_tensor.h
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Disentangle logic for split-quantize and general multi-tensor quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use size_t for split sections
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3b8d9a8a
...@@ -45,6 +45,16 @@ nsys profile \ ...@@ -45,6 +45,16 @@ nsys profile \
--trace=cuda,nvtx,cudnn,cublas \ --trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
# Example for jagged input benchmark to simulate unbalanced token splits
python benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4 --jagged-input "15296,8960,14656,14784,11712,7936,14080,10880"
# Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe
ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
--set=full \
--kernel-name "GroupHadamardAmaxTmaKernel" \
-s 5 -c 5 \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
""" """
RECIPES = { RECIPES = {
...@@ -163,7 +173,7 @@ def benchmark_linear( ...@@ -163,7 +173,7 @@ def benchmark_linear(
return timing_ms return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None):
data = [] data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark" assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
...@@ -173,12 +183,13 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): ...@@ -173,12 +183,13 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0 assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
# Bias is not supported for GroupedLinear benchmark # Bias is not supported for GroupedLinear benchmark
bias = None bias = None
# Run the benchmark # Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
print(f"m_splits: {m_splits}")
grouped_fwd_bwd_timing_ms = benchmark_linear( grouped_fwd_bwd_timing_ms = benchmark_linear(
x, x,
...@@ -235,8 +246,35 @@ if __name__ == "__main__": ...@@ -235,8 +246,35 @@ if __name__ == "__main__":
default="bf16", default="bf16",
help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all", help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all",
) )
# add an argument for the jagged input
# example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880] => sums up to 98304
parser.add_argument(
"--jagged-input",
type=str,
default=None,
help="Jagged input to use, example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880]",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=7168,
help="Hidden dimension to use, default is 7168",
)
parser.add_argument(
"--output-dim",
type=int,
default=2048,
help="Output dimension to use, default is 2048",
)
args = parser.parse_args() args = parser.parse_args()
jagged_input_splits = None
if args.jagged_input is not None:
jagged_input_splits = [int(x) for x in args.jagged_input.split(",")]
print(f"Jagged input splits: {jagged_input_splits}")
print(f"Jagged input splits sum: {sum(jagged_input_splits)}")
print(f"Jagged input splits num_gemms: {len(jagged_input_splits)}")
use_bias = False use_bias = False
# Set the MKN values to benchmark # Set the MKN values to benchmark
# Deepseek V3 EP64, SEQ_LEN=8192, topK8 # Deepseek V3 EP64, SEQ_LEN=8192, topK8
...@@ -256,11 +294,28 @@ if __name__ == "__main__": ...@@ -256,11 +294,28 @@ if __name__ == "__main__":
# 4 or 8local experts per rank # 4 or 8local experts per rank
num_gemms_list = [4, 8] num_gemms_list = [4, 8]
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]
token_dim_list = [65536]
hidden_dim_list = [7168]
output_dim_list = [2048]
# override the default targets to benchmark if specified
if jagged_input_splits is not None:
token_dim_list = [sum(jagged_input_splits)]
if args.hidden_dim is not None:
hidden_dim_list = [args.hidden_dim]
if args.output_dim is not None:
output_dim_list = [args.output_dim]
# MKN for group linear # MKN for group linear
mkns = [] mkns = []
for m in [65536]: for m in token_dim_list:
for k in [7168]: for k in hidden_dim_list:
for n in [2048]: for n in output_dim_list:
mkns.append((m, k, n)) mkns.append((m, k, n))
# default recipes to run if not specified # default recipes to run if not specified
...@@ -272,14 +327,20 @@ if __name__ == "__main__": ...@@ -272,14 +327,20 @@ if __name__ == "__main__":
recipe_list = [args.recipe] recipe_list = [args.recipe]
if args.profile: if args.profile:
mkns = [(8192 * 8, 7168, 2048)] num_gemms_list = [8]
hidden_dim_to_profile = 7168 if args.hidden_dim is None else args.hidden_dim
output_dim_to_profile = 2048 if args.output_dim is None else args.output_dim
token_dim_to_profile = 8192 * 8
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]
token_dim_to_profile = sum(jagged_input_splits)
mkns = [(token_dim_to_profile, hidden_dim_to_profile, output_dim_to_profile)]
# in profile mode, only run one recipe specified in args.recipe # in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", ( assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as" "In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, nvfp4, or bf16" " fp8_sub_channel, mxfp8, nvfp4, or bf16"
) )
recipe_list = [args.recipe] recipe_list = [args.recipe]
num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
# Initialize a dataframe to store the results # Initialize a dataframe to store the results
...@@ -310,6 +371,7 @@ if __name__ == "__main__": ...@@ -310,6 +371,7 @@ if __name__ == "__main__":
recipe_name, recipe_name,
use_bias, use_bias,
num_gemms=num_gemms, num_gemms=num_gemms,
m_splits=jagged_input_splits,
) )
df_linears = pd.concat([df_linears, df]) df_linears = pd.concat([df_linears, df])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py
# and also the test_nvfp4_rht_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
import pytest
import torch
import random
import math
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def generate_random_multiples_sum(total=8192, n=4, multiple=64):
if total % multiple != 0:
raise ValueError(f"Total ({total}) must be a multiple of {multiple}")
if (total // multiple) < n:
raise ValueError("Total too small for given n and multiple.")
# Work in units of multiples
total_units = total // multiple
# choose n−1 random cut points in [1, total_units−1)
cuts = sorted(random.sample(range(1, total_units), n - 1))
# convert to segment lengths
parts = (
[cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]]
)
# convert back to multiples
return [p * multiple for p in parts]
def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]:
least_multiple = 64
num_chunks = 4
split_sections = None
avg_split = M // num_chunks
if M == 0 or N == 0:
# all zeros
return [0] * num_chunks
if edge_cases == "regular":
split_sections = [avg_split] * num_chunks
elif edge_cases == "zero_tokens_front":
split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2]
elif edge_cases == "zero_tokens_end":
split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0]
elif edge_cases == "zero_tokens_middle":
split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2]
elif edge_cases == "random_uneven_split":
split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple)
else:
raise ValueError(f"Invalid edge case: {edge_cases}")
# adds up the split_sections to make it M
assert sum(split_sections) == M, "The split_sections do not add up to M"
# make sure every split_section is a multiple of least_multiple
for split_section in split_sections:
assert (
split_section % least_multiple == 0
), "The split_sections are not multiples of least_multiple"
return split_sections
# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding
def get_nvfp4_scale_shape_no_padding(shape, columnwise):
M, K = 1, 1
M = math.prod(shape[:-1])
K = shape[-1]
if columnwise:
outer = K
inner = math.ceil(M / 16)
return (outer, inner)
# rowwise
outer = M
inner = math.ceil(K / 16)
return (outer, inner)
def reference_group_quantize(
x: torch.Tensor,
quantizers: list[NVFP4Quantizer],
split_sections: list[int],
return_identity: bool,
return_transpose: bool,
) -> torch.Tensor:
x_view = x.reshape(-1, x.size(-1))
x_chunks = torch.split(x, split_sections)
# rowwise quantization
x_qx = []
x_sx = []
x_amax_rowwise = []
# columnwise quantization
x_qx_t = []
x_sx_t = []
x_amax_colwise = []
for i in range(len(x_chunks)):
x_chunk = x_chunks[i]
x_nvfp4_res = quantizers[i](x_chunk)
if return_identity:
x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8))
x_sx.append(x_nvfp4_res._rowwise_scale_inv)
x_amax_rowwise.append(x_nvfp4_res._amax_rowwise)
else:
x_qx.append(None)
x_sx.append(None)
x_amax_rowwise.append(None)
if return_transpose:
x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8))
x_sx_t.append(x_nvfp4_res._columnwise_scale_inv)
x_amax_colwise.append(x_nvfp4_res._amax_columnwise)
else:
x_qx_t.append(None)
x_sx_t.append(None)
x_amax_colwise.append(None)
return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise
def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None:
assert x.shape == y.shape
assert x.dtype == y.dtype
def check_group_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_identity: bool,
return_transpose: bool,
split_sections: list[int],
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
num_chunks = len(split_sections)
x_splits = torch.split(x, split_sections)
# Quantize
quantizers = [
NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=return_identity,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
for _ in range(len(split_sections))
]
x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = (
reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose)
)
split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers)
if return_identity:
x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs]
x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs]
x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs]
for i in range(len(x_qx)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
else:
torch.testing.assert_close(
x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0
)
torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0)
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False)
x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]]
x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0)
if return_transpose:
x_qx_t = [
output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs
]
x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs]
x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs]
# assert with zero tolerance
for i in range(len(x_qx_t)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
else:
torch.testing.assert_close(
x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0
)
torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0)
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True)
x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]]
x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# edge case, zero tokens for all
(0, 512),
# full tile cases
(256, 1024),
(1024, 256),
# larger sizes
(8192, 1024),
(16384, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"edge_cases",
[
"regular",
"zero_tokens_front",
"zero_tokens_end",
"zero_tokens_middle",
"random_uneven_split",
],
)
@pytest.mark.parametrize(
"quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
edge_cases: str,
quantize_mode: str,
with_random_sign_mask: bool,
with_rht: bool,
) -> None:
split_sections = generate_split_sections(M, N, edge_cases)
# currently disable pre-RHT amax
with_post_rht_amax = with_rht
if quantize_mode == "quantize":
return_identity = True
return_transpose = False
elif quantize_mode == "quantize_transpose":
return_identity = True
return_transpose = True
elif quantize_mode == "quantize_colwise_only":
return_identity = False
return_transpose = True
else:
raise ValueError(f"Invalid quantize mode: {quantize_mode}")
check_group_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
return_identity=return_identity,
return_transpose=return_transpose,
split_sections=split_sections,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import List, Tuple
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
...@@ -151,6 +156,74 @@ def quantize_fp4( ...@@ -151,6 +156,74 @@ def quantize_fp4(
return qx, sx, qx_t, sx_t return qx, sx, qx_t, sx_t
def group_quantize_fp4(
x: torch.Tensor,
use_stochastic_rounding: bool,
use_2D: bool,
use_RHT: bool,
split_sections: list[int],
use_tex_split_quantize: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Group quantize function with toggle between tex.split_quantize and manual split/call methods.
Args:
x (torch.Tensor): Input tensor.
use_stochastic_rounding (bool): Use stochastic rounding.
use_2D (bool): Use 2D quantization.
use_RHT (bool): Use RHT.
split_sections (list[int]): Split sizes for inputs.
use_tex_split_quantize (bool): Toggle method. If True, use tex.split_quantize, else use manual split and per-quantizer invocation.
Returns:
tuple: Lists of quantized tensors and scale tensors for all sections.
"""
num_tensors = len(split_sections)
nvfp4_quantizers = [
NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
for _ in range(num_tensors)
]
if use_tex_split_quantize:
outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers)
qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs]
sx_list = [output._rowwise_scale_inv for output in outputs]
qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs]
sx_t_list = [output._columnwise_scale_inv for output in outputs]
else:
x_chunks = torch.split(x, split_sections)
qx_list = []
sx_list = []
qx_t_list = []
sx_t_list = []
for i in range(num_tensors):
x_chunk = x_chunks[i]
x_nvfp4_sut = nvfp4_quantizers[i](x_chunk)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_list.append(qx)
sx_list.append(sx)
qx_t_list.append(qx_t)
sx_t_list.append(sx_t)
return qx_list, sx_list, qx_t_list, sx_t_list
def check_quantization_nvfp4_versus_reference( def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None: ) -> None:
...@@ -209,6 +282,92 @@ def check_quantization_nvfp4_versus_reference( ...@@ -209,6 +282,92 @@ def check_quantization_nvfp4_versus_reference(
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
def check_group_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
use_2D: bool,
use_RHT: bool,
num_splits: int,
use_tex_split_quantize: bool = True,
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
split_sections = [M // num_splits] * num_splits
x_total = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
x_splits = torch.split(x_total, split_sections)
q_rn_list, s_rn_list, q_t_rn_list, s_t_rn_list = group_quantize_fp4(
x_total,
use_stochastic_rounding=False,
use_2D=use_2D,
use_RHT=use_RHT,
split_sections=split_sections,
use_tex_split_quantize=use_tex_split_quantize,
)
sr_n_iter_results = []
for i in range(n_iters):
q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list = group_quantize_fp4(
x_total,
use_stochastic_rounding=True,
use_2D=use_2D,
use_RHT=use_RHT,
split_sections=split_sections,
use_tex_split_quantize=use_tex_split_quantize,
)
sr_n_iter_results.append((q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list))
for i, x in enumerate(x_splits):
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
# fetch q_rn, s_rn, q_t_rn, s_t_rn
q_rn = q_rn_list[i]
s_rn = s_rn_list[i]
q_t_rn = q_t_rn_list[i]
s_t_rn = s_t_rn_list[i]
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for iter_idx in range(n_iters):
result_sr = sr_n_iter_results[iter_idx]
q_sr = result_sr[0][i]
s_sr = result_sr[1][i]
q_t_sr = result_sr[2][i]
s_t_sr = result_sr[3][i]
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert (
me_t_sr < me_t_rn
), "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N", "M, N",
...@@ -236,3 +395,39 @@ def test_quantization_block_tiling_versus_reference( ...@@ -236,3 +395,39 @@ def test_quantization_block_tiling_versus_reference(
M=M, M=M,
N=N, N=N,
) )
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(4096, 7168),
(16384, 2048),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False], ids=str)
@pytest.mark.parametrize("use_RHT", [True], ids=str)
@pytest.mark.parametrize("num_splits", [4, 8], ids=str)
@pytest.mark.parametrize("use_tex_split_quantize", [True, False], ids=str)
def test_group_stochastic_rounding_quantization_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
num_splits: int,
use_tex_split_quantize: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_group_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
num_splits=num_splits,
use_tex_split_quantize=use_tex_split_quantize,
)
...@@ -12,7 +12,10 @@ import torch ...@@ -12,7 +12,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
get_align_size_for_quantization,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -1829,9 +1832,7 @@ def _test_grouped_linear_accuracy( ...@@ -1829,9 +1832,7 @@ def _test_grouped_linear_accuracy(
if num_gemms > 1: if num_gemms > 1:
split_size = 1 split_size = 1
if fp8: if fp8:
split_size = 16 split_size = get_align_size_for_quantization(recipe)
if recipe.mxfp8() or recipe.nvfp4():
split_size = 32
m = config.max_seqlen_q // split_size m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero dist.append(dist[-1]) # Manually add a zero
...@@ -2137,9 +2138,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -2137,9 +2138,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16 align_size = get_align_size_for_quantization(recipe)
if recipe.mxfp8() or recipe.nvfp4():
align_size = 32
padded_tokens_per_expert = [ padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size (num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert for num_tokens in tokens_per_expert
......
...@@ -175,6 +175,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources ...@@ -175,6 +175,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu) hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap # Compiling the files with the worst compilation time first to hopefully overlap
......
...@@ -16,185 +16,12 @@ ...@@ -16,185 +16,12 @@
#include "common/common.h" #include "common/common.h"
#include "common/util/ptx.cuh" #include "common/util/ptx.cuh"
#include "common/utils.cuh" #include "common/utils.cuh"
#include "hadamard_transform_utils.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
constexpr int kThreadsPerWarp = 32; constexpr int kThreadsPerWarp = 32;
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X, template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax> bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace transformer_engine {
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_
...@@ -61,6 +61,31 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE ...@@ -61,6 +61,31 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE
const NVTEQuantizationConfig quant_config, const NVTEQuantizationConfig quant_config,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split.
*
* This function is experimental and the API is not stable.
*
* This is intended for quantizing to NVFP4 with random Hadamard
* transforms (RHT). For each tensor split, compute the maximum
* absolute value (amax) and populate the row-wise amax of the
* corresponding output tensor. Also, compute the amax after a
* transposed RHT and populate the column-wise amax of the
* corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] outputs Array of NVFP4 output tensors. Only the row-wise and
* column-wise amaxes are updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors,
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -265,6 +265,22 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens ...@@ -265,6 +265,22 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales, float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream); float epsilon, cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
*
* For each tensor split, compute the maximum absolute value (amax)
* and populate the amax of the corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] amaxes Array of output tensors. Only the amax is updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections,
size_t num_tensors, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -253,7 +253,7 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten ...@@ -253,7 +253,7 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten
std::vector<py::handle> quantizer_list); std::vector<py::handle> quantizer_list);
std::vector<py::object> split_quantize(const at::Tensor &tensor, std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list); std::vector<py::handle> quantizer_list);
/*************************************************************************************************** /***************************************************************************************************
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..quantization import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -114,14 +114,8 @@ class Fp8Padding(torch.nn.Module): ...@@ -114,14 +114,8 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = ( recipe = FP8GlobalStateManager.get_fp8_recipe()
32 self.align_size = get_align_size_for_quantization(recipe)
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..quantization import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
...@@ -112,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -112,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = ( recipe = FP8GlobalStateManager.get_fp8_recipe()
32 self.align_size = get_align_size_for_quantization(recipe)
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -40,6 +40,7 @@ __all__ = [ ...@@ -40,6 +40,7 @@ __all__ = [
"is_fp8_block_scaling_available", "is_fp8_block_scaling_available",
"is_nvfp4_available", "is_nvfp4_available",
"get_default_recipe", "get_default_recipe",
"get_align_size_for_quantization",
] ]
...@@ -114,6 +115,15 @@ def get_default_recipe() -> Recipe: ...@@ -114,6 +115,15 @@ def get_default_recipe() -> Recipe:
return get_default_fp8_recipe() return get_default_fp8_recipe()
def get_align_size_for_quantization(recipe: Recipe) -> int:
"""Get the alignment size for quantization."""
if recipe.mxfp8():
return 32
if recipe.nvfp4():
return 64
return 16
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor""" """Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or ( if fp8_recipe.fp8_format == Format.E4M3 or (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment