Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -909,7 +909,7 @@ def test_illegal_2D_by_2D_enforced( ...@@ -909,7 +909,7 @@ def test_illegal_2D_by_2D_enforced(
is_w_1d_scaled, is_w_1d_scaled,
) -> None: ) -> None:
# 2D block quantization by 2D block quantization is not supported. # 2D block quantization by 2D block quantization is not supported.
expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling GEMM is supported"
cublas_gemm_test_constraint_enforced( cublas_gemm_test_constraint_enforced(
x_dtype, x_dtype,
w_dtype, w_dtype,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -89,126 +89,6 @@ def initialize_for_many_scales( ...@@ -89,126 +89,6 @@ def initialize_for_many_scales(
return result return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
) -> None:
te_dtype = TE_DType[quant_dtype]
tile_size = (1, blockwise_fp8_block_len)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=1,
all_gather_usage=True,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
x_fp8_sut_cpp_alloc = sut_quantizer(x)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=True,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
munge_scale_shapes=False,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous()
sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous()
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
# check that the C++ and Python allocators are equivalent
torch.testing.assert_close(
x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_scale_inv,
x_fp8_sut_cpp_alloc._columnwise_scale_inv,
atol=0.0,
rtol=0.0,
)
# check if the fp8 output between C++ and Python are the same
assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format
def check_quantization_block_tiling_versus_reference( def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype, x_dtype: torch.dtype,
M: int, M: int,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -8,9 +8,15 @@ import torch ...@@ -8,9 +8,15 @@ import torch
import pytest import pytest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.custom_recipes.quantization import MMParams
from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import (
CurrentScalingQuantizerRef,
)
from transformer_engine.pytorch.fp8 import int8_simulation_fp8 from transformer_engine.pytorch.fp8 import int8_simulation_fp8
...@@ -750,6 +756,132 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase): ...@@ -750,6 +756,132 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
) )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingNativeVsRef:
@staticmethod
def _make_quantizers(rowwise=True, columnwise=True):
# TE native FP8 current scaling quantizer
te_quant = te.Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=torch.device("cuda"),
rowwise=rowwise,
columnwise=columnwise,
)
# Reference quantizer
ref_quant = CurrentScalingQuantizerRef(
dtype=torch.float8_e4m3fn,
rowwise=rowwise,
columnwise=columnwise,
pow_2_scales=False,
eps=0.0,
)
return te_quant, ref_quant
@pytest.mark.parametrize(
"M, N, dtype",
[
(128, 256, torch.bfloat16),
],
ids=["rowwise"],
)
def test_current_scaling_quantization_versus_reference(self, M, N, dtype):
device = "cuda"
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((M, N), dtype=dtype, device=device)
te_quant, ref_quant = self._make_quantizers(rowwise=True, columnwise=False)
# Native TE quantization
x_te = te_quant(x)
assert x_te._data is not None
qx_native = x_te._data.view(dtype=torch.float8_e4m3fn)
sx_native = x_te._scale_inv
# Reference quantization
x_ref = ref_quant.quantize(x)
qx_ref = x_ref.data
sx_ref = x_ref.scale
# Byte-for-byte equality on data and exact scale_inv match
torch.testing.assert_close(qx_native, qx_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_native, sx_ref, atol=0.0, rtol=0.0)
@pytest.mark.parametrize(
"M, K, N, out_dtype, accumulate",
[
(128, 256, 96, torch.bfloat16, False),
(64, 128, 64, torch.float32, True),
],
ids=["bf16_no_acc", "fp32_acc"],
)
def test_current_scaling_gemm_versus_reference(self, M, K, N, out_dtype, accumulate):
device = "cuda"
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((M, K), dtype=torch.bfloat16, device=device)
w = torch.randn((N, K), dtype=torch.bfloat16, device=device)
out = torch.randn((M, N), dtype=out_dtype, device=device) if accumulate else None
te_quant_x, ref_quant = self._make_quantizers(rowwise=True, columnwise=True)
te_quant_w, _ = self._make_quantizers(rowwise=True, columnwise=True)
# Native TE quantization (direct)
qx_native = te_quant_x(x)
qw_native = te_quant_w(w)
# Prepare inputs for reference qgemm
assert qx_native._data is not None and qw_native._data is not None
qx_data = qx_native._data.view(dtype=torch.float8_e4m3fn)
qw_data = qw_native._data.view(dtype=torch.float8_e4m3fn)
sx = qx_native._scale_inv
sw = qw_native._scale_inv
# Reference GEMM
m_params = MMParams(out_dtype=out_dtype, use_split_accumulator=False)
y_ref = ref_quant.qgemm(
qx=qx_data,
qw=qw_data,
m_params=m_params,
out_dtype=out_dtype,
sx=sx,
sw=sw,
bias=None,
out=out.clone() if accumulate else None,
accumulate=accumulate,
gemm_type=None,
qresult_x=None,
qresult_w=None,
)
# Native TE GEMM
# return type is out, bias_grad, gelu_input, extra_output
y_native = tex.generic_gemm(
qw_native, # A
True, # transa (treat (N,K) as (K,N))
qx_native, # B
False, # transb
out.clone() if accumulate else None,
None, # out quantizer
TE_DType[out_dtype],
None, # bias
TE_DType[torch.bfloat16],
False, # use_gelu
None, # gelu_input
False, # use_grad
torch.empty(0, dtype=torch.uint8, device=device),
0,
accumulate,
False, # use_split_accumulator
)[0]
torch.testing.assert_close(y_native, y_ref, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -179,16 +179,12 @@ class TestFloat8BlockwiseTensor: ...@@ -179,16 +179,12 @@ class TestFloat8BlockwiseTensor:
) )
@pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False]) @pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_quantize_dequantize_dims( def test_quantize_dequantize_dims(
self, self,
dims: DimsType, dims: DimsType,
block_scaling_dim: int, block_scaling_dim: int,
dq_columnwise: bool, dq_columnwise: bool,
all_gather_usage: bool,
) -> None: ) -> None:
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
atol = _tols[tex.DType.kFloat8E4M3]["atol"] atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer( quantizer = Float8BlockQuantizer(
...@@ -196,7 +192,6 @@ class TestFloat8BlockwiseTensor: ...@@ -196,7 +192,6 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
...@@ -222,7 +217,6 @@ class TestFloat8BlockwiseTensor: ...@@ -222,7 +217,6 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=(block_scaling_dim == 1),
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
...@@ -287,13 +281,8 @@ class TestFloat8BlockwiseTensor: ...@@ -287,13 +281,8 @@ class TestFloat8BlockwiseTensor:
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("all_gather_usage", [True, False]) def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
def test_serialization(
self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool
) -> None:
"""Test serialization of Float8BlockwiseQTensor""" """Test serialization of Float8BlockwiseQTensor"""
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
device = "cuda" device = "cuda"
dtype = torch.bfloat16 dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
...@@ -302,7 +291,6 @@ class TestFloat8BlockwiseTensor: ...@@ -302,7 +291,6 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=True, columnwise=True,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
) )
# Create FP8 tensor # Create FP8 tensor
...@@ -326,7 +314,6 @@ class TestFloat8BlockwiseTensor: ...@@ -326,7 +314,6 @@ class TestFloat8BlockwiseTensor:
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
assert x_fp8_loaded._data_format == x_fp8._data_format
# Test that dequantized values match # Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize() x_fp8_dequant = x_fp8.dequantize()
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Callable, Tuple, Union, List from typing import Callable, Tuple, Union, List
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
import torch import torch
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -913,15 +913,15 @@ class TestBasicOps: ...@@ -913,15 +913,15 @@ class TestBasicOps:
dtype=dtype, dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad, accumulate_into_main_grad=accumulate_into_main_grad,
) )
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with torch.no_grad(): with torch.no_grad():
op.weight.copy_(w_test) op.weight.copy_(w_test)
del w_test del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2751,7 +2751,11 @@ class TestCheckpointing: ...@@ -2751,7 +2751,11 @@ class TestCheckpointing:
# Check that original and loaded model match exactly # Check that original and loaded model match exactly
tols = {"rtol": 0, "atol": 0} tols = {"rtol": 0, "atol": 0}
for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
torch.testing.assert_close(param_load, param_save, **tols) torch.testing.assert_close( # Force dequantization by casting to FP64
param_load.to(dtype=torch.float64, device="cpu"),
param_save.to(dtype=torch.float64, device="cpu"),
**tols,
)
torch.testing.assert_close(param_load.grad, param_save.grad, **tols) torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
for y_load, y_save in zip(ys_load, ys_save): for y_load, y_save in zip(ys_load, ys_save):
torch.testing.assert_close(y_load, y_save, **tols) torch.testing.assert_close(y_load, y_save, **tols)
...@@ -2768,7 +2772,6 @@ class TestSequentialModules: ...@@ -2768,7 +2772,6 @@ class TestSequentialModules:
@pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
...@@ -2778,25 +2781,18 @@ class TestSequentialModules: ...@@ -2778,25 +2781,18 @@ class TestSequentialModules:
*, *,
requires_grad: bool, requires_grad: bool,
bias: bool, bias: bool,
normalization: str,
quantized_compute: bool, quantized_compute: bool,
quantized_weight: bool, quantized_weight: bool,
dtype: torch.dtype, dtype: torch.dtype,
quantization: Optional[str], quantization: Optional[str],
device: torch.device = "cuda", device: torch.device = "cuda",
hidden_size: int = 32, hidden_size: int = 256,
sequence_length: int = 512, sequence_length: int = 48,
batch_size: int = 4, batch_size: int = 4,
ffn_hidden_size: int = 64, ffn_hidden_size: int = 384,
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
) -> None: ) -> None:
""" """LayerNorm/RMSNorm + Linear + SwiGLU + Linear"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape # Make input shape
in_shape = (sequence_length, batch_size, hidden_size) in_shape = (sequence_length, batch_size, hidden_size)
...@@ -2812,38 +2808,90 @@ class TestSequentialModules: ...@@ -2812,38 +2808,90 @@ class TestSequentialModules:
pytest.skip("Quantization scheme is not used") pytest.skip("Quantization scheme is not used")
# Random data # Random data
_, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization, quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
_, dy_test = make_reference_and_test_tensors( norm_w_ref, norm_w_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
norm_b_ref, norm_b_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
w1_ref, w1_test = make_reference_and_test_tensors(
(ffn_hidden_size, hidden_size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w2_ref, w2_test = make_reference_and_test_tensors(
(hidden_size, ffn_hidden_size // 2),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
b1_ref, b1_test, b2_ref, b2_test = None, None, None, None
if bias:
b1_ref, b1_test = make_reference_and_test_tensors(
ffn_hidden_size,
test_dtype=dtype,
test_device=device,
)
b2_ref, b2_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization, quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
) )
with torch.no_grad():
for t in (norm_w_ref, norm_w_test, norm_b_ref, norm_b_test):
t -= 0.5
for t in (w1_ref, w1_test, w2_ref, w2_test):
t *= 1 / 64
if bias:
for t in (b1_ref, b1_test, b2_ref, b2_test):
t -= 0.5
for t in (dy_ref, dy_test):
t -= 0.5
# Reference implementation
x = x_ref
x = torch.nn.functional.layer_norm(
x,
(hidden_size,),
weight=norm_w_ref,
bias=norm_b_ref,
eps=layernorm_epsilon,
)
x = torch.nn.functional.linear(x, w1_ref, bias=b1_ref)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
x = torch.nn.functional.linear(x, w2_ref, bias=b2_ref)
y_ref = x
y_ref.backward(dy_ref)
# Implementation with fusible operations # Construct operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm": norm = te_ops.LayerNorm(
norm = te_ops.LayerNorm( hidden_size,
hidden_size, eps=layernorm_epsilon,
eps=layernorm_epsilon, device=device,
device=device, dtype=dtype,
dtype=dtype, )
)
else:
norm = te_ops.RMSNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
ffn1 = te_ops.Linear( ffn1 = te_ops.Linear(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -2851,15 +2899,48 @@ class TestSequentialModules: ...@@ -2851,15 +2899,48 @@ class TestSequentialModules:
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
act = te_ops.GELU() act = te_ops.SwiGLU()
ffn2 = te_ops.Linear( ffn2 = te_ops.Linear(
ffn_hidden_size, ffn_hidden_size // 2,
hidden_size, hidden_size,
bias=bias, bias=bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
# Copy weights
with torch.no_grad():
norm.weight.copy_(norm_w_test)
norm.bias.copy_(norm_b_test)
ffn1.weight.copy_(w1_test)
ffn2.weight.copy_(w2_test)
if bias:
ffn1.bias.copy_(b1_test)
ffn2.bias.copy_(b2_test)
del norm_w_test, norm_b_test, w1_test, b1_test, w2_test, b2_test
# Fuse ops and perform forward and backward pass
forward = te_ops.Sequential(norm, ffn1, act, ffn2) forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert to FP64 CPU tensor"""
if tensor is None:
return None
out = tensor.detach().to(dtype=torch.float64, device="cpu")
out = out.requires_grad_(requires_grad=tensor.requires_grad)
return out
# Check values
tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking
torch.testing.assert_close(to_cpu(y_test), y_ref, **tols)
torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols)
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import transformer_engine.pytorch import transformer_engine.pytorch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch import is_mxfp8_available
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.quantize_scale_calc import scale_from_amax_tensor from references.quantize_scale_calc import scale_from_amax_tensor
...@@ -23,6 +24,7 @@ input_size_pairs = [ ...@@ -23,6 +24,7 @@ input_size_pairs = [
(555, 33333), (555, 33333),
] ]
appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)] appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)]
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
@pytest.mark.parametrize("input_size_pair", input_size_pairs) @pytest.mark.parametrize("input_size_pair", input_size_pairs)
...@@ -260,3 +262,33 @@ def test_multi_tensor_compute_scale_and_scale_inv( ...@@ -260,3 +262,33 @@ def test_multi_tensor_compute_scale_and_scale_inv(
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
def test_multi_tensor_compute_scale_inv_e8m0(input_size_pair, applier, repeat):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
a = torch.randn([sizea], dtype=torch.bfloat16, device=device).abs()
b = torch.randn([sizeb], dtype=torch.bfloat16, device=device).abs()
amax_list = []
for _ in range(repeat):
amax_list += [a.clone(), b.clone()]
scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list]
applier(
tex.multi_tensor_compute_scale_inv_e8m0,
None, # overflow_buf
[amax_list, scale_inv_list],
)
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
for amax, scale_inv in zip(amax_list, scale_inv_list):
scale_inv_u32 = (amax.float() / max_fp8).view(torch.int)
exponent = scale_inv_u32 // 2**23
mantissa = scale_inv_u32 & 0x7FFFFF
exponent += (
((mantissa > 0) & (exponent != 0xFE)) & ~((exponent == 0) & (mantissa <= 0x400000))
).to(torch.int)
torch.testing.assert_close(exponent.to(torch.uint8), scale_inv)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -13,7 +13,10 @@ import torch.nn as nn ...@@ -13,7 +13,10 @@ import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
get_align_size_for_quantization,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -46,7 +49,6 @@ from transformer_engine.pytorch import ( ...@@ -46,7 +49,6 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states from utils import ModelConfig, reset_rng_states
...@@ -191,7 +193,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: ...@@ -191,7 +193,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
return dict(rtol=1e-3, atol=1e-5) return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5) return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
def assert_allclose( def assert_allclose(
...@@ -1279,6 +1281,9 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): ...@@ -1279,6 +1281,9 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation): def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
te_linear_ref = Linear( te_linear_ref = Linear(
...@@ -1376,7 +1381,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1376,7 +1381,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match # Should be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -1576,6 +1581,9 @@ def test_layernorm_linear_accuracy( ...@@ -1576,6 +1581,9 @@ def test_layernorm_linear_accuracy(
def test_layernorm_linear_accuracy_delay_wgrad_compute( def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
): ):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ln_linear_ref = LayerNormLinear( ln_linear_ref = LayerNormLinear(
...@@ -1709,8 +1717,15 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret ...@@ -1709,8 +1717,15 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute( def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, bias, fuse_wgrad_accumulation dtype,
bs,
model,
bias,
fuse_wgrad_accumulation,
): ):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ln_mlp = LayerNormMLP( ln_mlp = LayerNormMLP(
...@@ -1760,6 +1775,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1760,6 +1775,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy_checkpoint(
dtype,
bs,
model,
bias,
):
config = model_configs[model]
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=True,
).eval()
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
params_dtype=dtype,
device="cuda",
checkpoint=False,
).eval()
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False)
te_outputs_ref = _test_granular_accuracy(
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_grouped_linear_accuracy( def _test_grouped_linear_accuracy(
block, block,
num_gemms, num_gemms,
...@@ -1786,9 +1853,7 @@ def _test_grouped_linear_accuracy( ...@@ -1786,9 +1853,7 @@ def _test_grouped_linear_accuracy(
if num_gemms > 1: if num_gemms > 1:
split_size = 1 split_size = 1
if fp8: if fp8:
split_size = 16 split_size = get_align_size_for_quantization(recipe)
if recipe.mxfp8() or recipe.nvfp4():
split_size = 32
m = config.max_seqlen_q // split_size m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero dist.append(dist[-1]) # Manually add a zero
...@@ -1857,6 +1922,8 @@ def test_grouped_linear_accuracy( ...@@ -1857,6 +1922,8 @@ def test_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2001,6 +2068,8 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -2001,6 +2068,8 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -2106,9 +2175,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -2106,9 +2175,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16 align_size = get_align_size_for_quantization(recipe)
if recipe.mxfp8() or recipe.nvfp4():
align_size = 32
padded_tokens_per_expert = [ padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size (num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert for num_tokens in tokens_per_expert
...@@ -2725,7 +2792,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2725,7 +2792,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm( general_gemm(
A[i], A[i],
B[i], B[i],
get_workspace(),
dtype, dtype,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
...@@ -2739,8 +2805,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2739,8 +2805,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
A, A,
B, B,
out, out,
[None] * z,
dtype, dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits, m_splits=m_splits,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
...@@ -2800,7 +2866,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua ...@@ -2800,7 +2866,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out, *_ = general_gemm( quantized_out, *_ = general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
quantization_params=out_quantizer, quantization_params=out_quantizer,
bias=None, bias=None,
...@@ -2810,7 +2875,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua ...@@ -2810,7 +2875,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out, *_ = general_gemm( out, *_ = general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
quantization_params=None, quantization_params=None,
bias=None, bias=None,
...@@ -2886,7 +2950,6 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2886,7 +2950,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm( general_gemm(
A_fp8[i], A_fp8[i],
B_fp8[i], B_fp8[i],
get_workspace(),
dtype, dtype,
out=out_ref[i], out=out_ref[i],
accumulate=accumulate, accumulate=accumulate,
...@@ -2895,8 +2958,8 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2895,8 +2958,8 @@ def test_fp8_grouped_gemm(shape, accumulate):
A_fp8, A_fp8,
B_fp8, B_fp8,
out, out,
[None] * z,
dtype, dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits, m_splits=m_splits,
accumulate=accumulate, accumulate=accumulate,
) )
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -89,7 +89,7 @@ class TestParallelCrossEntropy: ...@@ -89,7 +89,7 @@ class TestParallelCrossEntropy:
# Check that loss and grad input match # Check that loss and grad input match
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
test_loss = test_loss.to(dtype=torch.float64, device="cpu") test_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = test_loss.to(dtype=torch.float64, device="cpu") ref_loss = ref_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size()) ref_loss = ref_loss.reshape(test_loss.size())
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
...@@ -154,3 +154,16 @@ class TestParallelCrossEntropy: ...@@ -154,3 +154,16 @@ class TestParallelCrossEntropy:
reduce_loss=False, reduce_loss=False,
ignore_idx=True, ignore_idx=True,
) )
def test_ignore_idx_reduced_loss(self):
"""Test ignore_idx with reduce_loss=True"""
self.generate_iters(5)
self.generate_infra(True, 0) # reduce_loss=True
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
ignore_idx=True,
)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch import is_mxfp8_available
from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset):
n = inp.view(-1).size(0)
if n == h * w:
full = inp.view(-1)
else:
full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = torch.abs(full)
_amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2)
amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise)
_amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1)
amax_colwise[: (h // 32), :w].copy_(_amax_colwise)
def partial_cast_reference(
inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset
):
rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32)
colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32)
n = inp.view(-1).size(0)
if n == h * w:
full = inp
else:
full = torch.empty(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = full.float()
rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float()
colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float()
scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1)
rowwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype)
)
scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1)
colwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype)
)
def run_one_case(n, h, w, start_offset):
inp = torch.randn(n, dtype=torch.bfloat16, device="cuda")
rowwise_padding = [128, 4]
colwise_padding = [4, 128]
def _pad(x, padding):
return (x + padding - 1) // padding * padding
rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])]
colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])]
# Partial amax cuda kernel
amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset)
# Partial amax pytorch reference
amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset)
# Check partial amax
torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0)
# Calculate scales and scale_invs
scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8)
scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8)
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
None,
[
[amax_rowwise, amax_colwise],
[scale_inv_rowwise, scale_inv_colwise],
],
)
# Partial cast cuda kernel
output_rowwise = torch.empty_like(inp).to(torch.uint8)
output_colwise = torch.empty_like(inp).to(torch.uint8)
tex.mxfp8_scaling_partial_cast(
inp,
output_rowwise,
output_colwise,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)
# Partial cast pytorch reference
output_rowwise_ref = torch.empty_like(inp).to(torch.uint8)
output_colwise_ref = torch.empty_like(inp).to(torch.uint8)
partial_cast_reference(
inp,
output_rowwise_ref,
output_colwise_ref,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)
# Check partial cast results
torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0)
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_mxfp8_scaling_partial_cast():
torch.cuda.manual_seed(1234)
run_one_case(3, 32, 64, 31)
run_one_case(64 * 64 - 2, 64, 64, 1)
run_one_case(16384 * 6144, 16384, 6144, 0)
run_one_case(32768, 256, 128, 0)
run_one_case(131072, 768, 256, 0)
run_one_case(65536, 768, 256, 131072)
run_one_case(98304, 128, 768, 0)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
import os
import random import random
import torch import torch
...@@ -13,6 +14,7 @@ from transformer_engine.common import recipe ...@@ -13,6 +14,7 @@ from transformer_engine.common import recipe
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
moe_permute as te_permute, moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs, moe_permute_with_probs as te_permute_with_probs,
moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs,
moe_unpermute as te_unpermute, moe_unpermute as te_unpermute,
moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
...@@ -24,6 +26,7 @@ from transformer_engine.pytorch import ( ...@@ -24,6 +26,7 @@ from transformer_engine.pytorch import (
MXFP8Quantizer, MXFP8Quantizer,
) )
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding
import copy import copy
seed = 1234 seed = 1234
...@@ -653,6 +656,522 @@ def _test_permutation_mask_map( ...@@ -653,6 +656,522 @@ def _test_permutation_mask_map(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_merging_probs=False,
align_size=16,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens is None:
num_out_tokens = num_tokens * topK
print(
"permutation and padding:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK}"
f" with_merging_probs:{with_merging_probs} align_size:{align_size} {te_dtype}"
)
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
else:
pytest.skip("Invalid dtype.")
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs = probs.to(dtype)
probs.requires_grad_(True)
tokens_per_expert = routing_map.sum(dim=0).cpu()
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()
permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_bwd_input = torch.rand(
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
).cuda()
unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_fwd_input.requires_grad_(True)
restore_shape = permute_pad_fwd_input.shape
###################################################################################################################################
#
# moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# permute + padding
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
tokens_per_expert_list = tokens_per_expert.tolist()
fp8_padding = Fp8Padding(num_expert, align_size)
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)
permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)
# unpadding + unpermute
unpermute_unpad_fwd_input = permuted_paded_output.detach()
unpermute_unpad_fwd_input.requires_grad_(True)
fp8_unpadding = Fp8Unpadding(num_expert, align_size)
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
probs_naive = probs
unpermuted_unpaded_output = te_unpermute(
unpaded_output,
row_id_map,
merging_probs=probs_naive if with_merging_probs else None,
restore_shape=restore_shape,
)
unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
###################################################################################################################################
#
# fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach()
fusion_permute_and_pad_fwd_input.requires_grad_(True)
probs_fusion = probs_naive.detach().clone()
probs_fusion.requires_grad_(True)
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs_fusion,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)
fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)
# fusion unpad and unpermute
fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach()
fusion_unpermute_unpad_fwd_input.requires_grad_(True)
fusion_unpermuted_unpaded_output = te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
merging_probs=probs_fusion if with_merging_probs else None,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
permuted_paded_output_ = permuted_paded_output.float()
fusion_permuted_padded_output_ = fusion_permuted_padded_output.float()
permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float()
fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float()
unpermuted_unpaded_output_ = unpermuted_unpaded_output.float()
fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float()
unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float()
fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close(
permuted_paded_output_,
fusion_permuted_padded_output_,
msg=f"Mismatch in te_permute_and_pad fwd",
**tols,
)
torch.testing.assert_close(
permute_pad_fwd_input_grad,
fusion_permute_and_pad_fwd_input_grad,
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
torch.testing.assert_close(
unpermuted_unpaded_output_,
fusion_unpermuted_unpaded_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
unpermute_unpad_fwd_input_grad,
fusion_unpermute_unpad_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
torch.testing.assert_close(
permuted_paded_probs.float(),
fusion_permuted_padded_probs.float(),
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
if with_merging_probs:
torch.testing.assert_close(
probs_naive.grad.float(),
probs_fusion.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
def permute_and_pad():
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
fp8_padding(permuted_output, tokens_per_expert_list)
fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)
def fusion_permute_and_pad():
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)
t1 = perf_test_cuda_kernel(lambda: permute_and_pad())
t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad())
print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
permuted_paded_output,
permute_pad_bwd_input,
forward_input=[permute_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_permuted_padded_output,
fusion_permute_pad_bwd_input,
forward_input=[fusion_permute_and_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
def unpad_unpermute():
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
unpermuted_unpaded_output = te_unpermute(
unpaded_output, row_id_map, restore_shape=restore_shape
)
unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
t1 = perf_test_cuda_kernel(lambda: unpad_unpermute())
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
)
print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
unpermuted_unpaded_output,
unpermute_unpad_bwd_input,
forward_input=([unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_unpermuted_unpaded_output,
fusion_unpermute_bwd_input,
forward_input=([fusion_unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
def _test_permutation_and_padding_with_merging_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
align_size=16,
BENCHMARK=False,
):
"""
Test the combination of merging_probs AND pad_offsets together in moe_unpermute.
This specifically tests the backward pass fix where pad_offsets must be used
when computing gradients with merging_probs.
"""
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
"permutation and padding with merging probs:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}"
)
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
else:
pytest.skip("Invalid dtype.")
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs = probs.to(dtype)
probs.requires_grad_(True)
tokens_per_expert = routing_map.sum(dim=0).cpu()
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()
permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_bwd_input = torch.rand(
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
).cuda()
unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_fwd_input.requires_grad_(True)
restore_shape = permute_pad_fwd_input.shape
###################################################################################################################################
#
# Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs
#
###################################################################################################################################
# permute + padding
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
tokens_per_expert_list = tokens_per_expert.tolist()
fp8_padding = Fp8Padding(num_expert, align_size)
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)
# Reference: unpadding + unpermute WITH merging_probs
ref_unpermute_fwd_input = permuted_paded_output.detach()
ref_unpermute_fwd_input.requires_grad_(True)
ref_probs = probs.detach()
ref_probs.requires_grad_(True)
fp8_unpadding = Fp8Unpadding(num_expert, align_size)
unpaded_output = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list)
ref_unpermuted_output = te_unpermute(
unpaded_output, row_id_map, ref_probs, restore_shape=restore_shape
)
ref_unpermuted_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_fwd_input = permute_pad_fwd_input.detach()
fusion_permute_fwd_input.requires_grad_(True)
fusion_probs = probs.detach()
fusion_probs.requires_grad_(True)
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
fused_row_id_map,
pad_offsets,
_,
) = te_permute_and_pad_with_probs(
fusion_permute_fwd_input,
fusion_probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)
# Fused: unpermute with BOTH merging_probs AND pad_offsets
fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach()
fusion_unpermute_fwd_input.requires_grad_(True)
fusion_merging_probs = probs.detach()
fusion_merging_probs.requires_grad_(True)
fusion_unpermuted_output = te_unpermute(
fusion_unpermute_fwd_input,
fused_row_id_map,
fusion_merging_probs,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
fusion_unpermuted_output.backward(fusion_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
# Check forward pass
ref_unpermuted_output_ = ref_unpermuted_output.float()
fusion_unpermuted_output_ = fusion_unpermuted_output.float()
if not BENCHMARK:
torch.testing.assert_close(
ref_unpermuted_output_,
fusion_unpermuted_output_,
msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets fwd",
**tols,
)
# Check backward pass - activation gradients
ref_unpermute_fwd_input_grad = ref_unpermute_fwd_input.grad.float()
fusion_unpermute_fwd_input_grad = fusion_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
ref_unpermute_fwd_input_grad,
fusion_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)",
**tols,
)
# Check backward pass - probs gradients
ref_probs_grad = ref_probs.grad.float()
fusion_probs_grad = fusion_merging_probs.grad.float()
torch.testing.assert_close(
ref_probs_grad,
fusion_probs_grad,
msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)",
**tols,
)
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
def ref_unpad_unpermute():
unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list)
return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape)
def fused_unpermute():
return te_unpermute(
fusion_unpermute_fwd_input,
fused_row_id_map,
fusion_merging_probs,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
t1 = perf_test_cuda_kernel(lambda: ref_unpad_unpermute())
t2 = perf_test_cuda_kernel(lambda: fused_unpermute())
print(f"unpermute_unpad_with_probs\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
ref_unpermuted_output,
unpermute_unpad_bwd_input,
forward_input=[ref_unpermute_fwd_input, ref_probs],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_unpermuted_output,
fusion_unpermute_bwd_input,
forward_input=[fusion_unpermute_fwd_input, fusion_merging_probs],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute_unpad_with_probs\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
def _test_permutation_mask_map_fp8( def _test_permutation_mask_map_fp8(
te_dtype, te_dtype,
num_tokens, num_tokens,
...@@ -1126,7 +1645,7 @@ if te.is_bf16_available(): ...@@ -1126,7 +1645,7 @@ if te.is_bf16_available():
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_index_map( def test_permutation_index_map(
te_dtype, te_dtype,
...@@ -1155,7 +1674,7 @@ def test_permutation_index_map( ...@@ -1155,7 +1674,7 @@ def test_permutation_index_map(
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map( def test_permutation_mask_map(
te_dtype, te_dtype,
...@@ -1180,6 +1699,74 @@ def test_permutation_mask_map( ...@@ -1180,6 +1699,74 @@ def test_permutation_mask_map(
) )
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
"num_tokens, num_expert, hidden_size, topK",
[
(4096, 8, 1280, 2),
(4096, 64, 4096, 6),
(4096, 256, 7168, 6),
(4096, 512, 9216, 8),
],
)
@pytest.mark.parametrize("with_merging_probs", [True, False])
def test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_merging_probs,
):
BENCHMARK = False
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_merging_probs=with_merging_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
"num_tokens, num_expert, hidden_size, topK",
[
(4096, 8, 1280, 2),
(4096, 64, 4096, 6),
(4096, 256, 7168, 6),
(4096, 512, 9216, 8),
],
)
def test_permutation_and_padding_with_merging_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
"""Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets."""
BENCHMARK = False
_test_permutation_and_padding_with_merging_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_empty_input(te_dtype): def test_permutation_mask_map_empty_input(te_dtype):
with_probs = True with_probs = True
...@@ -1201,9 +1788,9 @@ def test_permutation_mask_map_empty_input(te_dtype): ...@@ -1201,9 +1788,9 @@ def test_permutation_mask_map_empty_input(te_dtype):
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("tp_size", [1, 2, 8]) @pytest.mark.parametrize("tp_size", [1, 2])
def test_permutation_mask_map_alongside_probs( def test_permutation_mask_map_alongside_probs(
te_dtype, te_dtype,
num_tokens, num_tokens,
...@@ -1253,10 +1840,10 @@ fp8_recipes = [ ...@@ -1253,10 +1840,10 @@ fp8_recipes = [
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
def test_permutation_mask_map_fp8( def test_permutation_mask_map_fp8(
...@@ -1341,7 +1928,7 @@ def test_permutation_mask_map_topk1_no_probs( ...@@ -1341,7 +1928,7 @@ def test_permutation_mask_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8]) @pytest.mark.parametrize("tp_size", [2, 8])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation( def test_chunk_permutation(
te_dtype, te_dtype,
...@@ -1376,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype): ...@@ -1376,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype):
) )
@pytest.mark.skipif(
os.getenv("RUN_BENCHMARK_TESTS", "0") != "1",
reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case",
)
def test_permutation_single_case(): def test_permutation_single_case():
print("GPU:", torch.cuda.get_device_name(0)) print("GPU:", torch.cuda.get_device_name(0))
...@@ -1413,6 +2004,26 @@ def test_permutation_single_case(): ...@@ -1413,6 +2004,26 @@ def test_permutation_single_case():
BENCHMARK=Benchmark, BENCHMARK=Benchmark,
) )
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=Benchmark,
)
_test_permutation_and_padding_with_merging_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=Benchmark,
)
_test_moe_chunk_sort( _test_moe_chunk_sort(
te_dtype=te_dtype, te_dtype=te_dtype,
num_tokens=num_tokens, num_tokens=num_tokens,
...@@ -1479,6 +2090,30 @@ def benchmark_single_case( ...@@ -1479,6 +2090,30 @@ def benchmark_single_case(
) )
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_and_padding_mask_map")
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_and_padding_with_merging_probs")
_test_permutation_and_padding_with_merging_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs( _test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype, te_dtype=te_dtype,
...@@ -1495,7 +2130,12 @@ def benchmark_single_case( ...@@ -1495,7 +2130,12 @@ def benchmark_single_case(
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
def benchmark_multiple_cases(): @pytest.mark.skipif(
os.getenv("RUN_BENCHMARK_TESTS", "0") != "1",
reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark",
)
def test_benchmark_multiple_cases():
"""Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark"""
print("GPU:", torch.cuda.get_device_name(0)) print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32 # te_dtype = tex.DType.kFloat32
...@@ -1537,4 +2177,4 @@ def benchmark_multiple_cases(): ...@@ -1537,4 +2177,4 @@ def benchmark_multiple_cases():
if __name__ == "__main__": if __name__ == "__main__":
benchmark_multiple_cases() test_benchmark_multiple_cases()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -13,9 +13,16 @@ import transformer_engine.common.recipe ...@@ -13,9 +13,16 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
Float8Quantizer, Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
Float8Tensor,
MXFP8Tensor,
NVFP4Tensor,
QuantizedTensor,
) )
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -44,8 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List: ...@@ -44,8 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List:
# Types that can be interpreted as tensor dims # Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int] DimsType = Union[Iterable[int], int]
# Check if FP8 is supported # Supported quantization recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
_quantization_list: List[str] = []
if fp8_available:
_quantization_list.append("fp8")
if fp8_block_scaling_available:
_quantization_list.append("fp8_blockwise")
if mxfp8_available:
_quantization_list.append("mxfp8")
if nvfp4_available:
_quantization_list.append("nvfp4")
# delayed scaling # delayed scaling
...@@ -86,6 +107,79 @@ def to_float8_CS( ...@@ -86,6 +107,79 @@ def to_float8_CS(
return quantizer(tensor) return quantizer(tensor)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if quantization is None:
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor: class TestFloat8Tensor:
...@@ -452,3 +546,113 @@ class TestCurrentScalingFloat8Tensor: ...@@ -452,3 +546,113 @@ class TestCurrentScalingFloat8Tensor:
# Make sure we are not trivially passing the test # Make sure we are not trivially passing the test
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])
class TestQuantizedTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("op", ("clone", "view", "reshape", "contiguous"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_identity_op(
self,
*,
op: str,
quantization: str,
shape: Iterable[int] = (128, 128),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
"""Test operations that do not affect tensor values.
These operations are must produce outputs that are bit-wise
equivalent to the inputs. They must support autograd.
"""
# Create reference and quantized tensor
x_ref, x_test = make_reference_and_test_tensors(
shape=shape,
quantization=quantization,
test_dtype=dtype,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape=shape,
test_dtype=dtype,
requires_grad=False,
)
# Apply identity operation
if op == "clone":
y_ref = x_ref.clone()
y_test = x_test.clone()
elif op == "view":
y_ref = x_ref.view(shape)
y_test = x_test.view(shape)
elif op == "reshape":
y_ref = x_ref.reshape(shape)
y_test = x_test.reshape(shape)
elif op == "contiguous":
y_ref = x_ref.contiguous()
y_test = x_test.contiguous()
# Check autograd
y_test.backward(dy_test)
assert x_test.grad is not None
# Check values
tols = dict(rtol=0, atol=0)
if isinstance(y_test, QuantizedTensor):
y_test = y_test.dequantize()
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dx_ref = dy_ref
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("dim", [0, 1])
def test_chunk(
self,
*,
quantization: str,
dim: int,
shape: Iterable[int] = (128, 128),
chunks: int = 2,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
# Create reference and quantized tensor
x_ref, x_test = make_reference_and_test_tensors(
shape=shape,
quantization=quantization,
test_dtype=dtype,
)
# Chunk tensors
ys_ref = torch.chunk(x_ref, chunks, dim=dim)
ys_test = torch.chunk(x_test, chunks, dim=dim)
# Check splits
for y_ref, y_test in zip(ys_ref, ys_test):
# Check split shapes
assert y_ref.size() == y_test.size()
# Check that splits are quantized when expected
if quantization == "fp8":
assert isinstance(y_test, Float8Tensor)
y_test = y_test.dequantize()
elif quantization == "mxfp8" and dim == 0:
assert isinstance(y_test, MXFP8Tensor)
y_test = y_test.dequantize()
# Check values
tols = dict(rtol=0, atol=0) # Chunking is exact
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -30,7 +30,6 @@ from transformer_engine.pytorch.quantization import ( ...@@ -30,7 +30,6 @@ from transformer_engine.pytorch.quantization import (
) )
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment