Unverified Commit 89cc2a7e authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][NVFP4][MOE] NVFP4 Grouped Hadamard Amax Kernel (#2351)



* minor fix of torch view dtype
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* multi-tensor RHT amax, compiles
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* setup multi_tensor_quantize_nvfp4_impl
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* wire things up and run without crash
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* numerical test
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* unit test passing
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* finish unit test of split quantize api
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* bump up padding to 64 for nvfp4 grouped quantize
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix stochastic rounding
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* change error message
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* enable multi-amax without RHT
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix col-only quantize mode
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* improve benchmark script
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* add NCU example script
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* add larger test case
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* add contiguous_data_and_scale check to bulk allocator
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* unified naming and differentiate between group_ and multi_
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* move regular amax into multi_tensor.h
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Disentangle logic for split-quantize and general multi-tensor quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use size_t for split sections
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3b8d9a8a
......@@ -45,6 +45,16 @@ nsys profile \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
# Example for jagged input benchmark to simulate unbalanced token splits
python benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4 --jagged-input "15296,8960,14656,14784,11712,7936,14080,10880"
# Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe
ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
--set=full \
--kernel-name "GroupHadamardAmaxTmaKernel" \
-s 5 -c 5 \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
"""
RECIPES = {
......@@ -163,7 +173,7 @@ def benchmark_linear(
return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
......@@ -173,12 +183,13 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms
m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
# Bias is not supported for GroupedLinear benchmark
bias = None
# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
print(f"m_splits: {m_splits}")
grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
......@@ -235,8 +246,35 @@ if __name__ == "__main__":
default="bf16",
help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all",
)
# add an argument for the jagged input
# example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880] => sums up to 98304
parser.add_argument(
"--jagged-input",
type=str,
default=None,
help="Jagged input to use, example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880]",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=7168,
help="Hidden dimension to use, default is 7168",
)
parser.add_argument(
"--output-dim",
type=int,
default=2048,
help="Output dimension to use, default is 2048",
)
args = parser.parse_args()
jagged_input_splits = None
if args.jagged_input is not None:
jagged_input_splits = [int(x) for x in args.jagged_input.split(",")]
print(f"Jagged input splits: {jagged_input_splits}")
print(f"Jagged input splits sum: {sum(jagged_input_splits)}")
print(f"Jagged input splits num_gemms: {len(jagged_input_splits)}")
use_bias = False
# Set the MKN values to benchmark
# Deepseek V3 EP64, SEQ_LEN=8192, topK8
......@@ -256,11 +294,28 @@ if __name__ == "__main__":
# 4 or 8local experts per rank
num_gemms_list = [4, 8]
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]
token_dim_list = [65536]
hidden_dim_list = [7168]
output_dim_list = [2048]
# override the default targets to benchmark if specified
if jagged_input_splits is not None:
token_dim_list = [sum(jagged_input_splits)]
if args.hidden_dim is not None:
hidden_dim_list = [args.hidden_dim]
if args.output_dim is not None:
output_dim_list = [args.output_dim]
# MKN for group linear
mkns = []
for m in [65536]:
for k in [7168]:
for n in [2048]:
for m in token_dim_list:
for k in hidden_dim_list:
for n in output_dim_list:
mkns.append((m, k, n))
# default recipes to run if not specified
......@@ -272,14 +327,20 @@ if __name__ == "__main__":
recipe_list = [args.recipe]
if args.profile:
mkns = [(8192 * 8, 7168, 2048)]
num_gemms_list = [8]
hidden_dim_to_profile = 7168 if args.hidden_dim is None else args.hidden_dim
output_dim_to_profile = 2048 if args.output_dim is None else args.output_dim
token_dim_to_profile = 8192 * 8
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]
token_dim_to_profile = sum(jagged_input_splits)
mkns = [(token_dim_to_profile, hidden_dim_to_profile, output_dim_to_profile)]
# in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, nvfp4, or bf16"
)
recipe_list = [args.recipe]
num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
# Initialize a dataframe to store the results
......@@ -310,6 +371,7 @@ if __name__ == "__main__":
recipe_name,
use_bias,
num_gemms=num_gemms,
m_splits=jagged_input_splits,
)
df_linears = pd.concat([df_linears, df])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py
# and also the test_nvfp4_rht_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
import pytest
import torch
import random
import math
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def generate_random_multiples_sum(total=8192, n=4, multiple=64):
if total % multiple != 0:
raise ValueError(f"Total ({total}) must be a multiple of {multiple}")
if (total // multiple) < n:
raise ValueError("Total too small for given n and multiple.")
# Work in units of multiples
total_units = total // multiple
# choose n−1 random cut points in [1, total_units−1)
cuts = sorted(random.sample(range(1, total_units), n - 1))
# convert to segment lengths
parts = (
[cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]]
)
# convert back to multiples
return [p * multiple for p in parts]
def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]:
least_multiple = 64
num_chunks = 4
split_sections = None
avg_split = M // num_chunks
if M == 0 or N == 0:
# all zeros
return [0] * num_chunks
if edge_cases == "regular":
split_sections = [avg_split] * num_chunks
elif edge_cases == "zero_tokens_front":
split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2]
elif edge_cases == "zero_tokens_end":
split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0]
elif edge_cases == "zero_tokens_middle":
split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2]
elif edge_cases == "random_uneven_split":
split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple)
else:
raise ValueError(f"Invalid edge case: {edge_cases}")
# adds up the split_sections to make it M
assert sum(split_sections) == M, "The split_sections do not add up to M"
# make sure every split_section is a multiple of least_multiple
for split_section in split_sections:
assert (
split_section % least_multiple == 0
), "The split_sections are not multiples of least_multiple"
return split_sections
# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding
def get_nvfp4_scale_shape_no_padding(shape, columnwise):
M, K = 1, 1
M = math.prod(shape[:-1])
K = shape[-1]
if columnwise:
outer = K
inner = math.ceil(M / 16)
return (outer, inner)
# rowwise
outer = M
inner = math.ceil(K / 16)
return (outer, inner)
def reference_group_quantize(
x: torch.Tensor,
quantizers: list[NVFP4Quantizer],
split_sections: list[int],
return_identity: bool,
return_transpose: bool,
) -> torch.Tensor:
x_view = x.reshape(-1, x.size(-1))
x_chunks = torch.split(x, split_sections)
# rowwise quantization
x_qx = []
x_sx = []
x_amax_rowwise = []
# columnwise quantization
x_qx_t = []
x_sx_t = []
x_amax_colwise = []
for i in range(len(x_chunks)):
x_chunk = x_chunks[i]
x_nvfp4_res = quantizers[i](x_chunk)
if return_identity:
x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8))
x_sx.append(x_nvfp4_res._rowwise_scale_inv)
x_amax_rowwise.append(x_nvfp4_res._amax_rowwise)
else:
x_qx.append(None)
x_sx.append(None)
x_amax_rowwise.append(None)
if return_transpose:
x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8))
x_sx_t.append(x_nvfp4_res._columnwise_scale_inv)
x_amax_colwise.append(x_nvfp4_res._amax_columnwise)
else:
x_qx_t.append(None)
x_sx_t.append(None)
x_amax_colwise.append(None)
return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise
def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None:
assert x.shape == y.shape
assert x.dtype == y.dtype
def check_group_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_identity: bool,
return_transpose: bool,
split_sections: list[int],
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
num_chunks = len(split_sections)
x_splits = torch.split(x, split_sections)
# Quantize
quantizers = [
NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=return_identity,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
for _ in range(len(split_sections))
]
x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = (
reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose)
)
split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers)
if return_identity:
x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs]
x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs]
x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs]
for i in range(len(x_qx)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
else:
torch.testing.assert_close(
x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0
)
torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0)
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False)
x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]]
x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0)
if return_transpose:
x_qx_t = [
output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs
]
x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs]
x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs]
# assert with zero tolerance
for i in range(len(x_qx_t)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
else:
torch.testing.assert_close(
x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0
)
torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0)
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True)
x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]]
x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# edge case, zero tokens for all
(0, 512),
# full tile cases
(256, 1024),
(1024, 256),
# larger sizes
(8192, 1024),
(16384, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"edge_cases",
[
"regular",
"zero_tokens_front",
"zero_tokens_end",
"zero_tokens_middle",
"random_uneven_split",
],
)
@pytest.mark.parametrize(
"quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
edge_cases: str,
quantize_mode: str,
with_random_sign_mask: bool,
with_rht: bool,
) -> None:
split_sections = generate_split_sections(M, N, edge_cases)
# currently disable pre-RHT amax
with_post_rht_amax = with_rht
if quantize_mode == "quantize":
return_identity = True
return_transpose = False
elif quantize_mode == "quantize_transpose":
return_identity = True
return_transpose = True
elif quantize_mode == "quantize_colwise_only":
return_identity = False
return_transpose = True
else:
raise ValueError(f"Invalid quantize mode: {quantize_mode}")
check_group_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
return_identity=return_identity,
return_transpose=return_transpose,
split_sections=split_sections,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
......@@ -2,9 +2,14 @@
#
# See LICENSE for license information.
from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......@@ -151,6 +156,74 @@ def quantize_fp4(
return qx, sx, qx_t, sx_t
def group_quantize_fp4(
x: torch.Tensor,
use_stochastic_rounding: bool,
use_2D: bool,
use_RHT: bool,
split_sections: list[int],
use_tex_split_quantize: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Group quantize function with toggle between tex.split_quantize and manual split/call methods.
Args:
x (torch.Tensor): Input tensor.
use_stochastic_rounding (bool): Use stochastic rounding.
use_2D (bool): Use 2D quantization.
use_RHT (bool): Use RHT.
split_sections (list[int]): Split sizes for inputs.
use_tex_split_quantize (bool): Toggle method. If True, use tex.split_quantize, else use manual split and per-quantizer invocation.
Returns:
tuple: Lists of quantized tensors and scale tensors for all sections.
"""
num_tensors = len(split_sections)
nvfp4_quantizers = [
NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
for _ in range(num_tensors)
]
if use_tex_split_quantize:
outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers)
qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs]
sx_list = [output._rowwise_scale_inv for output in outputs]
qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs]
sx_t_list = [output._columnwise_scale_inv for output in outputs]
else:
x_chunks = torch.split(x, split_sections)
qx_list = []
sx_list = []
qx_t_list = []
sx_t_list = []
for i in range(num_tensors):
x_chunk = x_chunks[i]
x_nvfp4_sut = nvfp4_quantizers[i](x_chunk)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_list.append(qx)
sx_list.append(sx)
qx_t_list.append(qx_t)
sx_t_list.append(sx_t)
return qx_list, sx_list, qx_t_list, sx_t_list
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None:
......@@ -209,6 +282,92 @@ def check_quantization_nvfp4_versus_reference(
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
def check_group_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
use_2D: bool,
use_RHT: bool,
num_splits: int,
use_tex_split_quantize: bool = True,
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
split_sections = [M // num_splits] * num_splits
x_total = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
x_splits = torch.split(x_total, split_sections)
q_rn_list, s_rn_list, q_t_rn_list, s_t_rn_list = group_quantize_fp4(
x_total,
use_stochastic_rounding=False,
use_2D=use_2D,
use_RHT=use_RHT,
split_sections=split_sections,
use_tex_split_quantize=use_tex_split_quantize,
)
sr_n_iter_results = []
for i in range(n_iters):
q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list = group_quantize_fp4(
x_total,
use_stochastic_rounding=True,
use_2D=use_2D,
use_RHT=use_RHT,
split_sections=split_sections,
use_tex_split_quantize=use_tex_split_quantize,
)
sr_n_iter_results.append((q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list))
for i, x in enumerate(x_splits):
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
# fetch q_rn, s_rn, q_t_rn, s_t_rn
q_rn = q_rn_list[i]
s_rn = s_rn_list[i]
q_t_rn = q_t_rn_list[i]
s_t_rn = s_t_rn_list[i]
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for iter_idx in range(n_iters):
result_sr = sr_n_iter_results[iter_idx]
q_sr = result_sr[0][i]
s_sr = result_sr[1][i]
q_t_sr = result_sr[2][i]
s_t_sr = result_sr[3][i]
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert (
me_t_sr < me_t_rn
), "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
......@@ -236,3 +395,39 @@ def test_quantization_block_tiling_versus_reference(
M=M,
N=N,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(4096, 7168),
(16384, 2048),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False], ids=str)
@pytest.mark.parametrize("use_RHT", [True], ids=str)
@pytest.mark.parametrize("num_splits", [4, 8], ids=str)
@pytest.mark.parametrize("use_tex_split_quantize", [True, False], ids=str)
def test_group_stochastic_rounding_quantization_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
num_splits: int,
use_tex_split_quantize: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_group_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
num_splits=num_splits,
use_tex_split_quantize=use_tex_split_quantize,
)
......@@ -12,7 +12,10 @@ import torch
import torch.nn as nn
from torch.nn import Parameter
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
get_align_size_for_quantization,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -1829,9 +1832,7 @@ def _test_grouped_linear_accuracy(
if num_gemms > 1:
split_size = 1
if fp8:
split_size = 16
if recipe.mxfp8() or recipe.nvfp4():
split_size = 32
split_size = get_align_size_for_quantization(recipe)
m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
......@@ -2137,9 +2138,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16
if recipe.mxfp8() or recipe.nvfp4():
align_size = 32
align_size = get_align_size_for_quantization(recipe)
padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert
......
......@@ -175,6 +175,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_tensor.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "hadamard_transform_utils.cuh"
namespace transformer_engine {
namespace {
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed
struct MultiAmaxArgs {
// (output) Amax buffer for pre-RHT amax buffer
void* output_pre_rht_amax_list[kMaxTensorsPerKernel];
// (output) Amax buffer for RHT identity amax buffer
void* output_identity_amax_list[kMaxTensorsPerKernel];
// (output) Amax buffer for RHT transpose amax buffer
void* output_transpose_amax_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of split_sections of each tensor of input
int split_sections_range[kMaxTensorsPerKernel + 1];
// Number of tensors (splits) being processed by kernel
int num_tensors;
};
constexpr int kThreadsPerWarp = 32;
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment
int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);
int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;
if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_t_reg)
: "r"(local_amax_t_reg), "r"(temp_amax_t_reg));
}
if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[1]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[2])
: "r"(a_frag[2]), "r"(a_frag[3]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[2]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}
}
template <int kN>
__device__ __host__ constexpr int NextPowerOf2() {
static_assert(kN > 0, "kN must be > 0");
// Round up to the next power of 2 by counting leading zeros.
return 1 << (32 - __builtin_clz(kN - 1));
}
template <int kNumWarps, bool kReturnPreRhtAmax, bool kReturnIdentityAmax,
bool kReturnTransposedAmax>
__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax,
const float transpose_amax, float* staging_for_pre_rht,
float* staging_for_identity, float* staging_for_transpose,
float* output_pre_rht_amax_ptr,
float* output_identity_amax_ptr,
float* output_transpose_amax_ptr, const int warpid) {
// intra-warp reduction
constexpr int kWarpSize = 32;
int local_rank = threadIdx.x % 32;
float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max<kWarpSize>(pre_rht_amax) : 0.0f;
float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max<kWarpSize>(identity_amax) : 0.0f;
float warp_transpose_amax =
kReturnTransposedAmax ? warp_reduce_max<kWarpSize>(transpose_amax) : 0.0f;
// inter-warp reduction
if (threadIdx.x % 32 == 0) {
if (kReturnPreRhtAmax) {
staging_for_pre_rht[warpid] = warp_pre_rht_amax;
}
if (kReturnIdentityAmax) {
staging_for_identity[warpid] = warp_identity_amax;
}
if (kReturnTransposedAmax) {
staging_for_transpose[warpid] = warp_transpose_amax;
}
}
__syncthreads();
constexpr int kNumWarpsPow2 = NextPowerOf2<kNumWarps>();
if (warpid == 0) {
if (kReturnIdentityAmax) {
float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f;
identity_accum = warp_reduce_max<kNumWarpsPow2>(identity_accum);
if (local_rank == 0) {
atomicMaxFloat(output_identity_amax_ptr, identity_accum);
}
}
}
if (warpid == 1) {
if (kReturnTransposedAmax) {
float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f;
transpose_accum = warp_reduce_max<kNumWarpsPow2>(transpose_accum);
if (local_rank == 0) {
atomicMaxFloat(output_transpose_amax_ptr, transpose_accum);
}
}
}
if (warpid == 2) {
if (kReturnPreRhtAmax) {
float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f;
pre_rht_accum = warp_reduce_max<kNumWarpsPow2>(pre_rht_accum);
if (local_rank == 0) {
atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum);
}
}
}
}
// args: the mult-tensor amax arguments
__global__ void MultiZeroAmaxKernel(MultiAmaxArgs args) {
int num_tensors = args.num_tensors;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < num_tensors; tid += stride) {
float* output_pre_rht_amax_ptr = static_cast<float*>(args.output_pre_rht_amax_list[tid]);
float* output_identity_amax_ptr = static_cast<float*>(args.output_identity_amax_list[tid]);
float* output_transpose_amax_ptr = static_cast<float*>(args.output_transpose_amax_list[tid]);
if (output_pre_rht_amax_ptr != nullptr) {
*output_pre_rht_amax_ptr = 0;
}
if (output_identity_amax_ptr != nullptr) {
*output_identity_amax_ptr = 0;
}
if (output_transpose_amax_ptr != nullptr) {
*output_transpose_amax_ptr = 0;
}
}
}
// args: the mult-tensor amax arguments
__global__ void MultiAmaxMemcpyD2DKernelPreRHT(MultiAmaxArgs args) {
int num_tensors = args.num_tensors;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < num_tensors; tid += stride) {
float* output_pre_rht_amax_ptr = static_cast<float*>(args.output_pre_rht_amax_list[tid]);
float* output_identity_amax_ptr = static_cast<float*>(args.output_identity_amax_list[tid]);
float* output_transpose_amax_ptr = static_cast<float*>(args.output_transpose_amax_list[tid]);
if (output_pre_rht_amax_ptr != nullptr) {
float pre_rht_amax = *output_pre_rht_amax_ptr;
if (output_identity_amax_ptr != nullptr) {
*output_identity_amax_ptr = pre_rht_amax;
}
if (output_transpose_amax_ptr != nullptr) {
*output_transpose_amax_ptr = pre_rht_amax;
}
}
}
}
template <typename IType, int kHadamardDimension, int CHUNK_DIM_Y, int CHUNK_DIM_X, int BUFF_DIM_Y,
int BUFF_DIM_X, int THREADS_PER_CHUNK, int THREADS_PER_Y, bool kReturnPreRhtAmax,
bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input,
const MultiAmaxArgs args, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, uint64_t num_rows,
uint64_t row_length) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
float* output_pre_rht_amax_ptr;
float* output_identity_amax_ptr;
float* output_transpose_amax_ptr;
// calculate the global offset in Y direction to access the correct amax buffer
int global_offset_y = blockIdx.y * CHUNK_DIM_Y;
int tensor_id = 0;
while (args.split_sections_range[tensor_id + 1] <= global_offset_y) {
++tensor_id;
}
output_pre_rht_amax_ptr = static_cast<float*>(args.output_pre_rht_amax_list[tensor_id]);
output_identity_amax_ptr = static_cast<float*>(args.output_identity_amax_list[tensor_id]);
output_transpose_amax_ptr = static_cast<float*>(args.output_transpose_amax_list[tensor_id]);
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0);
static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0);
constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X;
constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp;
const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X;
extern __shared__ __align__(128) char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t* dshmem = reinterpret_cast<uint8_t*>((base_shmem_ptr + 127) & ~127ULL);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
IType* in_sh_0 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_sh_1 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_shs[2] = {in_sh_0, in_sh_1};
constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t* mbar = reinterpret_cast<uint64_t*>(dshmem);
dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y);
float* max_staging_identity = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_transpose = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_pre_rht = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
initialize_barriers<STAGES_X * STAGES_Y, THREADS_PER_CHUNK * THREADS_PER_Y>(mbar,
is_master_thread);
copy_2d_to_shared(in_shs[0], reinterpret_cast<const void*>(&tensor_map_input),
input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0],
is_master_thread);
uint32_t had_frag_i[4];
uint32_t had_frag_t[4];
get_hadamard_matrix_fragment<kReturnIdentityAmax, kReturnTransposedAmax, false, false>(
had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t);
float local_pre_rht_amax = 0.0;
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_pre_rht_amax_reg = *reinterpret_cast<uint32_t*>(&local_pre_rht_amax);
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
const int next_stage = stage + 1;
const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1;
const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y;
if (next_stage < STAGES_X * STAGES_Y) {
const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y;
const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X;
copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong
reinterpret_cast<const void*>(&tensor_map_input), input_global_offset_X,
input_global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
const size_t compute_stage_x_num =
BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp));
const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y);
const size_t in_row_stride = BUFF_DIM_X;
IType* in_sh_ptr = in_shs[stage % 2];
#pragma unroll
for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) {
const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y +
threadIdx.y * kHadamardDimension);
const int in_row_offset = row_idx_offset * in_row_stride;
#pragma unroll
for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) {
ComputeKernel<IType, kHadamardDimension, BUFF_DIM_Y, BUFF_DIM_X, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads();
}
}
}
const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp;
if constexpr (kReturnPreRhtAmax) {
unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax);
}
if constexpr (kReturnIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
}
if constexpr (kReturnTransposedAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
}
ReduceMax<kNumWarps, kReturnPreRhtAmax, kReturnIdentityAmax, kReturnTransposedAmax>(
local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity,
max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr,
output_transpose_amax_ptr, warpid);
destroy_barriers<STAGES_X * STAGES_Y>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace
// broadcast_pre_rht_amax: when it's true, hadamard transform will be disabled
// if at this time, the amax buffers for output expects both amax_rowwise and amax_colwise
// then call MultiAmaxMemcpyD2DKernelPreRHT to D2D copy the amax values
void group_hadamard_transform_amax(const Tensor& input_, std::vector<Tensor*>& output_list,
const size_t* split_sections, size_t num_tensors,
uint16_t random_sign_mask, uint16_t random_sign_mask_t,
bool broadcast_pre_rht_amax, cudaStream_t stream) {
NVTE_API_CALL(group_hadamard_transform_amax);
#if CUDA_VERSION >= 12080
// Check input tensor
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor& input = input_.data;
// TODO: validate num_tensors and split_sections
// assert if num_tensors is greater than kMaxTensorsPerKernel
// will expand 64 to higher value if needed
// if input size is going to exceed 4KB kernel launch limit, will then support multi-launch
NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel,
"Number of tensors should be less than or equal to ", kMaxTensorsPerKernel);
// check split_sections
// TODO: support m_splits_tensor for device initiated API
NVTE_CHECK(split_sections != nullptr, "split_sections should not be nullptr");
MultiAmaxArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.split_sections_range[0] = 0;
bool all_return_pre_rht_amax = true;
bool all_return_identity_amax = true;
bool all_return_transposed_amax = true;
for (size_t i = 0; i < num_tensors; ++i) {
void* output_pre_rht_amax_ptr = output_list[i]->amax.dptr;
// disable RHT(x) for now, only RHT_T(x) should be used
void* output_identity_amax_ptr = nullptr;
void* output_transpose_amax_ptr = output_list[i]->columnwise_amax.dptr;
all_return_pre_rht_amax &= (output_pre_rht_amax_ptr != nullptr);
all_return_identity_amax &= (output_identity_amax_ptr != nullptr);
all_return_transposed_amax &= (output_transpose_amax_ptr != nullptr);
// sanity check split_sections component to see if it's 64 multiple for each element
NVTE_CHECK(split_sections[i] % 64 == 0, "component ", i,
" of split_sections should be 64 multiple");
// also skip adding this tensor to the kernel args there are zero elements in this split
if (split_sections[i] == 0) {
continue;
}
// fill in kernel arguments
kernel_args.output_pre_rht_amax_list[kernel_args.num_tensors] = output_pre_rht_amax_ptr;
kernel_args.output_identity_amax_list[kernel_args.num_tensors] = output_identity_amax_ptr;
kernel_args.output_transpose_amax_list[kernel_args.num_tensors] = output_transpose_amax_ptr;
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
kernel_args.num_tensors++;
}
NVTE_CHECK(all_return_pre_rht_amax || all_return_identity_amax || all_return_transposed_amax,
"At least one of return_pre_rht_amax, return_identity_amax, or return_transposed_amax "
"must be true");
// currently we haven't supported all_return_identity_amax, assert error if it's mistakenly enabled
NVTE_CHECK(!all_return_identity_amax,
"Currently RHT transform should only be applied to transposed input");
if (broadcast_pre_rht_amax) {
NVTE_CHECK(all_return_pre_rht_amax,
"broadcast_pre_rht_amax is only supported when we compute pre-RHT amax");
// if all_return_identity_amax and all_return_transposed_amax both are false, there is no need to broadcast anything
broadcast_pre_rht_amax &= (all_return_identity_amax || all_return_transposed_amax);
}
// Multi zero out multiple amaxes if needed
// Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel
// let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block
dim3 block_setup_amax(kMaxTensorsPerKernel);
dim3 grid_setup_amax(1);
MultiZeroAmaxKernel<<<grid_setup_amax, block_setup_amax, 0, stream>>>(kernel_args);
NVTE_CHECK_CUDA(cudaGetLastError());
checkCuDriverContext(stream);
using IType = bf16;
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
// four (1x4) 64x64 sub-tiles for ping-pong overlap
constexpr uint64_t kChunkBlockXSmall = 256;
constexpr uint64_t kChunkBlockYSmall = 64;
constexpr uint64_t kBuffDimX = 64;
constexpr uint64_t kBuffDimY = 64;
alignas(64) CUtensorMap tensor_map_input{};
create_2D_tensor_map(
/*tensorMap=*/tensor_map_input,
/*tensor=*/input,
/*globalY=*/num_rows,
/*globalX=*/row_length,
/*shmemY=*/kBuffDimY,
/*shmemX=*/kBuffDimX,
/*stride_elems=*/row_length,
/*offset_elems=*/0,
/*type_num_bits=*/sizeof(IType) * 8,
/*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B);
constexpr uint64_t kThreadBlockX = 4;
constexpr uint64_t kThreadBlockY = 1;
constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY;
dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY);
dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
(all_return_transposed_amax && !broadcast_pre_rht_amax), kReturnTransposedAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
(all_return_identity_amax && !broadcast_pre_rht_amax), kReturnIdentityAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
all_return_pre_rht_amax, kReturnPreRhtAmax,
// *2 for ping-pong
size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType);
size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) *
(kChunkBlockYSmall / kBuffDimY);
size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes = (shmem_bytes + 128);
auto kernel = GroupHadamardAmaxTmaKernel<
IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY,
kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(tensor_map_input, kernel_args,
random_sign_mask, random_sign_mask_t,
num_rows, row_length);
if (broadcast_pre_rht_amax) {
MultiAmaxMemcpyD2DKernelPreRHT<<<grid_setup_amax, block_setup_amax, 0, stream>>>(
kernel_args);
})));
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ",
CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace transformer_engine
// Naming convention: "Group" kernels here means contiguous input concatenated
// While "Multi" kernels are processing a list of pointers, like the zero amax kernel
// Group hadamard transform API is unlike other multi-input & multi-output APIs
// Group hadamard transform will take in a single input tensor, and directly calculate amax
// with optional RHT transform. That's because we can assume the input tensor list to be
// contiguous in memory, so the tensors are only splitted in dimension 0.
// RHT transform is 16x16, so as long as each split of the input has 16 multiple shape
// in dimension 0, we can treat the entire input as a single tensor.
// Although mathmatically 16 multple is enough for this function to be correct,
// for this kernel, we required 64 multiple of 16 in dimension 0 for better performance.
void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors,
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_hadamard_transform_amax);
using namespace transformer_engine;
if (num_tensors == 0) {
return;
}
Tensor* input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor*> output_list(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
output_list[i] = convertNVTETensorCheck(outputs[i]);
}
// Call the group tensor Hadamard transform amax implementation.
group_hadamard_transform_amax(*input_tensor, output_list, split_sections, num_tensors,
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), false, stream);
}
// Grouped-tensor amax without doing hadamard transform
void nvte_group_amax(const NVTETensor input, NVTETensor* outputs, const size_t* split_sections,
size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_amax);
using namespace transformer_engine;
if (num_tensors == 0) {
return;
}
Tensor* input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor*> output_list(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
output_list[i] = convertNVTETensorCheck(outputs[i]);
}
group_hadamard_transform_amax(*input_tensor, output_list, split_sections, num_tensors, 0, 0, true,
stream);
}
......@@ -16,185 +16,12 @@
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "hadamard_transform_utils.cuh"
namespace transformer_engine {
namespace {
constexpr int kThreadsPerWarp = 32;
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace transformer_engine {
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_
......@@ -61,6 +61,31 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split.
*
* This function is experimental and the API is not stable.
*
* This is intended for quantizing to NVFP4 with random Hadamard
* transforms (RHT). For each tensor split, compute the maximum
* absolute value (amax) and populate the row-wise amax of the
* corresponding output tensor. Also, compute the amax after a
* transposed RHT and populate the column-wise amax of the
* corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] outputs Array of NVFP4 output tensors. Only the row-wise and
* column-wise amaxes are updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors,
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -265,6 +265,22 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
*
* For each tensor split, compute the maximum absolute value (amax)
* and populate the amax of the corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] amaxes Array of output tensors. Only the amax is updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections,
size_t num_tensors, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -253,7 +253,7 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten
std::vector<py::handle> quantizer_list);
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list);
/***************************************************************************************************
......
......@@ -6,6 +6,7 @@
#include "transformer_engine/cast.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
......@@ -494,13 +495,15 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// allocate fp4 data, fp8 scalings, and amax values
// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN]
// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nvfp4_tensors(
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_allocate_nvfp4_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &contiguous_data_and_scale = std::get<2>(retval);
contiguous_data_and_scale = true;
// Number of tensors
const size_t num_tensors = shape_list.size();
......@@ -555,22 +558,29 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nv
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets, amax_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size += (product(rowwise_data_shapes[i]) + 1) / 2;
// FP4 data is aligned to 256B
const auto offset = roundup(buffer_size, 256);
if (offset != buffer_size) {
contiguous_data_and_scale = false;
}
data_offsets.push_back(offset);
buffer_size = offset + (product(rowwise_data_shapes[i]) + 1) / 2;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
// Scales are aligned to 16B
const auto offset = roundup(buffer_size, 16);
if (offset != buffer_size) {
contiguous_data_and_scale = false;
}
scale_offsets.push_back(offset);
buffer_size = offset + product(rowwise_scale_shapes[i]) * scale_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
amax_offsets.push_back(buffer_size);
// amax is scalar in fp32, 4 bytes each
buffer_size += 4;
// Amaxes (FP32) are aligned to 16B
// Note: Multi-quantize kernel does not require contiguous amaxes.
const auto offset = roundup(buffer_size, 16);
amax_offsets.push_back(offset);
buffer_size = offset + 4;
}
// Allocate full buffer
......@@ -584,7 +594,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nv
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
amax_rowwise_list.emplace_back(
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kUInt8));
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kFloat32));
}
}
......@@ -610,22 +620,29 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nv
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets, amax_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size += (product(columnwise_data_shapes[i]) + 1) / 2;
// FP4 data is aligned to 256B
const auto offset = roundup(buffer_size, 256);
if (offset != buffer_size) {
contiguous_data_and_scale = false;
}
data_offsets.push_back(offset);
buffer_size = offset + (product(columnwise_data_shapes[i]) + 1) / 2;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
// Scales are aligned to 16B
const auto offset = roundup(buffer_size, 16);
if (offset != buffer_size) {
contiguous_data_and_scale = false;
}
scale_offsets.push_back(offset);
buffer_size = offset + product(columnwise_scale_shapes[i]) * scale_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
amax_offsets.push_back(buffer_size);
// amax is scalar in fp32, 4 bytes each
buffer_size += 4;
// Amaxes (FP32) are aligned to 16B
// Note: Multi-quantize kernel does not require contiguous amaxes.
const auto offset = roundup(buffer_size, 16);
amax_offsets.push_back(offset);
buffer_size = offset + 4;
}
// Allocate full buffer
......@@ -639,7 +656,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nv
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
amax_columnwise_list.emplace_back(
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kUInt8));
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kFloat32));
}
}
......@@ -692,10 +709,209 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nv
return retval;
}
void split_quantize_nvfp4_impl(const TensorWrapper &input,
const std::vector<TensorWrapper> &input_list,
std::vector<TensorWrapper> &output_list,
const std::vector<size_t> &split_sections,
const std::vector<NVFP4Quantizer *> &quantizers) {
// Check tensor lists
const size_t num_tensors = split_sections.size();
NVTE_CHECK(input_list.size() == num_tensors, "Expected ", num_tensors, " input tensors, but got ",
input_list.size(), ".");
NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors,
" output tensors, but got ", output_list.size(), ".");
NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors,
" NVFP4 quantizers, but got ", quantizers.size(), ".");
// Trivial cases
if (num_tensors == 0) {
return;
}
if (input.numel() == 0) {
for (const auto &tensor : input_list) {
NVTE_CHECK(tensor.numel() == 0,
"Input tensor has zero elements but got split with non-zero elements");
}
return;
}
// Assume all quantizers have identical config
const auto &quantizer = *quantizers.front();
NVTE_CHECK(!quantizer.with_2d_quantization,
"NVFP4 split-quantize does not support 2D quantization");
NVTE_CHECK(!quantizer.with_amax_reduction,
"NVFP4 split-quantize does not support amax reduction");
// Check input tensor shape
const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1;
NVTE_CHECK(input_last_dim % 128 == 0,
"NVFP4 multi-quantize requires inner dim to be multiple of 128.");
// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
// Objects for TE C API
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<QuantizationConfigWrapper> quant_config_list;
for (size_t i = 0; i < num_tensors; ++i) {
nvte_tensor_input_list.push_back(input_list[i].data());
nvte_tensor_output_list.push_back(output_list[i].data());
quant_config_list.emplace_back(QuantizationConfigWrapper());
}
// Stochastic rounding
std::vector<TensorWrapper> te_rng_state_list;
at::Tensor rng_states_tensor;
if (quantizer.stochastic_rounding) {
// TODO(zhongbo): remove the for loop of generating rng states with a single call
// with rng_elts_per_thread = 1024 * num_tensors
// Change to the bulk generate rng states api when grouped quantize is available
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
rng_states_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
for (size_t i = 0; i < num_tensors; ++i) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_ptr = static_cast<int64_t *>(rng_states_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args, rng_state_ptr);
te_rng_state_list.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_list[i].set_rng_state(te_rng_state_list[i].data());
quant_config_list[i].set_stochastic_rounding(true);
}
}
// Perform multi-tensor quantization
if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data
// Check that config is supported
NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input");
// Compute amaxes
if (quantizer.with_post_rht_amax) {
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
NVTE_SCOPED_GIL_RELEASE({
nvte_group_hadamard_transform_amax(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream);
});
} else {
// RHT is enabled, but amax is pre-RHT amax
NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax");
}
// Check that RHT matrix is available
NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0,
"RHT matrix is not available.");
auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix);
// Quantize tensors individually
NVTE_SCOPED_GIL_RELEASE({
for (size_t i = 0; i < num_tensors; i++) {
if (input_list[i].numel() == 0) {
continue; // Skip tensors with no elements
}
// Direct NVFP4 quantization for row-wise data
if (quantizer.rowwise_usage) {
auto out_rowwise_data = output_list[i].get_rowwise_data();
auto out_rowwise_scale_inv = output_list[i].get_rowwise_scale_inv();
auto out_rowwise_amax = output_list[i].get_amax();
TensorWrapper out_rowwise(output_list[i].scaling_mode());
out_rowwise.set_rowwise_data(out_rowwise_data.data_ptr,
static_cast<DType>(out_rowwise_data.dtype),
out_rowwise_data.shape);
out_rowwise.set_rowwise_scale_inv(out_rowwise_scale_inv.data_ptr,
static_cast<DType>(out_rowwise_scale_inv.dtype),
out_rowwise_scale_inv.shape);
out_rowwise.set_amax(out_rowwise_amax.data_ptr,
static_cast<DType>(out_rowwise_amax.dtype), out_rowwise_amax.shape);
nvte_quantize_v2(input_list[i].data(), out_rowwise.data(), quant_config_list[i], stream);
}
// RHT + NVFP4 quantize for column-wise data
if (quantizer.columnwise_usage) {
// Get the output column-wise data, scale_inv, and amax
auto out_columnwise_data = output_list[i].get_columnwise_data();
auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv();
auto out_columnwise_amax = output_list[i].get_columnwise_amax();
// Flatten column-wise data to 2D
auto colwise_data_shape = out_columnwise_data.shape;
std::vector<size_t> colwise_data_shape_2d;
colwise_data_shape_2d.push_back(colwise_data_shape.data[0]);
size_t last_dim = 1;
for (size_t i = 1; i < colwise_data_shape.ndim; ++i) {
last_dim *= colwise_data_shape.data[i];
}
colwise_data_shape_2d.push_back(last_dim);
// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper out_transpose(output_list[i].scaling_mode());
out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
static_cast<DType>(out_columnwise_data.dtype),
colwise_data_shape_2d);
out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr,
static_cast<DType>(out_columnwise_scale_inv.dtype),
out_columnwise_scale_inv.shape);
out_transpose.set_amax(out_columnwise_amax.data_ptr,
static_cast<DType>(out_columnwise_amax.dtype),
out_columnwise_amax.shape);
// RHT + NVFP4 quantize kernel
nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(),
rht_matrix_nvte.data(),
quant_config_list[i], stream);
}
}
});
} else { // NVFP4 quantize
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for input too
// Columnwise amax will be filled with a fused D2D copy from rowwise amax
// Note that the multi compute amax API expects rowwise amax pointer to be not null
// So we need to set the pointer accordingly to make colwise-only quantization work
std::vector<void *> orig_amax_ptr_list;
for (size_t i = 0; i < num_tensors; i++) {
auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr;
orig_amax_ptr_list.push_back(rowwise_amax_ptr);
auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr;
void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr;
NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer");
output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_SCOPED_GIL_RELEASE({
nvte_group_amax(input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
split_sections.data(), num_tensors, stream);
});
for (size_t i = 0; i < num_tensors; i++) {
output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector<size_t>{1});
}
// Quantize tensors individually
NVTE_SCOPED_GIL_RELEASE({
for (size_t i = 0; i < num_tensors; i++) {
// skip this round if input is empty
if (input_list[i].numel() == 0) {
continue;
}
nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream);
}
});
}
}
} // namespace
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list) {
init_extension();
......@@ -726,8 +942,6 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
const size_t dim0_stride =
input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0];
for (size_t i = 0; i < num_splits; ++i) {
NVTE_CHECK(split_sections[i] >= 0, "Attempted to split tensor with shape=", input_shape,
" along dim 0 with split_sections=", split_sections);
NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0],
"Attempted to split tensor with shape=", input_shape,
" along dim 0 with split_sections=", split_sections);
......@@ -745,65 +959,96 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i]));
}
// For FP8 block-scaling, we construct output tensors with bulk allocations
// For MXFP8, we also use bulk allocations
bool use_fused_bulk_alloc = true;
for (size_t i = 0; i < quantizer_list.size(); i++) {
if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) &&
!detail::IsMXFP8Quantizers(quantizer_list[i].ptr()) &&
!detail::IsNVFP4Quantizers(quantizer_list[i].ptr())) {
use_fused_bulk_alloc = false;
break;
}
// Choose implementation for allocating and populating tensors
enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 };
enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 };
AllocationMethod allocation_method = AllocationMethod::UNFUSED;
QuantizationMethod quantization_method = QuantizationMethod::UNFUSED;
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsMXFP8Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_MXFP8;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsNVFP4Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
}
// Allocate output tensors
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
if (!use_fused_bulk_alloc) {
// Allocate output tensors individually
for (size_t i = 0; i < num_splits; ++i) {
auto [output_cpp, output_py] =
quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype);
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
} else {
// TODO(zhongbo): make a better api to make this part less hacky
bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr());
bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr());
bool is_nvfp4 = detail::IsNVFP4Quantizers(quantizer_list[0].ptr());
if (is_fp8_blockwise) {
// FP8 block-scaling: construct output tensors with bulk allocations
switch (allocation_method) {
case AllocationMethod::BULK_FP8_BLOCKWISE: {
// Bulk allocation for FP8 block-scaling tensors
std::vector<Float8BlockQuantizer *> blockwise_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
} else if (is_mxfp8) {
// MXFP8: construct output tensors with bulk allocations
break;
}
case AllocationMethod::BULK_MXFP8: {
// Bulk allocation for MXFP8 tensors
std::vector<MXFP8Quantizer *> mxfp8_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
mxfp8_quantizers.push_back(static_cast<MXFP8Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
} else if (is_nvfp4) {
// NVFP4: construct output tensors with bulk allocations
break;
}
case AllocationMethod::BULK_NVFP4: {
// Bulk allocation for NVFP4 tensors
std::vector<NVFP4Quantizer *> nvfp4_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bool contiguous_data_and_scale;
std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) =
bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers);
} else {
NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer");
if (!contiguous_data_and_scale) {
// Avoid fused quantize kernel if data is not contiguous
quantization_method = QuantizationMethod::UNFUSED;
}
break;
}
default: {
// Allocate output tensors individually
for (size_t i = 0; i < num_splits; ++i) {
auto [output_cpp, output_py] =
quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype);
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
}
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
// Quantize into output tensors
switch (quantization_method) {
case QuantizationMethod::FUSED_NVFP4: {
// Fused NVFP4 quantize kernel
auto input_nvte = makeTransformerEngineTensor(input_dptr, input_shape, input_dtype);
std::vector<NVFP4Quantizer *> nvfp4_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
split_quantize_nvfp4_impl(input_nvte, input_list, output_cpp_list, split_sections,
nvfp4_quantizers);
break;
}
default:
// General multi-tensor quantization
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
}
return output_py_list;
}
......
......@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex
from ..quantization import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization
from ..jit import no_torch_dynamo
......@@ -114,14 +114,8 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = (
32
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
recipe = FP8GlobalStateManager.get_fp8_recipe()
self.align_size = get_align_size_for_quantization(recipe)
# FP8 padding calculate
padded_m_splits = [
......
......@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex
from ..quantization import FP8GlobalStateManager
from ..quantization import FP8GlobalStateManager, get_align_size_for_quantization
from ..jit import no_torch_dynamo
......@@ -112,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = (
32
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
recipe = FP8GlobalStateManager.get_fp8_recipe()
self.align_size = get_align_size_for_quantization(recipe)
# FP8 padding calculate
padded_m_splits = [
......
......@@ -40,6 +40,7 @@ __all__ = [
"is_fp8_block_scaling_available",
"is_nvfp4_available",
"get_default_recipe",
"get_align_size_for_quantization",
]
......@@ -114,6 +115,15 @@ def get_default_recipe() -> Recipe:
return get_default_fp8_recipe()
def get_align_size_for_quantization(recipe: Recipe) -> int:
"""Get the alignment size for quantization."""
if recipe.mxfp8():
return 32
if recipe.nvfp4():
return 64
return 16
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment