Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os
import pathlib
from nvdlfw_inspect.config_manager import ConfigManager
......
......@@ -8,18 +8,22 @@ import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
LOG_QUANTIZED_CONFIG_BASE = """
......@@ -128,7 +132,7 @@ def test_sanity(feature_dirs):
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
with te.autocast(recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
......@@ -232,7 +236,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
with te.autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
......
......@@ -17,19 +17,19 @@ import transformer_engine.debug
import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.quantization import _default_sf_compute
from transformer_engine.pytorch import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
is_fp8_available,
)
from transformer_engine.pytorch.module.base import (
_2X_ACC_DGRAD,
_2X_ACC_FPROP,
_2X_ACC_WGRAD,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID
......@@ -250,7 +250,7 @@ def _init_model(weight):
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True):
with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE):
with tepytorch.autocast(enabled=fp8, recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward()
debug_api.step()
......@@ -547,7 +547,7 @@ def run_per_tensor_scaling(
LOSS_MULTIPLIER = 100
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with tepytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=True)
model.zero_grad()
y.retain_grad()
......
......@@ -7,11 +7,10 @@ import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from test_numerics import create_config_file
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
B, S, H, D = 64, 64, 64, 64
......@@ -68,7 +67,7 @@ def _get_model(model_key):
def _run_forward_backward(model, fp8):
for _ in range(3):
inp = torch.randn((S, B, H)).cuda()
with te.fp8_autocast(enabled=fp8):
with te.autocast(enabled=fp8):
out = model(inp)
out.sum().backward()
debug_api.step()
......
......@@ -21,13 +21,13 @@ from transformer_engine.common.recipe import (
Recipe,
)
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
QuantizedTensor,
Float8Tensor,
Float8CurrentScalingQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
def _get_raw_data(quantized_tensor):
......@@ -439,7 +439,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
with te.quantized_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
......@@ -475,17 +475,17 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
......@@ -573,7 +573,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": False}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
with te.quantized_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
......@@ -615,17 +615,17 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
......
......@@ -110,9 +110,9 @@ def _train(args):
build_model_context = nullcontext
build_model_context_args = {}
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch import quantized_model_init
build_model_context = fp8_model_init
build_model_context = quantized_model_init
build_model_context_args["enabled"] = True
# Build the model with the specified context
......
......@@ -19,9 +19,12 @@ from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
Float8Tensor,
Float8Quantizer,
MXFP8Quantizer,
)
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
......@@ -172,12 +175,12 @@ def _parse_args(argv=None, namespace=None):
opts.p2p = True
if opts.atomic:
if not te.fp8.check_fp8_support():
if not te.is_fp8_available():
assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
assert opts.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts
......
......@@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--quantization",
......@@ -438,7 +438,7 @@ def _train(opts):
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
)
with te.fp8_model_init(enabled=opts.fp8_init):
with te.quantized_model_init(enabled=opts.fp8_init):
test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs)
dist_print("Initialized test model...", debug=True)
if WORLD_RANK == 0:
......@@ -450,7 +450,7 @@ def _train(opts):
ref_args, ref_kwargs, _ = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True
)
with te.fp8_model_init(enabled=opts.fp8_init):
with te.quantized_model_init(enabled=opts.fp8_init):
ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
......@@ -473,7 +473,9 @@ def _train(opts):
layer_contexts = [
(
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
partial(
te.autocast, enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world
)
if opts.num_layers_at_start_in_bf16 <= i
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
else nullcontext
......
......@@ -9,6 +9,7 @@ import datetime
import os
import sys
from functools import wraps
import math
import torch
from torch import nn
......@@ -20,10 +21,14 @@ from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
NVFP4BlockScaling,
Format,
Recipe,
QParams,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -48,6 +53,14 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
def nvfp4_vanilla():
nvfp4_recipe = NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = QParams()
nvfp4_recipe.fp4_quant_fwd_weight = QParams()
nvfp4_recipe.fp4_quant_bwd_grad = QParams()
return nvfp4_recipe
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
......@@ -60,7 +73,9 @@ def quantization_recipe() -> Recipe:
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
return te.fp8.get_default_fp8_recipe()
if QUANTIZATION == "nvfp4":
return nvfp4_vanilla()
return te.quantization.get_default_fp8_recipe()
def main(argv=None, namespace=None):
......@@ -97,10 +112,14 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
if QUANTIZATION in ("fp8", "mxfp8"):
if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"):
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
# For fp8 block scaling, block size is 128,
# and to make low precision TP work, input tensor
# must be 128x128 divisible to be eligible for
# low precision All-Gather when needed
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
......@@ -108,6 +127,7 @@ def main(argv=None, namespace=None):
test_dict = [
test_quantizer,
test_quantized_all_gather,
test_linear,
test_layernorm,
test_layernorm_linear,
......@@ -177,6 +197,9 @@ def _get_tolerances(dtype):
# row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION == "nvfp4":
# TODO(zhongboz): investigate why the tolerance is so large
return {"rtol": 0.125, "atol": 0.12}
elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625}
......@@ -293,15 +316,15 @@ def _apply_models(
_alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True
input_single_node.requires_grad_()
input_distributed.requires_grad_()
with te.fp8_autocast(
with te.autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
recipe=quantization_recipe(),
):
output_single_node = model_single_node(input_single_node, **kwargs)
with te.fp8_autocast(
with te.autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
fp8_group=NCCL_WORLD,
recipe=quantization_recipe(),
amax_reduction_group=NCCL_WORLD,
):
output_distributed = model_distributed(input_distributed, **kwargs)
return output_single_node, output_distributed
......@@ -327,24 +350,36 @@ def _alloc_main_grad(model_single_node, model_distributed):
###############################################
# Quantizer #
###############################################
def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size):
"""
quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
"""
if quantizer_class == Float8CurrentScalingQuantizer:
quantizer_dist = quantizer_class(
fp8_dtype=fp8_dtype,
fp8_dtype=low_precision_dtype,
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
)
quantizer = quantizer_class(
fp8_dtype=fp8_dtype,
fp8_dtype=low_precision_dtype,
device=device,
with_amax_reduction=False,
)
return quantizer, quantizer_dist
elif quantizer_class == NVFP4Quantizer:
quantizer_dist = quantizer_class(
fp4_dtype=low_precision_dtype,
with_amax_reduction=True,
amax_reduction_group=tp_group,
)
quantizer = quantizer_class(
fp4_dtype=low_precision_dtype,
with_amax_reduction=False,
amax_reduction_group=None,
)
return quantizer, quantizer_dist
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_class}")
......@@ -415,6 +450,194 @@ def test_quantizer():
_test_quantizer(input_dtype, fp8_dtype)
############################################
# Quantized All-Gather #
############################################
def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape):
"""
Zero padding the scale_inv.
scale_inv shape is the padded shape, but not zero padded
unpadded_shape is the original shape before padding
"""
dim0, dim1 = scale_inv.shape
unpadded_dim0, unpadded_dim1 = unpadded_shape
pad_dim0 = (128 - unpadded_dim0 % 128) % 128
pad_dim1 = (4 - unpadded_dim1 % 4) % 4
new_dim0 = unpadded_dim0 + pad_dim0
new_dim1 = unpadded_dim1 + pad_dim1
assert dim0 == new_dim0
assert dim1 == new_dim1
# return input if no padding is needed
if pad_dim0 == 0 and pad_dim1 == 0:
return scale_inv
# unpad first to remove random bits from torch empty
scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous()
# using torch padding
new_scale_inv = torch.nn.functional.pad(
scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0
)
assert new_scale_inv.shape == (new_dim0, new_dim1)
return new_scale_inv
def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise):
"""
Calculate the unpadded shape of the scale_inv tensor.
"""
M, K = 1, 1
M = math.prod(input_shape[:-1])
K = input_shape[-1]
if quantizer_cls == NVFP4Quantizer:
if columnwise:
outer = K
inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE)
return (outer, inner)
else:
outer = M
inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE)
return (outer, inner)
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_cls}")
@run_distributed_test()
def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):
"""Test the quantizer under distributed settings.
Args:
input_dtype (torch.dtype): The data type of the input.
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
"""
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
# high precision input
x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
# set one element of the input to a very large value, which doesn't live in rank 0 after the split
# to test the amax reduction on purpose
# x_hp_cpu[M - 1, N - 1] = 1e4
# get the unpadded shapes
unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False)
unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True)
# rank 0 takes the full copy and quantize with GPU 0 for verification
if WORLD_RANK == 0:
x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]
# Create quantizers
quantizer, quantizer_dist = _construct_quantizer(
quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
)
# quantize the entire input
if WORLD_RANK == 0:
x_low_precision_single = quantizer(x_hp_rank0)
# run all-gather with a quantizer as input for quantized all-gather
x_low_precision_total, _ = gather_along_first_dim(
x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist
)
# check the outputs
if WORLD_RANK == 0:
# assert all data and scale_inv are the same
torch.testing.assert_close(
x_low_precision_single._rowwise_data,
x_low_precision_total._rowwise_data,
rtol=0.0,
atol=0.0,
)
# check the rowwise scale without any padding
unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape
unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
torch.testing.assert_close(
unpadded_rowwise_scale_inv_ref,
unpadded_rowwise_scale_inv,
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
_ref_zero_padding_scale_inv(
x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
),
_ref_zero_padding_scale_inv(
x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
),
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
x_low_precision_single._columnwise_data,
x_low_precision_total._columnwise_data,
rtol=0.0,
atol=0.0,
)
unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape
unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[
:unpad_dim0, :unpad_dim1
]
torch.testing.assert_close(
unpadded_columnwise_scale_inv_ref,
unpadded_columnwise_scale_inv,
rtol=0.0,
atol=0.0,
)
torch.testing.assert_close(
_ref_zero_padding_scale_inv(
x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
),
_ref_zero_padding_scale_inv(
x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
),
rtol=0.0,
atol=0.0,
)
def test_quantized_all_gather():
"""
Run quantized all-gather tests with various configurations.
"""
# skip this test for other quantization schemes
is_nvfp4 = QUANTIZATION == "nvfp4"
# add other recipes for testing if needed
if not is_nvfp4:
return
input_dtypes = [torch.bfloat16]
fp4_dtype = [tex.DType.kFloat4E2M1]
fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
quantizer_cls_nvfp4 = [NVFP4Quantizer]
# add FP8 quantizers if needed
quantizer_cls_fp8 = []
low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype
quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8
for quantizer_cls in quantizer_cls_list:
for input_dtype in input_dtypes:
for low_precision_dtype in low_precisio_dtypes:
_test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls)
############################################
# Linear #
############################################
......@@ -515,7 +738,7 @@ def test_linear():
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
......@@ -703,7 +926,7 @@ def test_layernorm_linear():
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"zero_centered_gamma": False},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
......@@ -818,7 +1041,7 @@ def test_layernorm_mlp():
{"normalization": "RMSNorm"},
{"zero_centered_gamma": True},
{"bias": False},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"activation": "relu"},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
......@@ -924,7 +1147,7 @@ def test_transformer_layer():
{"fuse_qkv_params": True, "fuse_wgrad_accumulation": True},
{"qkv_weight_interleaved": False},
{"bias": False},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"fuse_qkv_params": True},
{"activation": "relu"},
]
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import datetime
import os
import sys
from functools import wraps
import transformer_engine.pytorch as te
import torch
from torch import nn
import torch.distributed as dist
from transformer_engine.common.recipe import (
NVFP4BlockScaling,
Recipe,
QParams,
CustomRecipe,
)
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
from run_layer_with_overlap import _compare_tensors
BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def get_nvfp4_quantizer_factory():
"""
Create a quantizer factory for NVFP4 reference implementation.
This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization
enabled.
Returns:
A factory function that takes a role string and returns a quantizer instance
"""
def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for input
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16), # 2D quantization for weight
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for grad_output
)
elif role == "linear_grad_input":
# Grad input quantization not used
return None
else:
# For any other roles, return None
return None
return factory
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "nvfp4":
return nvfp4_rht_and_2d_quantization()
raise ValueError(f"Unsupported quantization: {QUANTIZATION}")
def quantization_reference_recipe() -> Recipe:
"""Create reference recipe using CustomRecipe with NVFP4 quantizer factory."""
if QUANTIZATION == "nvfp4":
nvfp4_ref_factory = get_nvfp4_quantizer_factory()
return CustomRecipe(qfactory=nvfp4_ref_factory)
raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}")
def main(argv=None, namespace=None):
global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
"timeout": datetime.timedelta(seconds=30),
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
NCCL_WORLD = dist.new_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
parser = argparse.ArgumentParser()
parser.add_argument("--quantization", type=str, default=None)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--hidden-size", type=int, default=128)
parser.add_argument("--out-size", type=int, default=128)
args = parser.parse_args(argv, namespace)
# Quantization scheme
QUANTIZATION = args.quantization
BATCH_SIZE = args.batch_size
HIDDEN_SIZE = args.hidden_size
OUT_SIZE = args.out_size
test_dict = [
test_linear,
test_layernorm_linear,
]
for test in test_dict:
test()
dist.destroy_process_group()
return 0
def run_distributed_test(test_name=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
name = test_name if test_name is not None else func.__name__
dist_print(f"Starting test {name} with args {args} and {kwargs}")
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
func(*args, **kwargs)
dist.barrier()
dist_print(f"Passed test {name}")
return wrapper
return decorator
def dist_print(msg, src=None, end="\n", error=False):
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
############################################
# Linear #
############################################
class TestDistributedLinearBase:
@staticmethod
def _prepare_data(
batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32
):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda")
bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None
gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda")
return x, w, bias, gradient
@staticmethod
def _shard_tensor(x, world_size, axis):
split_size = x.size()[axis] // world_size
split_tensor = torch.split(x, split_size, axis)
out = []
for tensor in split_tensor:
out.append(tensor.detach().clone().requires_grad_(x.requires_grad))
return out
@staticmethod
def _gather_tensor(local, world_size, tp_group, concat_dim):
out_list = [torch.zeros_like(local) for _ in range(world_size)]
torch.distributed.all_gather(out_list, local, tp_group)
return torch.cat(out_list, dim=concat_dim)
@staticmethod
def _all_reduce_tensor(local, world_size, tp_group):
if world_size == 1:
return local
handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False)
return local
@staticmethod
def _get_sum_abs_error(a, b):
return torch.sum(torch.abs(a - b))
@staticmethod
def _get_mean_abs_relative_error(a, b):
error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
return torch.mean(error)
@classmethod
def run_linear_preprocess_parallel(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_size=1,
rank=0,
):
if tp_size > 1:
if parallel_mode == "column":
# split w in N dim, which should be axis 0
w = cls._shard_tensor(w, tp_size, 0)[rank]
bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None
# split gradient in N dim, which should be axis 1
gradient = cls._shard_tensor(gradient, tp_size, 1)[rank]
if sequence_parallel:
# split x in M dim, which should be axis 0
x = cls._shard_tensor(x, tp_size, 0)[rank]
# row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1
if parallel_mode == "row":
# split x in K dim, which should be axis 1
x = cls._shard_tensor(x, tp_size, 1)[rank]
# split w in K dim, which should be axis 1
w = cls._shard_tensor(w, tp_size, 1)[rank]
if sequence_parallel:
# split gradient in M dim, which should be axis 0
gradient = cls._shard_tensor(gradient, tp_size, 0)[rank]
return x, w, bias, gradient
@classmethod
def run_linear_postprocess_parallel(
cls,
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
):
if tp_size > 1:
if parallel_mode == "column":
# gather y_q in N dim, which should be axis 1
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1)
# gather wgrad in N dim, which should be axis 0
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0)
# gather bgrad in N dim, which should be axis 0
bgrad = (
cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None
)
if sequence_parallel:
# gather dgrad in M dim, which should be axis 0
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0)
if parallel_mode == "row":
# gather dgrad in K dim, which should be axis 1
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1)
# gather wgrad in K dim, which should be axis 1
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1)
if sequence_parallel:
# gather y_q in M dim, which should be axis 0
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0)
# we need to sum bias gradient when using TP + SP
bgrad = (
cls._all_reduce_tensor(bgrad, tp_size, tp_group)
if bgrad is not None
else None
)
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_one_step(
cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False
):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
if isinstance(layer, te.Linear):
# Kitchen Linear
y_q = layer.forward(x, is_first_microbatch=is_first_microbatch)
else:
# the default torch.nn.Linear
y_q = layer(x)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
bgrad = (
layer._parameters["bias"].grad
if layer._parameters.get("bias", None) is not None
else None
)
assert "weight" in layer._parameters
if fuse_wgrad_accumulation:
wgrad = layer._parameters["weight"].main_grad
assert layer._parameters["weight"].grad is None
else:
wgrad = layer._parameters["weight"].grad
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls,
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation=False,
):
"""
Run multiple steps of linear layer and collect results.
"""
y_q_list, dgrad_list, wgrad_list = [], [], []
bgrad_list = [] if layer._parameters.get("bias", None) is not None else None
for i in range(run_num_steps):
x_i = (x + i).clone().detach().requires_grad_(True)
# run_linear_one_step
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(
layer,
x_i,
gradient,
is_first_microbatch=(i == 0) if enable_weight_cache else None,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
# Collect results
y_q_list.append(y_q.detach().clone())
dgrad_list.append(dgrad.detach().clone())
wgrad_list.append(wgrad.detach().clone())
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod
def run_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
fuse_wgrad_accumulation=False,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = te.Linear(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
layer = layer.to("cuda")
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
if fuse_wgrad_accumulation:
assert (
run_num_steps > 1
), "Fused weight gradient accumulation requires run_num_steps > 1"
layer.weight.main_grad = torch.zeros_like(layer.weight)
# Run one step or multiple steps
if run_num_steps == 1:
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
else:
y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps(
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation,
)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, dgrad, wgrad, bgrad
@run_distributed_test()
def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'row' or 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference
"""
params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True)
fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False)
seed = torch.initial_seed()
recipe = quantization_recipe()
# turn on weight quantization cache when fusing wgrad accumulation
enable_weight_cache = fuse_wgrad_accumulation
run_num_steps = 1 if not fuse_wgrad_accumulation else 5
x, w, bias, gradient = TestDistributedLinearBase._prepare_data(
BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype
)
# run the recipe under test
with te.autocast(enabled=True, recipe=recipe):
y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
run_num_steps=1 if not fuse_wgrad_accumulation else 5,
enable_weight_cache=fuse_wgrad_accumulation,
)
# run the reference
reference_recipe = quantization_reference_recipe()
with te.autocast(enabled=True, recipe=reference_recipe):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
run_num_steps=run_num_steps,
enable_weight_cache=enable_weight_cache,
)
# compare results, zero tolerance
if WORLD_RANK == 0:
torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch")
torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch")
torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch")
if bgrad is not None and bgrad_ref is not None:
torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch")
def test_linear():
"""Run linear layer tests with various configurations."""
kwargs_list = [
{"bias": False},
]
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs)
############################################
# LayerNormLinear #
############################################
class TestDistributedLayerNormLinearBase(TestDistributedLinearBase):
@classmethod
def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
parameters = layer._parameters
# bias and weight gradients
bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None
assert "weight" in parameters
wgrad = parameters["weight"].grad
return y_q, ln_out, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False
):
# raise error, no test case for multiple steps for now
raise NotImplementedError("LayerNormLinear does not support test multiple steps for now")
@classmethod
def run_layernorm_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
LayerNormLinearClass=te.LayerNormLinear,
normalization="LayerNorm",
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = LayerNormLinearClass(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
normalization=normalization,
return_layernorm_output=True,
)
layer = layer.to("cuda")
# Copy weights
# kitchen_linear has different parameter names
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
# Run one step
y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, ln_out, dgrad, wgrad, bgrad
@run_distributed_test()
def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
"""
params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True)
seed = torch.initial_seed()
recipe = quantization_recipe()
# run multiple steps currently not supported for LayerNormLinear
run_num_steps = 1
x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data(
BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype
)
# run the recipe under test
with te.autocast(enabled=True, recipe=recipe):
y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
run_num_steps=run_num_steps,
enable_weight_cache=False,
)
# run the reference
reference_recipe = quantization_reference_recipe()
with te.autocast(enabled=True, recipe=reference_recipe):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = (
TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=NCCL_WORLD,
tp_size=WORLD_SIZE,
rank=WORLD_RANK,
run_num_steps=run_num_steps,
enable_weight_cache=False,
)
)
# compare results, zero tolerance
if WORLD_RANK == 0:
torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch")
torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch")
torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch")
torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch")
if bgrad is not None and bgrad_ref is not None:
torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch")
def test_layernorm_linear():
kwargs_list = [
{"bias": False},
]
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
_test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)
if __name__ == "__main__":
sys.exit(main())
......@@ -8,15 +8,15 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available
# NVTE_DISABLE_NVRTC=1 NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
TEST_ROOT = Path(__file__).parent.resolve()
......
......@@ -14,14 +14,13 @@ import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import logging
if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
......
......@@ -20,31 +20,34 @@ import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
QuantizedTensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
from utils import dtype_tols, make_recipe, quantization_tols
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_mxfp8_available(return_reason=True)
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
if nvfp4_available:
quantization_list.append("nvfp4")
@functools.cache
......@@ -115,6 +118,14 @@ def make_reference_and_test_tensors(
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
......@@ -415,7 +426,7 @@ def _test_basic_linear(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
......@@ -428,7 +439,7 @@ def _test_basic_linear(
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
......@@ -437,7 +448,7 @@ def _test_basic_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -581,7 +592,7 @@ def _test_linear(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -600,7 +611,7 @@ def _test_linear(
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -609,7 +620,7 @@ def _test_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -623,6 +634,204 @@ def _test_linear(
torch.testing.assert_close(db_test, db_ref, **tols)
def _test_mlp(
*,
bias: bool = True,
hidden_size: int = 32,
local_batch_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
quantized_weight: bool = False,
sequence_parallel: bool = False,
) -> None:
"""2-layer MLP
MLP includes GELU activation in order to test op fusions. Model
performs warmup steps in order to test inter-step logic.
"""
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
mlp_size = hidden_size * world_size
batch_size = local_batch_size
if sequence_parallel:
batch_size *= world_size
in_shape = (batch_size, hidden_size)
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w1_ref, w1_test = make_reference_and_test_tensors(
(mlp_size, hidden_size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
b1_ref, b1_test = None, None
w2_ref, w2_test = make_reference_and_test_tensors(
(hidden_size, mlp_size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
b2_ref, b2_test = None, None
if bias:
b1_ref, b1_test = make_reference_and_test_tensors(
(mlp_size,),
test_dtype=dtype,
test_device=device,
)
b2_ref, b2_test = make_reference_and_test_tensors(
(world_size, hidden_size),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
y_ref = torch.nn.functional.linear(y_ref, w1_ref)
if bias:
y_ref += b1_ref
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
y_ref = torch.nn.functional.linear(y_ref, w2_ref)
if bias:
y_ref += b2_ref.sum(dim=0)
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
local_mlp_size = mlp_size // world_size
local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size)
dx_ref = x_ref.grad
dw1_ref = w1_ref.grad[local_mlp_slice, :]
w1_ref = w1_ref[local_mlp_slice, :]
w1_test = w1_test[local_mlp_slice, :]
dw2_ref = w2_ref.grad[:, local_mlp_slice]
w2_ref = w2_ref[:, local_mlp_slice]
w2_test = w2_test[:, local_mlp_slice]
if bias:
db1_ref = b1_ref.grad[local_mlp_slice]
b1_ref = b1_ref[local_mlp_slice]
b1_test = b1_test[local_mlp_slice]
db2_ref = b2_ref.grad[rank, :]
b2_ref = b2_ref[rank, :]
b2_test = b2_test[rank, :]
else:
db1_ref = None
db2_ref = None
if sequence_parallel:
local_batch_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
x_ref = x_ref[local_batch_slice, ...]
dx_ref = dx_ref[local_batch_slice, ...]
x_test = x_test[local_batch_slice, ...].clone()
y_ref = y_ref[local_batch_slice, ...]
dy_ref = dy_ref[local_batch_slice, ...]
dy_test = dy_test[local_batch_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.GELU(),
te_ops.Linear(
hidden_size,
mlp_size,
bias=bias,
device=device,
dtype=dtype,
tensor_parallel_mode="column",
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
),
te_ops.GELU(),
te_ops.Linear(
mlp_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
tensor_parallel_mode="row",
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
),
te_ops.GELU(),
)
with torch.no_grad():
model[1].weight.copy_(w1_test)
model[3].weight.copy_(w2_test)
if bias:
model[1].bias.copy_(b1_test)
model[3].bias.copy_(b2_test)
del w1_test, w2_test, b1_test, b2_test
# Warmup steps
for _ in range(3):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
x_test.grad = None
model[1].weight.grad = None
model[3].weight.grad = None
if bias:
model[1].bias.grad = None
model[3].bias.grad = None
# Forward and backward step
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
torch.testing.assert_close(dw1_test, dw1_ref, **tols)
torch.testing.assert_close(dw2_test, dw2_ref, **tols)
if bias:
db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db1_test, db1_ref, **tols)
torch.testing.assert_close(db2_test, db2_ref, **tols)
def _test_fp8_scale_update(
*,
amax_history_len: int = 31,
......@@ -734,7 +943,7 @@ def _test_fp8_scale_update(
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
......@@ -789,16 +998,31 @@ def run_parallel_tests() -> None:
for config in itertools.product(
quantization_list,
("column", "row"),
(False, True),
):
if rank == 0:
print(f"Running _test_linear with {config=}")
quantization, tensor_parallel_mode = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
quantization, tensor_parallel_mode, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_available() else torch.float32
_test_linear(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
quantization=quantization,
tensor_parallel_mode=tensor_parallel_mode,
sequence_parallel=sequence_parallel,
)
# MLP
for config in itertools.product(quantization_list, (False, True)):
if rank == 0:
print(f"Running _test_mlp with {config=}")
quantization, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_available() else torch.float32
_test_mlp(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
quantization=quantization,
sequence_parallel=sequence_parallel,
)
# FP8 scale update
......
......@@ -16,23 +16,23 @@ import sys
import pytest
import torch
from typing import Optional, Iterable
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
QuantizedTensor,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
......@@ -40,8 +40,8 @@ sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
......@@ -301,7 +301,7 @@ def _test_linear(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
ops = []
linear_op = None
bias_op = None
......@@ -351,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......
......@@ -8,9 +8,8 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine as te
"""
Distributed numerics tests
......@@ -27,11 +26,12 @@ import transformer_engine as te
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
......@@ -52,7 +52,9 @@ def _run_test(quantization):
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
@pytest.mark.parametrize(
"quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -62,15 +64,17 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", "None")
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
importlib.reload(te.fp8)
_run_test(quantization)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
if ori_int8_sim_fp8 is None or ori_int8_sim_fp8 == "None":
os.environ["NVTE_INT8_SIM_FP8"] = "0"
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
importlib.reload(te.fp8)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
import transformer_engine.pytorch as te
"""
Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def _run_test(quantization, batch_size, hidden_size, out_size):
test_path = TEST_ROOT / "run_numerics_exact.py"
test_cmd = LAUNCH_CMD + [str(test_path)]
test_cmd += ["--quantization", quantization]
test_cmd += ["--batch-size", str(batch_size)]
test_cmd += ["--hidden-size", str(hidden_size)]
test_cmd += ["--out-size", str(out_size)]
result = subprocess.run(test_cmd, env=os.environ, check=False)
assert result.returncode == 0
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", ["nvfp4"])
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(64, 128, 128),
(128, 128, 128),
(128, 256, 256),
(512, 1024, 768),
(512, 256, 1024),
(2048, 2048, 2048),
],
)
def test_distributed(quantization, batch_size, hidden_size, out_size):
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization, batch_size, hidden_size, out_size)
......@@ -7,8 +7,7 @@ import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......
......@@ -7,12 +7,11 @@ import pytest
import subprocess
from pathlib import Path
import torch
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count()
......@@ -34,7 +33,7 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def check_nvfp4_gemm_versus_reference(
x_dtype: torch.dtype,
w_dtype: torch.dtype,
out_dtype: torch.dtype,
M: int,
K: int,
N: int,
accumulate: bool,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
):
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input tensors
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
x = torch.randn(x_shape, dtype=x_dtype, device=device)
w = torch.randn(w_shape, dtype=w_dtype, device=device)
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device)
else:
out = None
# Native TE NVFP4 quantization
x_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
w_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
# Quantize x and w
x_nvfp4_native = x_quantizer.make_empty(
x_shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native)
w_nvfp4_native = w_quantizer.make_empty(
w_shape, dtype=w_dtype, device=device, requires_grad=False
)
w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native)
# Extract quantized data from native NVFP4Tensors
qx_data = (
x_nvfp4_native._columnwise_data.view(dtype=torch.uint8)
if x_columnwise
else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8)
)
qw_data = (
w_nvfp4_native._columnwise_data.view(dtype=torch.uint8)
if w_columnwise
else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8)
)
sx_native = (
x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv
)
sw_native = (
w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv
)
# Trim quantized data to match the actual tensor dimensions (remove padding)
qx_data = qx_data[:M, :]
qw_data = qw_data[:N, :]
# NVFP4 uses 16-element blocks, trim scales to remove padding
block_length = 16 # NVFP4 uses 16-element blocks
expected_sx_cols = expected_sw_cols = K // block_length
# Trim the scales to remove padding
sx_trimmed = sx_native[:M, :expected_sx_cols]
sw_trimmed = sw_native[:N, :expected_sw_cols]
# Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn
# for the reference GEMM to work correctly
sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn)
sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn)
# Create reference quantizer for reference GEMM
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=True,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
# Create reference quantized tensors needed by reference GEMM
x_nvfp4_ref = ref_quantizer.quantize(x)
w_nvfp4_ref = ref_quantizer.quantize(w)
# Reference GEMM using quantizer's qgemm method
y_ref = ref_quantizer.qgemm(
qx=qx_data,
qw=qw_data,
m_params=None, # MMParams not used in reference
out_dtype=out_dtype,
sx=sx_trimmed,
sw=sw_trimmed,
bias=None, # No bias for this test
out=out.clone() if accumulate else None,
accumulate=accumulate,
gemm_type=None, # GEMMType not used in reference
qresult_x=x_nvfp4_ref,
qresult_w=w_nvfp4_ref,
)
# Native TE GEMM using tex.generic_gemm (cuBLAS GEMM)
# Allocate cuBLAS workspace
workspace = torch.empty(4, dtype=torch.uint8, device=device)
transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
out_quantizer = None
bias = None
bias_dtype = TE_DType[torch.bfloat16]
use_gelu = False
gelu_input = None
use_grad = False
use_split_accumulator = False
# Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_native = tex.generic_gemm(
w_nvfp4_native,
transa,
x_nvfp4_native,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
gelu_input,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# just in case of accumulation, make sure y_ref and y_native are not the same tensor
assert y_ref is not y_native, "y_ref and y_native should not be the same tensor"
# Reset nans to zeros because torch.assert_close does not assume nans to be equal
assert not torch.isnan(y_ref.float()).all(), "All elements are nan"
y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref)
y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native)
# Compare results with some tolerance
torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, K, N",
[
(128, 128, 128),
(256, 128, 256),
(256, 256, 256),
(256, 1024, 256),
(1024, 1024, 1024),
(4096, 512, 3072),
(112, 128, 96),
(304, 640, 304),
(1008, 3072, 992),
(256, 64, 256),
(128, 128, 112),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise",
[
(False, False), # Only rowwise x rowwise is supported by reference GEMM
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
# Columnwise layouts are not supported by the reference implementation
],
ids=["rowxrow"],
)
def test_nvfp4_gemm_versus_reference(
M: int,
K: int,
N: int,
x_dtype: torch.dtype,
w_dtype: torch.dtype,
out_dtype: torch.dtype,
accumulate: bool,
is_x_columnwise: bool,
is_w_columnwise: bool,
):
check_nvfp4_gemm_versus_reference(
x_dtype=x_dtype,
w_dtype=w_dtype,
out_dtype=out_dtype,
M=M,
K=K,
N=N,
accumulate=accumulate,
x_columnwise=is_x_columnwise,
w_columnwise=is_w_columnwise,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
class GetRecipes:
@staticmethod
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
@staticmethod
def nvfp4_rht_only():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True)
return nvfp4_recipe
@staticmethod
def nvfp4_2d_quantization_only():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False)
return nvfp4_recipe
@staticmethod
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
@staticmethod
def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False):
if with_rht and with_2d_quantization:
return GetRecipes.nvfp4_rht_and_2d_quantization()
elif with_rht:
return GetRecipes.nvfp4_rht_only()
elif with_2d_quantization:
return GetRecipes.nvfp4_2d_quantization_only()
else:
return GetRecipes.nvfp4_vanilla()
def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bool = False):
"""
Create a quantizer factory for NVFP4 reference implementation.
This factory returns NVFP4QuantizerRef instances based on the role and configuration.
Used with CustomRecipe to create reference quantizers.
Args:
with_rht: Whether to enable random Hadamard transform
with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights)
Returns:
A factory function that takes a role string and returns a quantizer instance
"""
def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_grad_input":
# Grad input quantization not used
return None
else:
# For any other roles, return None
return None
return factory
def reset_rng_states():
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def check_nvfp4_module_versus_reference(
module_class,
in_features: int,
out_features: int,
bias: bool,
x_dtype: torch.dtype,
num_steps: int = 1,
with_rht: bool = False,
with_2d_quantization: bool = False,
):
"""
Compare native NVFP4 module against reference implementation.
Args:
module_class: te.Linear or te.LayerNormLinear
in_features: Input feature dimension
out_features: Output feature dimension
bias: Whether to use bias
x_dtype: Input tensor dtype
num_steps: Number of forward/backward steps to test
"""
device = "cuda"
batch_size = 32
seq_len = 128
# Create both modules with identical initialization
reset_rng_states()
# Create native module
print("\nCreate native module")
if module_class == te.Linear:
native_module = te.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
elif module_class == te.LayerNormLinear:
native_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
else:
raise ValueError(f"Unsupported module class: {module_class}")
# Create reference module with same weights
reset_rng_states()
# Create reference module
print("Create reference module")
if module_class == te.Linear:
ref_module = te.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
elif module_class == te.LayerNormLinear:
ref_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
# Sync weights between native and reference modules
with torch.no_grad():
# Copy main weight and bias parameters
if hasattr(native_module, "weight") and hasattr(ref_module, "weight"):
ref_module.weight.copy_(native_module.weight)
if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"):
ref_module.bias.copy_(native_module.bias)
# Copy layer norm parameters if they exist
if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"):
ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight)
if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
# Create recipes for native and reference implementations
nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization)
nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory)
# Training loop comparison
native_outputs = []
ref_outputs = []
for step in range(num_steps):
torch.manual_seed(1234 + step)
torch.cuda.manual_seed(1234 + step)
x_shape = (batch_size, seq_len, in_features)
x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device)
x_native = x_val.clone().detach().requires_grad_(True)
x_ref = x_native.clone().detach().requires_grad_(True)
grad_output_shape = (batch_size, seq_len, out_features)
grad_output_val = torch.normal(
mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device
)
grad_output = grad_output_val.clone().detach()
# Native forward/backward
with te.autocast(enabled=True, recipe=nvfp4_recipe):
# enable weight cache by giving is_first_microbatch
y_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
with te.autocast(enabled=True, recipe=nvfp4_ref_recipe):
y_ref = ref_module(x_ref)
y_ref.backward(grad_output)
# Store results
native_outputs.append(
{
"output": y_native.detach().clone(),
"input_grad": (
x_native.grad.detach().clone() if x_native.grad is not None else None
),
"weight_grad": (
native_module.weight.grad.detach().clone()
if native_module.weight.grad is not None
else None
),
"bias_grad": (
native_module.bias.grad.detach().clone()
if bias and native_module.bias.grad is not None
else None
),
}
)
ref_outputs.append(
{
"output": y_ref.detach().clone(),
"input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None),
"weight_grad": (
ref_module.weight.grad.detach().clone()
if ref_module.weight.grad is not None
else None
),
"bias_grad": (
ref_module.bias.grad.detach().clone()
if bias and ref_module.bias.grad is not None
else None
),
}
)
# Compare results across all steps
for step in range(num_steps):
native_out = native_outputs[step]
ref_out = ref_outputs[step]
# Compare outputs
torch.testing.assert_close(
native_out["output"],
ref_out["output"],
atol=1e-6,
rtol=1e-6,
msg=f"Output mismatch at step {step}",
)
# Compare input gradients
torch.testing.assert_close(
native_out["input_grad"],
ref_out["input_grad"],
atol=1e-6,
rtol=1e-6,
msg=(
f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:"
f" {ref_out['input_grad']}"
),
)
# Compare weight gradients
torch.testing.assert_close(
native_out["weight_grad"],
ref_out["weight_grad"],
atol=1e-6,
rtol=1e-6,
msg=(
f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']},"
f" Ref: {ref_out['weight_grad']}"
),
)
# Compare bias gradients
if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None:
torch.testing.assert_close(
native_out["bias_grad"],
ref_out["bias_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Bias gradient mismatch at step {step}",
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"in_features, out_features",
[
(128, 256),
(256, 128),
(512, 512),
(768, 3072),
(1024, 4096),
],
)
# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"])
@pytest.mark.parametrize("bias", [False], ids=["no_bias"])
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"])
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"]
)
def test_nvfp4_linear_versus_reference(
in_features: int,
out_features: int,
bias: bool,
x_dtype: torch.dtype,
num_steps: int,
with_rht: bool,
with_2d_quantization: bool,
):
"""Test NVFP4 Linear module against reference implementation."""
if with_rht and x_dtype != torch.bfloat16:
pytest.skip("RHT is only supported for bfloat16 input")
check_nvfp4_module_versus_reference(
module_class=te.Linear,
in_features=in_features,
out_features=out_features,
bias=bias,
x_dtype=x_dtype,
num_steps=num_steps,
with_rht=with_rht,
with_2d_quantization=with_2d_quantization,
)
def check_nvfp4_layernorm_linear_versus_reference(
in_features: int,
out_features: int,
bias: bool,
normalization: str,
x_dtype: torch.dtype,
num_steps: int = 1,
with_rht: bool = False,
with_2d_quantization: bool = False,
):
"""
Compare native NVFP4 LayerNormLinear module against reference implementation,
including ln_out.
"""
device = "cuda"
batch_size = 32
seq_len = 128
# Create both modules with identical initialization
reset_rng_states()
# Native module
native_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
normalization=normalization,
return_layernorm_output=True,
)
# Reference module
reset_rng_states()
ref_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
normalization=normalization,
return_layernorm_output=True,
)
# Sync weights and LN params
with torch.no_grad():
if hasattr(native_module, "weight") and hasattr(ref_module, "weight"):
ref_module.weight.copy_(native_module.weight)
if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"):
ref_module.bias.copy_(native_module.bias)
if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"):
if (
native_module.layer_norm_weight is not None
and ref_module.layer_norm_weight is not None
):
ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight)
if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None:
ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
# Create recipes for native and reference implementations
nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization)
nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory)
native_outputs = []
ref_outputs = []
for step in range(num_steps):
torch.manual_seed(1234 + step)
torch.cuda.manual_seed(1234 + step)
x_shape = (batch_size, seq_len, in_features)
x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device)
x_native = x_val.clone().detach().requires_grad_(True)
x_ref = x_native.clone().detach().requires_grad_(True)
grad_output_shape = (batch_size, seq_len, out_features)
grad_output_val = torch.normal(
mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device
)
grad_output = grad_output_val.clone().detach()
# Native forward/backward
with te.autocast(enabled=True, recipe=nvfp4_recipe):
y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
with te.autocast(enabled=True, recipe=nvfp4_ref_recipe):
y_ref, ln_out_ref = ref_module(x_ref)
y_ref.backward(grad_output)
native_outputs.append(
{
"output": y_native.detach().clone(),
"ln_out": ln_out_native.detach().clone(),
"input_grad": (
x_native.grad.detach().clone() if x_native.grad is not None else None
),
"weight_grad": (
native_module.weight.grad.detach().clone()
if native_module.weight.grad is not None
else None
),
"bias_grad": (
native_module.bias.grad.detach().clone()
if bias and native_module.bias.grad is not None
else None
),
}
)
ref_outputs.append(
{
"output": y_ref.detach().clone(),
"ln_out": ln_out_ref.detach().clone(),
"input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None),
"weight_grad": (
ref_module.weight.grad.detach().clone()
if ref_module.weight.grad is not None
else None
),
"bias_grad": (
ref_module.bias.grad.detach().clone()
if bias and ref_module.bias.grad is not None
else None
),
}
)
# Compare results
for step in range(num_steps):
n = native_outputs[step]
r = ref_outputs[step]
torch.testing.assert_close(
n["output"],
r["output"],
atol=1e-6,
rtol=1e-6,
msg=f"Output mismatch at step {step}",
)
torch.testing.assert_close(
n["ln_out"],
r["ln_out"],
atol=1e-6,
rtol=1e-6,
msg=f"LN output mismatch at step {step}",
)
torch.testing.assert_close(
n["input_grad"],
r["input_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Input gradient mismatch at step {step}",
)
torch.testing.assert_close(
n["weight_grad"],
r["weight_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Weight gradient mismatch at step {step}",
)
if bias and n["bias_grad"] is not None and r["bias_grad"] is not None:
torch.testing.assert_close(
n["bias_grad"],
r["bias_grad"],
atol=1e-6,
rtol=1e-6,
msg=f"Bias gradient mismatch at step {step}",
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"in_features, out_features",
[
(128, 256),
(256, 128),
],
)
@pytest.mark.parametrize("bias", [False], ids=["no_bias"])
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("num_steps", [1], ids=["single_step"])
@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"])
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"]
)
def test_nvfp4_layernorm_linear_versus_reference(
in_features: int,
out_features: int,
bias: bool,
normalization: str,
x_dtype: torch.dtype,
num_steps: int,
with_rht: bool,
with_2d_quantization: bool,
):
if with_rht and x_dtype != torch.bfloat16:
pytest.skip("RHT is only supported for bfloat16 input")
check_nvfp4_layernorm_linear_versus_reference(
in_features=in_features,
out_features=out_features,
bias=bias,
normalization=normalization,
x_dtype=x_dtype,
num_steps=num_steps,
with_rht=with_rht,
with_2d_quantization=with_2d_quantization,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment