Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# delayed scaling
def to_float8(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
quantizer = Float8Quantizer(
scale=torch.full([1], scale, dtype=torch.float32, device="cuda"),
amax=torch.empty([1], dtype=torch.float32, device="cuda"),
fp8_dtype=fp8_dtype,
)
return quantizer(tensor.cuda())
# current scaling
def to_float8_CS(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
tensor = tensor.cuda()
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=fp8_dtype,
device=tensor.device,
force_pow_2_scales=force_pow_2_scales,
amax_epsilon=amax_epsilon,
)
if return_transpose:
quantizer.set_usage(rowwise=True, columnwise=True)
else:
quantizer.set_usage(rowwise=True, columnwise=False)
return quantizer(tensor)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale_inv: float = 0.375,
dtype: torch.dtype = torch.float32,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
tensor = Float8Tensor(
shape=dims,
dtype=dtype,
data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.full([1], scale_inv),
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"
def _test_quantize_dequantize(
self,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Check numerical error when casting to FP8 and back"""
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
# Cast to FP8 and back
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
x_fp8 = x_fp8.dequantize().cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
def test_quantize_dequantize_dtypes(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
) -> None:
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)
@pytest.mark.parametrize("scale", [0.375, 1, 3.5])
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)
@pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
def test_basic_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test basic out-of-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
x_ref = x_fp8.dequantize()
y_ref = y_fp8.dequantize()
# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)
# Operations with numerical error
tols = _tols[fp8_dtype]
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
@pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]])
def test_chunk_op(
self,
dims: DimsType,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test for ops for which shape of inputs and outputs differ."""
# Initialize random data
dims = _to_list(dims)
x_ref = torch.randn(dims, dtype=dtype, device="cpu")
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0)
# Get chunks.
chunk1, chunk2 = x_fp8.chunk(2, dim=0)
# Test chunks.
torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0)
torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0)
# Check shapes.
assert (
chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
), "Wrong shape for chunk1"
assert (
chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
), "Wrong shape for chunk2"
def test_inplace_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test in-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
x_ref = x_fp8.dequantize()
y_ref = y_fp8.dequantize()
# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.dequantize()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.dequantize()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.dequantize()
# Make sure we are not trivially passing tests
x_ref += 123
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
def test_serialization(
self,
dims: DimsType = [2, 3, 5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
):
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
x_ref = x_fp8.dequantize()
# Serialize tensor
byte_stream = io.BytesIO()
torch.save(x_fp8, byte_stream)
x_bytes = byte_stream.getvalue()
# Mess up and delete old tensor
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
del x_fp8, byte_stream
# Deserialize tensor
x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False)
del x_bytes
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(x_fp8, x_ref, **tols)
# Make sure we are not trivially passing tests
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
def test_set_data(self):
"""Test directly setting .data attr"""
# Initialize Float8Tensor
x0 = torch.zeros(4, dtype=torch.float32)
x = to_float8(x0)
assert isinstance(x, Float8Tensor)
assert x0.size() == x.size() == x._data.size()
assert x.dtype == torch.float32
assert x.is_cuda and x._data.is_cuda
y = x.dequantize()
assert not isinstance(y, Float8Tensor)
assert x.size() == y.size()
assert x.dtype == y.dtype
assert x.device == y.device
# Set data to plain tensor
x0 = torch.zeros((3, 2), dtype=torch.float16, device=x.device)
x.data = x0
assert isinstance(x, Float8Tensor)
assert x0.size() == x.size() == x._data.size()
assert x0.dtype == x.dtype
assert x0.device == x.device == x._data.device
y = x.dequantize()
assert not isinstance(y, Float8Tensor)
assert x.size() == y.size()
assert x.dtype == y.dtype
assert x.device == y.device
# Set data to Float8Tensor
x0 = to_float8(torch.zeros((4, 3, 1), dtype=torch.float32))
x.data = x0
assert isinstance(x, Float8Tensor)
assert x0.size() == x.size() == x._data.size()
assert x0.dtype == x.dtype
assert x0.device == x.device == x._data.device
assert x0._data is x._data
assert x0._scale_inv is x._scale_inv
y = x.dequantize()
assert not isinstance(y, Float8Tensor)
assert x.size() == y.size()
assert x.dtype == y.dtype
assert x.device == y.device
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestCurrentScalingFloat8Tensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize(
"dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]]
)
@pytest.mark.parametrize("return_transpose", [True, False], ids=str)
@pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str)
@pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str)
def test_quantize(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
dims: DimsType,
return_transpose: bool,
force_pow_2_scales: bool,
amax_epsilon: float,
) -> None:
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data
device = "cuda"
x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1
# Cast to FP8 and back
x_fp8 = to_float8_CS(
x_hp,
fp8_dtype=fp8_dtype,
return_transpose=return_transpose,
force_pow_2_scales=force_pow_2_scales,
amax_epsilon=amax_epsilon,
)
# get reference implementation of current scaling
x_fp8_ref, sx_ref, x_fp8_t_ref, _ = ref_per_tensor_cs_cast(
x_hp,
fp8_dtype=fp8_dtype,
return_transpose=return_transpose,
force_pow_2_scales=force_pow_2_scales,
amax_epsilon=amax_epsilon,
)
torch.testing.assert_close(x_fp8._data, x_fp8_ref.view(torch.uint8), atol=0.0, rtol=0.0)
torch.testing.assert_close(x_fp8._scale_inv, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(
x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
def test_quantize_dequantize(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType
) -> None:
"""Check numerical error when casting to FP8 and back"""
# Initialize random high precision data
device = "cuda"
x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_hp, fp8_dtype=fp8_dtype)
x_fp8_dequantized = x_fp8.dequantize()
# Check results
torch.testing.assert_close(x_fp8_dequantized, x_hp, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from itertools import product
import copy
from contextlib import nullcontext
import pytest
import torch
from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
class TestFusedOptimizer:
def setup_method(self, *, iters: int = 7) -> None:
self.iters = iters
torch.manual_seed(9876)
def gen_param_optim(self, tensors, options, tst_options=None):
# Adding this to make backward compatible with existing tests. Just in
# case "tst_options" are not provided, it gets a copy of options
# which contains the parameters for the reference optimizer
if tst_options == None:
tst_options = options
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = self.ref_optim(ref_param, **options)
tst_optim = self.fused_optim(tst_param, **tst_options)
return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
for p_ref, p_tst in zip(ref_param, tst_param):
half_grads.append(torch.rand_like(p_ref).half())
p_ref.grad = half_grads[-1].float() / scale
return half_grads
def gen_single_type_test(
self, param_type=torch.float, device="cuda", *, skip_assert: bool = False
):
nelem = 278011
# Some ref and test optimizers may require different set of options.
# This is a quick workaround to add that functionality while making
# minimum changes in existing code.
# If there is no "tst_options" field provided, safe to initialize
# the test optimizer with the parameters of reference optimizer.
if not hasattr(self, "tst_options"):
self.tst_options = self.options
tensor = torch.rand(nelem, dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], self.options, self.tst_options
)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
if skip_assert:
return
torch.testing.assert_close(ref_param, tst_param)
class TestFusedAdam(TestFusedOptimizer):
def setup_method(self) -> None:
super().setup_method()
self.options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
self.ref_optim = torch.optim.Adam
self.fused_optim = te.optimizers.FusedAdam
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
# NOTE(mkozuki): Current threshold values look too small for BFloat16.
# TODO(mkozuki): Refactor `TestFusedOptimizer`
def test_half(self):
self.gen_single_type_test(param_type=torch.float16, skip_assert=True)
def test_bfloat16(self):
self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_param, tst_param)
def test_adam_option(self):
nelem = 1
adam_option = {
"lr": 0.01,
"betas": (0.6, 0.9),
"eps": 3e-06,
"weight_decay": 0,
"amsgrad": False,
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_param, tst_param)
def test_frozen_model(self):
nelem = 1
adam_option = {
"lr": 0.01,
"betas": (0.6, 0.9),
"eps": 3e-06,
"weight_decay": 0,
"amsgrad": False,
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
# Add an empty param group which may occur for pipeline parallel p-tuning
tst_optim.add_param_group({"params": []})
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_param, tst_param)
def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
store_param_remainders=False,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=param_dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)
def test_one_iteration(ref_optimizer, tst_optimizer):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights and not store_param_remainders:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
state_dict = tst_optim.state_dict()
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)
tst_optim.load_state_dict(state_dict)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
store_param_remainders=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16
with fp8_model_init(enabled=True, recipe=DelayedScaling()):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True
)
class TestFusedSGD(TestFusedOptimizer):
def setup_method(self) -> None:
super().setup_method()
self.options = {"lr": 0.25, "momentum": 0.125}
self.ref_optim = torch.optim.SGD
self.fused_optim = te.optimizers.FusedSGD
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.reshape(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
class AdamTest:
def setup_method(self, *, seed: int = 0) -> None:
torch.manual_seed(seed)
self.model = Model().cuda()
self.model_ = Model().cuda()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
self.lr = 0.00001
params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.Adam(params, lr=self.lr)
def test_grad_scaler(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
scaler_.scale(loss_).backward()
scaler_.step(optimizer_)
scaler_.update()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad,
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
def test_grad_scaler_capturable(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
scaler_.scale(loss_).backward()
scaler_.step(optimizer_)
scaler_.update()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad,
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
def test_grad_scaler_capturable_master(self):
# Cast conv layers to FP16
for m in self.model_.modules():
if m.__class__ in [torch.nn.Conv2d]:
m.half()
params_ = [p for p in self.model_.parameters() if p.requires_grad]
master_weights = [p.float() for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(
params_, lr=self.lr, capturable=True, master_weights=master_weights
)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
scaler_.scale(loss_).backward()
scaler_.step(optimizer_)
scaler_.update()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight,
m_.weight.float(),
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad.float(),
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
def test_native(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
y = self.model(x)
loss = ((gt - y) ** 2).mean()
loss.backward()
self.optimizer.step()
# DUT
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
loss_.backward()
optimizer_.step()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad,
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
@largeTensorTest("60GB", "cuda")
def test_large_tensor(self):
t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
grad = torch.randn_like(t)
t.grad = grad
t2.grad = grad
params = [t]
params2 = [t2]
optimizer = te.optimizers.FusedAdam(params, lr=self.lr)
optimizer.step()
optimizer2 = torch.optim.Adam(params2, lr=self.lr)
torch.testing.assert_close(t, t2)
torch.cuda.synchronize()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import math
import pytest
import torch
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.dot_product_attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output)
return torch.sum(output * t)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
fused=True,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
cp_size: int,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
if cp_size > 1:
cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)):
cu_seqlens_padded.append(
cu_seqlens_padded[i - 1]
+ math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2)
)
else:
cu_seqlens_padded = cu_seqlens
cu_seqlens_padded = torch.tensor(
cu_seqlens_padded,
dtype=torch.int32,
device=device,
)
t = torch.rand(
(cu_seqlens_padded[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(cu_seqlens_padded[-1])
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
fused=True,
tensor_format="thd",
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional
import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
BackwardLinearAdd,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16)
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]
def maybe_skip_quantization(
quantization: Optional[str],
*,
dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None,
) -> None:
# Don't skip if there is no quantization
if quantization is None:
return
# Check if quantization scheme is supported
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if dims is not None:
if not isinstance(dims, Iterable):
dims = (dims,)
if quantization == "fp8":
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")
# Check if device is supported
if device is not None and torch.device(device).type != "cuda":
pytest.skip("Quantization is only supported on CUDA devices")
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
class TestSequential:
"""Tests for sequential container"""
def test_modules(self) -> None:
"""Check that list of modules can be manipulated as expected"""
# Construct sequential container
modules = [
te_ops.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
]
model = te_ops.Sequential(*modules)
# Length
assert len(model) == len(modules)
# Iterator
for module1, module2 in zip(model, modules):
assert module1 is module2
# Index by int
for i, module in enumerate(modules):
assert model[i] is module
assert model[i - len(modules)] is module
# Index by slice
model_subset = model[1:-1]
modules_subset = modules[1:-1]
assert isinstance(model_subset, te_ops.Sequential)
for module1, module2 in zip(model_subset, modules_subset):
assert module1 is module2
# Set element
new_module = torch.nn.Identity()
idx = 1
modules[idx] = new_module
model[idx] = new_module
for module1, module2 in zip(model, modules):
assert module1 is module2
# Delete element
idx = 1
del modules[idx]
del model[idx]
for module1, module2 in zip(model, modules):
assert module1 is module2
# Append
new_module = torch.nn.Identity()
modules.append(new_module)
model.append(new_module)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Extend
new_modules = [te_ops.Identity(), te_ops.Identity()]
modules.extend(new_modules)
model.extend(new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Insert
new_module = te_ops.Identity()
idx = 2
modules.insert(idx, new_module)
model.insert(idx, new_module)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Pop
idx = 2
assert model.pop(idx) is modules.pop(idx)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Out-of-place add
new_modules = [torch.nn.Identity(), te_ops.Identity()]
added_modules = modules + new_modules
added_model = model + te_ops.Sequential(*new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
for module1, module2 in zip(added_model, added_modules):
assert module1 is module2
# In-place add
new_modules = [te_ops.Identity(), torch.nn.Identity()]
modules += new_modules
model += te_ops.Sequential(*new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
def test_module_groups(self) -> None:
"""Check that modules are grouped together correctly"""
model = te_ops.Sequential(
te_ops.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
te_ops.Identity(),
te_ops.Identity(),
)
model(torch.zeros(1))
assert len(model._module_groups) == 6
class TestFuser:
"""Tests for operation fusion infrastructure"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
self,
size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Test FP8 scaling factors with delayed scaling recipe"""
# FP8 recipe
margin = 2
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
)
# Construct model
with te.fp8_model_init(recipe=recipe):
model = te_ops.basic.BasicLinear(
size,
size,
device=device,
dtype=dtype,
)
# Training steps
w_vals = [2, 5, 3, 11]
x_vals = [7, 3, 5]
dy_vals = [1, 2, 1]
with torch.no_grad():
model.weight.fill_(w_vals[0])
for step in range(3):
# Data tensors
x = torch.full(
(size, size),
x_vals[step],
dtype=dtype,
device=device,
requires_grad=True,
)
dy = torch.full(
(size, size),
dy_vals[step],
dtype=dtype,
device=device,
)
# Training step
with te.fp8_autocast(fp8_recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
model.weight.fill_(w_vals[step + 1])
# Check that output tensors match expected
tols = dict(rtol=0, atol=0)
y_val_ref = w_vals[step] * x_vals[step] * size
dx_val_ref = w_vals[step] * dy_vals[step] * size
torch.testing.assert_close(
y,
torch.full_like(y, y_val_ref),
**dtype_tols(tex.DType.kFloat8E4M3),
)
torch.testing.assert_close(
x.grad,
torch.full_like(x.grad, dx_val_ref),
**dtype_tols(tex.DType.kFloat8E5M2),
)
# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
w_scale = model.get_quantizer("forward", 1).scale
x_scale = model.get_quantizer("forward", 0).scale
dy_scale = model.get_quantizer("backward", 0).scale
torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))
@pytest.mark.parametrize("init_dtype", _dtypes)
@pytest.mark.parametrize("final_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_dtype_cast(
self,
*,
size: int = 32,
init_dtype: torch.dtype,
final_dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Check dtype cast functions"""
# Skip invalid configurations
maybe_skip_quantization(quantization, device=device)
with_quantization = quantization is not None
# Random data
dtype = torch.float32
if torch.float16 in (init_dtype, final_dtype):
dtype = torch.float16
if torch.bfloat16 in (init_dtype, final_dtype):
dtype = torch.bfloat16
w_ref, w_test = make_reference_and_test_tensors(
(size, size),
test_dtype=dtype,
test_device=device,
test_is_fp8=with_quantization,
)
# Construct operation
with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
# Cast operation dtype
if final_dtype == torch.float32:
op.float()
elif final_dtype == torch.float16:
op.half()
elif final_dtype == torch.bfloat16:
op.bfloat16()
# Check weights
assert isinstance(op.weight, QuantizedTensor) == with_quantization
assert op.weight.dtype == final_dtype
w_test = op.weight.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0)
# Check forward and backward pass
x = torch.zeros(
(size, size),
dtype=init_dtype,
device=device,
requires_grad=True,
)
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == final_dtype
assert x.grad.dtype == init_dtype
assert op.weight.grad.dtype == final_dtype
@pytest.mark.parametrize("model_dtype", _dtypes)
@pytest.mark.parametrize("autocast_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_pyt_autocast(
self,
*,
size: int = 32,
model_dtype: torch.dtype,
autocast_dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weights: bool = False,
) -> None:
"""Test with PyTorch autocast"""
device = torch.device(device)
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization)
# Construct operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass
x = torch.zeros(
(size, size),
dtype=model_dtype,
device=device,
requires_grad=True,
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype
# Check forward and backward pass (swapped context order)
if quantized_compute:
x.grad = None
op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype
class TestBasicOps:
"""Tests for individual operations"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
def test_identity(
self,
*,
in_shape: Iterable[int] = (1,),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref
dx_ref = dy_ref
# Implementation with fusible operation
op = te_ops.Identity()
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dict(rtol=0, atol=0) # Identity is exact
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(y_test, -y_ref, **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(dx_test, -dx_ref, **tols)
@pytest.mark.parametrize(
"shapes",
(
((1, 2, 3, 4), (2, 12)),
((5, 4, 3, 2), (-1, 6)),
((30,), (2, 3, -1)),
((6, 7), (3, -1, 7)),
),
)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8", (False, True))
def test_reshape(
self,
*,
shapes: tuple[Iterable[int], Iterable[int]],
dtype: torch.dtype,
device: torch.device = "cuda",
memory_format: torch.memory_format = torch.contiguous_format,
fp8: bool,
) -> None:
in_shape, out_shape = shapes
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
x_test = x_test.contiguous(memory_format=memory_format)
x_test = x_test.detach().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
x_ref.reshape(out_shape).size(),
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref.reshape(out_shape)
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.Reshape(out_shape)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dict(rtol=0, atol=0) # Reshape is exact
y_test = y_test.to(
dtype=torch.float64,
device="cpu",
memory_format=torch.contiguous_format,
)
dx_test = x_test.grad.to(
dtype=torch.float64,
device="cpu",
memory_format=torch.contiguous_format,
)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("fp8", (False, True))
def test_bias(
self,
*,
size: int,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
# Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size]
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
b_ref, b_test = make_reference_and_test_tensors(
size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.Bias(size, device=device, dtype=dtype)
with torch.no_grad():
op.bias.copy_(b_test)
del b_test
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
@pytest.mark.parametrize("cast_forward", (False, True))
@pytest.mark.parametrize("cast_backward", (False, True))
def test_quantize(
self,
*,
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
quantization: str,
cast_forward: bool,
cast_backward: bool,
) -> None:
"""Quantize"""
# Skip invalid configurations
maybe_skip_quantization(quantization)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
test_is_fp8=True,
)
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
test_is_fp8=True,
)
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = x_ref
dx_ref = dy_ref
# Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization)
with te.fp8_autocast(fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Check tensor types
assert isinstance(y_test, QuantizedTensor) == cast_forward
assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
# Check values
tols = dict(rtol=0, atol=0)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
def _test_basic_linear(
self,
*,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
quantized_compute: bool = False,
quantized_input: bool = False,
quantized_weight: bool = False,
quantized_output: bool = False,
quantized_grad_output: bool = False,
quantized_grad_input: bool = False,
accumulate_into_main_grad: bool = False,
) -> None:
"""Helper function for tests with GEMM"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization == "fp8" and quantized_output and not quantized_compute:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantization == "fp8" and quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input:
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_input),
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_grad_output),
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute or quantized_output or quantized_grad_input:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if accumulate_into_main_grad:
if op.weight.grad is not None:
torch.testing.assert_close(
op.weight.grad,
torch.zeros_like(op.weight.grad),
rtol=0,
atol=0,
)
dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
else:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(
op.weight.main_grad,
torch.full_like(op.weight.main_grad, 0.5),
rtol=0,
atol=0,
)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
def test_basic_linear(
self,
*,
weight_shape: tuple[int, int],
in_shape: Iterable[int],
dtype: torch.dtype,
quantization: Optional[str],
accumulate_into_main_grad: bool,
) -> None:
"""GEMM"""
self._test_basic_linear(
weight_shape=weight_shape,
in_shape=in_shape,
dtype=dtype,
quantization=quantization,
quantized_compute=quantization is not None,
accumulate_into_main_grad=accumulate_into_main_grad,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_input", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("quantized_output", (False, True))
@pytest.mark.parametrize("quantized_grad_output", (False, True))
@pytest.mark.parametrize("quantized_grad_input", (False, True))
def test_basic_linear_quantized(
self,
*,
quantization: str,
quantized_compute: bool,
quantized_input: bool,
quantized_weight: bool,
quantized_output: bool,
quantized_grad_output: bool,
quantized_grad_input: bool,
) -> None:
"""GEMM with FP8 inputs and outputs"""
self._test_basic_linear(
dtype=torch.bfloat16,
quantization=quantization,
quantized_compute=quantized_compute,
quantized_input=quantized_input,
quantized_weight=quantized_weight,
quantized_output=quantized_output,
quantized_grad_output=quantized_grad_output,
quantized_grad_input=quantized_grad_input,
)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
self,
*,
bias: bool,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""GEMM + bias"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
op.weight.copy_(w_test)
if bias:
op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_layer_norm(
self,
*,
weight_shape: Iterable[int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 0.3,
zero_centered_gamma: bool,
quantization: Optional[str],
) -> None:
"""Layer norm"""
# Make input and weight shapes consistent
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
b_ref, b_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.layer_norm(
x_ref,
weight_shape,
weight=(w_ref + 1 if zero_centered_gamma else w_ref),
bias=b_ref,
eps=eps,
)
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.LayerNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
)
with torch.no_grad():
op.weight.copy_(w_test)
op.bias.copy_(b_test)
del w_test
del b_test
quantized_compute = quantization is not None
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
op,
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
def test_layer_norm_autocast(
self,
*,
weight_shape: Iterable[int] = (32,),
in_shape: Iterable[int] = (32,),
dtype: torch.dtype = torch.float16,
autocast_dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
eps: float = 0.3,
) -> None:
"""Layer norm with PyTorch autocast"""
# Make input and weight shapes consistent
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=autocast_dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
b_ref, b_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=autocast_dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.layer_norm(
x_ref,
weight_shape,
weight=w_ref,
bias=b_ref,
eps=eps,
)
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.LayerNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
)
with torch.no_grad():
op.weight.copy_(w_test)
op.bias.copy_(b_test)
del w_test
del b_test
with torch.autocast(device, dtype=autocast_dtype):
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
assert y_test.dtype == autocast_dtype
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype))
torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype))
torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype))
torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype))
@pytest.mark.parametrize("weight_shape", ((19,), (64,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_rmsnorm(
self,
*,
weight_shape: Iterable[int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 0.3,
zero_centered_gamma: bool,
quantization: Optional[str],
) -> None:
"""Layer norm"""
# Make input and weight shapes consistent
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
if zero_centered_gamma:
y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
else:
y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.RMSNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
quantized_compute = quantization is not None
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
op,
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
def test_add_in_place(
self,
*,
in_shape: Iterable[int] = (1,),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
"""Add two tensors
Join in compute graph.
"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
x2_ref, x2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x2_ref.detach()
y_ref += x1_ref
dx1_ref = dy_ref
dx2_ref = dy_ref
# Implementation with fusible operation
op = te_ops.AddInPlace()
y_test = op(x1_test, x2_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(x1_test._fp8_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
def test_make_extra_output(
self,
*,
in_shape: Iterable[int] = (1,),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
"""Output tensor twice
Split in compute graph.
"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y1_ref = x_ref
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operation
op = te_ops.MakeExtraOutput()
y1_test, y2_test = op(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
# Check results
tols = dtype_tols(dtype)
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0)
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Activation functions"""
# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu":
tols = {"atol": 0, "rtol": 0}
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
):
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.SwiGLU(),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
class TestFusedOps:
"""Tests for fused operations"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_forward_linear_bias_activation(
self,
*,
bias: bool = True,
weight_shape: tuple[int, int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""Forward GEMM + bias + activation"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if dtype not in (torch.float16, torch.bfloat16):
pytest.skip(
"FP8 fused linear-bias-activation is only supported with FP16 or BF16 output"
)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
if bias:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_forward_linear_bias_add(
self,
*,
bias: bool,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Forward GEMM + bias + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x1_test, QuantizedTensor):
with torch.no_grad():
x1_test = x1_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
)
x2_ref, x2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
),
te_ops.AddInPlace(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
if bias:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_backward_linear_add(
self,
*,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Backward dgrad GEMM + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y1_ref = torch.nn.functional.linear(x_ref, w_ref)
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardLinearAdd)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, **tols)
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
batch_size = 32
seq_length = 2048
num_heads = 16
head_dim = 64
dtype = torch.bfloat16
num_attn_head = 16
ffn_hidden_size = 1024
@pytest.mark.parametrize("kv_channels", [128, 256])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16])
def test_gqa(kv_channels, hidden_size, num_gqa_groups) -> None:
model = te.TransformerLayer(
hidden_size, ffn_hidden_size, num_attn_head, num_gqa_groups, kv_channels=kv_channels
)
# Run forward pass
x = torch.randn((batch_size, 1, hidden_size)).cuda()
model(x)
# Check shapes of weights.
assert model.self_attention.layernorm_qkv.key_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.key_weight.shape[1] == hidden_size
assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head
assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size
assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size
assert model.self_attention.proj.weight.shape[0] == hidden_size
assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import pytest
import torch
import transformer_engine.pytorch as te
# Model names for test_torch_dynamo
_model_factory = {
"Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
"LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
"LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
"LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
"TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
}
@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
@pytest.mark.parametrize("model_name", list(_model_factory.keys()))
def test_torch_dynamo(model_name: str):
"""Test compatibility with Torch Dynamo
Construct model, optimize with Torch Dynamo, and perform a single
forward and backward pass.
"""
# Helper function to construct tensor with default options
def make_tensor(
dims: Tuple[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
requires_grad: bool = True,
**kwargs,
):
return torch.zeros(
dims,
dtype=dtype,
device=device,
requires_grad=requires_grad,
**kwargs,
)
# Construct model and input tensors
model_builder, input_builder = _model_factory[model_name]
model = model_builder()
inputs = [make_tensor(input_builder)]
# Optimize model with TorchDynamo
torch.compile(model)
# Forward and backward pass
out = model(*inputs)
out.backward(torch.zeros_like(out))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply
input_size_pairs = [
(7777 * 77, 555 * 555),
(777, 555),
(555, 2048 * 32 + 1),
(2048 * 32 + 1, 555),
(555, 2048 * 32),
(2048 * 32, 555),
(33333, 555),
(555, 33333),
]
appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)]
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("inplace", [False, True])
def test_multi_tensor_scale(input_size_pair, applier, repeat, in_type, out_type, inplace):
if inplace is True and (out_type is not in_type):
pytest.skip("inplace=True and out_type != in_type is not supported.")
elif (in_type == torch.float16 and out_type == torch.bfloat16) or (
in_type == torch.bfloat16 and out_type == torch.float16
):
pytest.skip("float16 to bfloat16 is not necessary and vice versa.")
device = torch.device("cuda")
scale = 4.0
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
ref = torch.tensor([1.0], dtype=torch.float32, device=device)
sizea, sizeb = input_size_pair
def downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=False):
overflow_buf.zero_()
a = torch.full([sizea], scale, dtype=torch.float32, device=device)
b = torch.full([sizeb], scale, dtype=torch.float32, device=device)
out_list = []
for i in range(repeat):
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
else:
in_list = [out.clone().to(in_type) for out in out_list]
applier(tex.multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
assert all([torch.allclose(out, ref.to(out_type)) for out in out_list])
assert overflow_buf.item() == 0
def find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
t,
ind,
val,
inplace=False,
):
overflow_buf.zero_()
a = torch.full([sizea], scale, dtype=torch.float32, device=device)
b = torch.full([sizeb], scale, dtype=torch.float32, device=device)
out_list = []
for i in range(repeat):
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
else:
in_list = [out.clone().to(in_type) for out in out_list]
applier(tex.multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
overflow_buf.zero_()
in_list[t][ind] = val
applier(tex.multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
assert overflow_buf.item() > 0
downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
0,
0,
float("nan"),
inplace=inplace,
)
find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
2 * repeat - 1,
sizeb - 1,
float("inf"),
inplace=inplace,
)
find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
2 * (repeat // 2),
sizea // 2,
float("inf"),
inplace=inplace,
)
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True])
def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tensor):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
val = 4.0
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
overflow_buf.zero_()
a = torch.full([sizea], val, dtype=torch.float32, device=device)
b = torch.full([sizeb], val, dtype=torch.float32, device=device)
in_list = []
for i in range(repeat):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm, norm_per_tensor = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
).norm()
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True])
def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, per_tensor):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
val = 4.0
inv_scale = 0.5
inv_scale_cuda = torch.tensor([inv_scale], dtype=torch.float32, device=device)
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
overflow_buf.zero_()
a = torch.full([sizea], val, dtype=torch.float32, device=device)
b = torch.full([sizeb], val, dtype=torch.float32, device=device)
in_list = []
for i in range(repeat):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm, norm_per_tensor = applier(
tex.multi_tensor_unscale_l2norm,
overflow_buf,
[in_list],
inv_scale_cuda,
True,
)
normab = torch.cat(((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(
tex.multi_tensor_unscale_l2norm,
overflow_buf,
[in_list],
inv_scale_cuda,
True,
)
reference = torch.full(
[(sizea + sizeb) * repeat], val * inv_scale, dtype=torch.float32, device=device
).norm()
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import math
import os
from typing import Dict, List, Optional
import pytest
import copy
import random
import torch
import torch.nn as nn
from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
GroupedLinear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
LayerNorm,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"]
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
]
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
if rtol is not None:
tols["rtol"] = rtol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
class TorchScaledMaskedSoftmax(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
dtype = inp.dtype
inp = inp.float()
if scale is not None:
inp = inp * scale
mask_output = attention_mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output)
probs = probs.to(dtype)
return probs
class TorchDotProductAttention(torch.nn.Module):
def __init__(
self,
kv_channels: int,
attention_dropout: float = 0.0,
) -> None:
super().__init__()
self.norm_factor = math.sqrt(kv_channels)
self.scale_mask_softmax = TorchScaledMaskedSoftmax()
self.attention_dropout = torch.nn.Dropout(attention_dropout)
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
attention_probs = self.attention_dropout(attention_probs)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1)
return context_layer
class TorchLayerNorm(nn.Module):
def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
super().__init__()
self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma
initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.bias = nn.Parameter(torch.zeros(in_features))
self.register_parameter("weight", self.weight)
self.register_parameter("bias", self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight if not self.zero_centered_gamma else 1 + self.weight
w = w.to(torch.float32)
b = self.bias.to(torch.float32)
inp = x.to(torch.float32)
out = torch.nn.functional.layer_norm(
inp, (self.in_features,), weight=w, bias=b, eps=self.eps
)
return out.to(x.dtype)
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
super().__init__()
self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma
initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.register_parameter("weight", self.weight)
def forward(self, x):
norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
d_x = self.in_features
rms_x2 = norm_x2 / d_x + self.eps
r_rms_x = rms_x2 ** (-1.0 / 2)
x_normed = x * r_rms_x
w = self.weight.float()
if self.zero_centered_gamma:
w = 1 + w
return (w * x_normed).to(x.dtype)
class TorchLayerNormLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
eps: float,
bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
):
super().__init__()
if normalization == "LayerNorm":
self.layernorm = TorchLayerNorm(
in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
)
elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(
in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
)
else:
raise RuntimeError("Unsupported normalization")
self.linear = nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.layernorm(x))
class TorchMHA(nn.Module):
def __init__(self, hidden_size: int, num_attention_heads: int):
super().__init__()
self.mhsa = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_attention_heads,
dropout=0.1,
bias=True,
batch_first=False,
)
def forward(self, x, attention_mask=None):
output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
if isinstance(output, tuple):
output = output[0]
return output
class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)
class TorchSquaredRELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input
class TorchGroupedLinearWithPadding(nn.Module):
def __init__(
self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
) -> None:
super().__init__()
self.padding = Fp8Padding(num_gemms)
self.linear_fn = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
device="cuda",
)
self.unpadding = Fp8Unpadding(num_gemms)
self.fp8 = fp8
def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
if self.fp8:
orig_m_splits = m_splits
inp, m_splits = self.padding(inp, m_splits)
out = self.linear_fn(inp, m_splits)
if self.fp8:
out = self.unpadding(out, orig_m_splits)
return out
_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
"reglu": nn.ReLU(),
"relu": nn.ReLU(),
"swiglu": nn.SiLU(),
"qgelu": TorchQuickGELU(),
"srelu": TorchSquaredRELU(),
}
class TorchGLU(nn.Module):
def __init__(self, activation: str):
super().__init__()
self.act = _supported_act[activation]
def forward(self, x):
shape = x.size(-1)
a = x[..., : shape // 2]
b = x[..., (shape // 2) :]
a = self.act(a)
return a * b
class TorchLayerNormMLP(nn.Module):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
activation="gelu",
normalization: str = "LayerNorm",
):
super().__init__()
if normalization == "LayerNorm":
self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
else:
raise RuntimeError("Unsupported normalization")
if "glu" in activation:
fc1_output_features = 2 * ffn_hidden_size
self.gelu = TorchGLU(activation)
else:
fc1_output_features = ffn_hidden_size
self.gelu = _supported_act[activation]
self.fc1 = nn.Linear(hidden_size, fc1_output_features)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
def forward(self, x):
t = self.gelu(self.fc1(self.ln(x)))
return self.fc2(t)
class TorchGPT(nn.Module):
def __init__(
self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
self.parallel_attention_mlp = parallel_attention_mlp
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln(x)
b = self.causal_attn(a, attention_mask)
if self.parallel_attention_mlp:
n = self.ln_mlp(x)
x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
else:
x = x + nn.functional.dropout(b, p=0.1, training=self.training)
n = self.ln_mlp(x)
x = x + nn.functional.dropout(n, p=0.1, training=self.training)
return x
def _test_e2e_selective_recompute(
bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False
):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
config = model_configs[model]
outputs = _test_e2e_selective_recompute(
bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False
)
outputs_recompute = _test_e2e_selective_recompute(
bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols["atol"] = 1e-4
if fp8 or fp8_model_params:
tols.update(dict(rtol=0.125, atol=0.0675))
for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_full_recompute(
bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True
):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=use_reentrant,
)
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if recompute:
te_out = te_checkpoint(
block,
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
distribute_saved_activations=False,
tp_group=None,
use_reentrant=use_reentrant,
)
else:
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out]
names = ["output"]
if use_reentrant:
outputs.append(te_inp_hidden_states.grad)
names.append("input")
for name, p in block.named_parameters():
if p.requires_grad:
outputs.append(p.grad)
names.append(name)
return outputs, names
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_reentrant", all_boolean)
def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
config = model_configs[model]
if not use_reentrant:
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"
outputs, names = _test_e2e_full_recompute(
bs,
dtype,
config,
fp8,
recipe,
fp8_model_params,
recompute=False,
use_reentrant=use_reentrant,
)
outputs_recompute, _ = _test_e2e_full_recompute(
bs,
dtype,
config,
fp8,
recipe,
fp8_model_params,
recompute=True,
use_reentrant=use_reentrant,
)
if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"]
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols["atol"] = 1e-3
if fp8 or fp8_model_params:
tols.update(dict(rtol=0.125, atol=0.0675))
for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_checkpointing_get_model(config, dtype):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
device="cuda",
)
def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
reset_rng_states()
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
block = _test_e2e_checkpointing_get_model(config, dtype)
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
None,
)
loss = te_out.sum()
loss.backward()
if checkpoint:
# This process is necessary so that we can start afresh with
# a new model while erasing all internal state to ensure that
# loading from a checkpoint gives bitwise identical results.
# Since gradients are being accumulated, it is important to
# restore them post loading the checkpoint.
torch.save(block.state_dict(), path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False))
reset_rng_states()
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
None,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len)
out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
params_dtype=dtype,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
device="cuda",
).eval()
torch_gpt = (
TorchGPT(
config.hidden_size,
config.eps,
config.num_attention_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_gpt.ln.weight = Parameter(
te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
)
torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
te_gpt.self_attention.layernorm_qkv.weight.clone()
)
torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
te_gpt.self_attention.layernorm_qkv.bias.clone()
)
torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
te_gpt.self_attention.proj.weight.clone()
)
torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
te_gpt.self_attention.proj.bias.clone()
)
torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)
atol = {
torch.float32: 5e-3,
torch.half: 5e-2,
torch.bfloat16: 1e-1,
}
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
# Check gradients, only for small model
if model == "small":
atol[torch.float32] = 5e-2
rtol = {
torch.float32: 1e-2,
torch.half: 1e-2,
torch.bfloat16: 1e-2,
}
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
forward_kwargs = {}
if te:
forward_kwargs["attn_mask_type"] = mask_type
forward_kwargs["attention_mask"] = inp_attn_mask
out = block(inp_hidden_states, **forward_kwargs)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
te_mha = MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
input_layernorm=False,
device="cuda",
).eval()
torch_mha = (
TorchMHA(
config.hidden_size,
config.num_attention_heads,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())
te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
# Check gradients, only for small model
if model == "small":
atol = {
torch.float32: 5e-2,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-2,
torch.half: 1e-2,
torch.bfloat16: 1e-2,
}
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
out = block(inp_hidden_states)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(
torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
)
query, key, value = [
torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed),
dtype=dtype,
device="cuda",
requires_grad=True,
)
for _ in range(3)
]
query.retain_grad()
key.retain_grad()
value.retain_grad()
out = block(query, key, value, attention_mask=mask)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
return [out, query.grad, key.grad, value.grad]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
def test_dpa_accuracy(dtype, bs, model):
config = model_configs[model]
te_dpa = (
DotProductAttention(
config.num_attention_heads,
config.embed,
attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
.cuda()
)
torch_dpa = (
TorchDotProductAttention(
config.embed,
0.0, # dropout
)
.to(dtype=dtype)
.cuda()
)
te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model):
config = model_configs[model]
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
torch_linear = torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
device="cuda",
dtype=dtype,
).eval()
# Share params
with torch.no_grad():
torch_linear.weight = Parameter(te_linear.weight.clone())
torch_linear.bias = Parameter(te_linear.bias.clone())
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
# Check output.
if model == "small":
tolerance = 5e-3 if dtype == torch.float32 else 5e-2
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-2,
torch.bfloat16: 2e-2,
}
for te_output, torch_output in zip(te_outputs, torch_outputs):
assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_rmsnorm = RMSNorm(
config.hidden_size,
eps=eps,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
).eval()
torch_rmsnorm = (
TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())
te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
atol = {
torch.float32: 1e-7,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
atol[torch.float32] = 2e-3
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-3,
torch.bfloat16: 1.6e-2,
}
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_layernorm = LayerNorm(
config.hidden_size,
eps=eps,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
).eval()
torch_layernorm = (
TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
torch_layernorm.bias = Parameter(te_layernorm.bias.clone())
te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
atol = {
torch.float32: 1e-7,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-3,
torch.bfloat16: 1.6e-2,
}
atol[torch.float32] = 1e-4
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model]
te_ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
).eval()
torch_ln_linear = (
TorchLayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
if normalization != "RMSNorm":
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
atol = {
torch.float32: 2.5e-4,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}
rtol = {
torch.float32: 1e-3,
torch.half: 4e-2,
torch.bfloat16: 4e-2,
}
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
if model == "small":
atol = {
torch.float32: 1e-3,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-3,
torch.half: 4e-2,
torch.bfloat16: 4e-2,
}
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
config = model_configs[model]
te_ln_mlp = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
device="cuda",
).eval()
torch_ln_mlp = (
TorchLayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
activation=activation,
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
atol = {
torch.float32: 2e-2,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-3,
torch.half: 4e-2,
torch.bfloat16: 4e-2,
}
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
# Check gradients, only for small model
rtol = {
torch.float32: 1e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}
atol[torch.half] = 2e-1
atol[torch.bfloat16] = 2e-1
if model == "small":
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
if num_gemms > 1:
split_size = 1
if fp8:
if recipe.delayed():
split_size = 16
if recipe.mxfp8():
split_size = 128
m = config.seq_len // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * split_size
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, GroupedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_grouped_linear_accuracy(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
bs=2,
model="126m",
fp8=True,
recipe=recipe,
fp8_model_params=True,
parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
)
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_single_gemm(recipe):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=1,
bs=2,
model="126m",
fp8=True,
recipe=recipe,
fp8_model_params=True,
fuse_wgrad_accumulation=True,
)
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16."""
padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
]
hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = []
for hidden_state, actual_num_tokens, padded_num_tokens in zip(
hidden_states, tokens_per_expert, padded_tokens_per_expert
):
padded_hidden_states.append(hidden_state)
if padded_num_tokens > actual_num_tokens:
pad_tensor = torch.zeros(
padded_num_tokens - actual_num_tokens,
hidden_state.shape[1],
dtype=hidden_state.dtype,
device=hidden_state.device,
)
padded_hidden_states.append(pad_tensor)
padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
return padded_hidden_states, padded_tokens_per_expert
def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
inputmats = torch.split(
padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
)
hidden_states = torch.cat(
[
grad_output_mat[: actual_tokens_per_expert[i]]
for i, grad_output_mat in enumerate(inputmats)
],
dim=0,
)
return hidden_states
def _generate_random_numbers(n, total_sum):
if n <= 0:
return []
# reset seed
random.seed(seed)
breaks = sorted(random.sample(range(1, total_sum), n - 1))
random_numbers = (
[breaks[0]]
+ [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
+ [total_sum - breaks[-1]]
)
return random_numbers
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
if fp8:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for graph capture.
static_input = torch.randn(
config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
# Basic training loop.
def train_step():
optimizer.zero_grad(set_to_none=False)
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
return out
# Warmup steps in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
train_step()
torch.cuda.current_stream().wait_stream(s)
# Capture graph.
g = None
static_output = None
if graph:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = train_step()
# Run with new data.
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
if graph:
g.replay()
else:
static_output = train_step()
grads = [static_input.grad]
for p in block.parameters():
if p.requires_grad:
grads.append(p.grad)
with torch.no_grad():
output = static_output.clone()
return output, grads
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_cuda_graph(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block_args = (
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
)
block = TransformerLayer(*block_args, **block_kwargs)
graphed_block = TransformerLayer(*block_args, **block_kwargs)
with torch.no_grad():
for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
param2.copy_(param1)
out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
params = list(block.parameters())
graphed_params = list(graphed_block.parameters())
# Check that results match
assert_allclose(out, graphed_out, 1e-3)
assert_allclose(params, graphed_params, 1e-3)
assert_allclose(grads, graphed_grads, 1e-3)
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
config = model_configs[model]
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe)
# Check that results match
tols = dict(rtol=0.125, atol=0.0675)
for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
rtol=0.125,
atol=0.0675,
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
def test_transformer_layer_hidden_states_format(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_sbhd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="sbhd",
)
# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_bshd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="bshd",
)
torch.manual_seed(0)
block_thd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="thd",
self_attn_mask_type="padding_causal",
)
for (n1, p1), (n2, p2), (n3, p3) in zip(
block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0, 1).contiguous()
x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_sbhd = block_sbhd(x_sbhd)
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_bshd = block_bshd(x_bshd)
# Check that results match
torch.testing.assert_close(
y_bshd,
y_sbhd.transpose(0, 1).contiguous(),
)
# THD is not supported in float32 and on GPUs older than Ampere, skip the test here
if dtype != torch.float32 and sm_80plus:
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_thd = block_thd(
x_thd,
cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum,
max_seqlen_q=config.seq_len,
max_seqlen_kv=config.seq_len,
)
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1
# Limits the max size of KV-cache
B_max = B
S_max = S + 2
if module == "TransformerLayer":
model = TransformerLayer(
hidden_size=D,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
device="cuda",
).eval()
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal",
params_dtype=dtype,
)
.cuda()
.eval()
)
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
for i in range(S):
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
)
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
incremental_output[i] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer":
atol = {
torch.float32: 5e-3,
torch.half: 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32: 1e-3,
torch.half: 1e-3,
torch.bfloat16: 1e-2,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize(
"shape",
[
(1, 127, 128, 512),
(8, 15, 128, 512),
(8, 1027, 128, 512),
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
torch.manual_seed(0)
z, m, k, n = shape
dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
assert m_splits.sum() == m and len(m_splits) == z
m_splits = m_splits.tolist()
if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = True
single_output = True
else: # layout == "NT"
A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
out_ref = [o.clone() for o in out]
grad = True
single_output = False
for i in range(z):
general_gemm(
A[i],
B[i],
get_workspace(),
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]
general_grouped_gemm(
A,
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize(
"shape",
[
(1, 128, 128, 512),
(8, 1024, 128, 512),
(16, 4096, 128, 512),
],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
z, m, k, n = shape
m_splits = [m // z] * z
dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
out_ref = [o.clone() for o in out]
# fp8 should be robust enough to this fake scale
scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze()
amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
a_quantizers = [
Float8Quantizer(
scale.clone(),
amax.clone(),
tex.DType.kFloat8E4M3,
)
for _ in range(z)
]
b_quantizers = [
Float8Quantizer(
scale.clone(),
amax.clone(),
tex.DType.kFloat8E4M3,
)
for _ in range(z)
]
A_fp8 = []
B_fp8 = []
for i in range(z):
A_fp8.append(a_quantizers[i](A[i]))
B_fp8.append(b_quantizers[i](B[i]))
# baseline
for i in range(z):
general_gemm(
A_fp8[i],
B_fp8[i],
get_workspace(),
dtype,
out=out_ref[i],
accumulate=accumulate,
)
general_grouped_gemm(
A_fp8,
B_fp8,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
accumulate=accumulate,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def test_noncontiguous():
def _create2modules(m, params):
mod1 = m(*params)
mod2 = m(*params)
for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
p2.data = p1.data.clone()
return mod1, mod2
def _run_module(m, inp):
out = m(inp)
out.sum().backward()
ret = [out]
if inp.grad is not None:
ret.append(inp.grad)
for p in m.parameters():
if p.requires_grad:
ret.append(p.grad)
return ret
a = torch.randn((128, 256), device="cuda", requires_grad=True)
a = a.T
assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."
b = a.contiguous()
# LayerNorm
ln1, ln2 = _create2modules(LayerNorm, [128])
outT = _run_module(ln1, a)
out = _run_module(ln2, b)
assert_allclose(out, outT, 1e-7)
# RMSNorm
ln1, ln2 = _create2modules(RMSNorm, [128])
outT = _run_module(ln1, a)
out = _run_module(ln2, b)
assert_allclose(out, outT, 1e-7)
# GEMM
g1, g2 = _create2modules(Linear, [128, 128])
outT = _run_module(g1, a)
out = _run_module(g2, b)
assert_allclose(out, outT, 1e-7)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import random
import pytest
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
class TestParallelCrossEntropy:
def generate_iters(self, iters: int):
self.iters = iters
def generate_infra(self, reduce_loss: bool, label_smoothing: float):
self.test_loss_func = parallel_cross_entropy
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda()
else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
def one_iteration_test(
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool
):
self.generate_input(dtype, swap_dim)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
if reduce_loss:
test_loss.backward()
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
if reduce_loss:
ref_loss.backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
self.input_test = None
self.input_ref = None
self.tar_test = None
self.tar_ref = None
def test_float32_input(self):
self.generate_iters(5)
self.generate_infra(True, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=True
)
def test_bfloat16_input(self):
self.generate_iters(5)
self.generate_infra(True, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.bfloat16, swap_dim=False, label_smoothing=0, reduce_loss=True
)
def test_swapped_input(self):
self.generate_iters(5)
self.generate_infra(True, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=True, label_smoothing=0, reduce_loss=True
)
def test_label_smoothing(self):
self.generate_iters(3)
self.generate_infra(True, 0.1)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0.1, reduce_loss=True
)
def test_non_reduced_loss(self):
self.generate_iters(1)
self.generate_infra(False, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import random
import torch
import pytest
from typing import Dict, List
from transformer_engine.pytorch import (
moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs,
moe_unpermute as te_unpermute,
moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def pytorch_permute_index_map(tokens, indices, num_out_tokens: int = None):
"""
Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
Args:
tokens: torch.Tensor
The input token tensor.
indices: torch.Tensor
The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
num_out_tokens: int, optional
The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped.
By default, set to None, meaning no tokens are dropped.
Returns:
torch.Tensor:
The permuted tensor.
torch.Tensor:
The sorted_indices corresponding permuted tensor.
"""
if indices.dim() == 1:
topk = 1
else:
topk = indices.size(1)
flatten_indices = indices.view(-1)
sorted_indices = torch.argsort(flatten_indices, stable=True)
num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0)
permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk)
return permuted_tokens, sorted_indices
def pytorch_unpermute_index_map(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
):
"""
Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their
corresponding probabilities.
Args:
permuted_tokens: torch.Tensor
The tensor of permuted tokens to be unpermuted.
sorted_indices: torch.Tensor
The tensor of sorted indices used to unpermute the tokens.
probs: torch.Tensor, optional
The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will
be merged with their respective probabilities.
Returns:
torch.Tensor:
The unpermuted tokens, optionally merged with probabilities.
"""
if probs is not None:
# Unpermute and merge the tokens with their probabilities
num_unpermuted_tokens = probs.numel()
topk = probs.size(1)
else:
# Unpermute the tokens without merge
num_unpermuted_tokens = sorted_indices.size(0)
topk = 1
unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def pytorch_permute_mask_map(tokens, routing_map):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
"""
num_tokens, _ = tokens.shape
num_experts = routing_map.shape[1]
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
return permuted_input, sorted_indices
def pytorch_unpermute_mask_map(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
restore_shape (torch.Size): The shape of the unpermuted tensor.
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
Returns:
torch.Tensor: The tokens restored to their original order.
"""
_, hidden = restore_shape
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
# Create an output tensor filled with zeros
output_tokens = torch.zeros(
restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype
)
# Scatter add the permuted_input back to the original positions
output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
return output_tokens
def pytorch_sort_chunks_by_index(
input: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
):
"""
Split and sort the input tensor based on the split_sizes and sorted indices.
return a tuple of (output, row_id_map). row_id_map is only used when fused=True.
"""
input = torch.split(input, split_sizes.tolist(), dim=0)
output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
return output
def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]:
"""Estimated tolerances for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if te_dtype == tex.DType.kFloat32:
return dict(rtol=1.0e-6, atol=1.0e-6)
if te_dtype == tex.DType.kFloat16:
return dict(rtol=3.0e-3, atol=1.0e-5)
if te_dtype == tex.DType.kBFloat16:
return dict(rtol=2.0e-2, atol=1.0e-5)
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
return dict(rtol=2.0e-1, atol=1.0e-1)
raise ValueError(f"Unsuppored dtype ({te_dtype})")
def backward_wrapper(
act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False
):
# Set forward_input.grad to None to avoid grad accumulation.
if accumulate_grad == False:
for i in forward_input:
i.grad = None
return act.backward(backward_input, retain_graph=retain_graph)
def _test_permutation_index_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_probs,
BENCHMARK=False,
):
if not with_probs and topK > 1:
pytest.skip("Only permutations with topK=1 and without probabilities are supported.")
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
"index map:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_permute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input)
permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input)
unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
if num_tokens > 0:
indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)])
else:
indices = torch.empty((num_tokens, topK))
indices = indices.to(torch.int32).cuda()
probs = None
if with_probs:
probs = torch.rand(num_tokens, topK).cuda()
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs.requires_grad_(True)
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_permute_output, sorted_indices = pytorch_permute_index_map(
pytorch_permute_fwd_input, indices, num_out_tokens
)
pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)
pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
pytorch_unpermute_fwd_input.requires_grad_(True)
pytorch_unpermute_output = pytorch_unpermute_index_map(
pytorch_unpermute_fwd_input, sorted_indices, probs=probs
)
pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, indices, num_out_tokens, map_type="index"
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
te_probs = None
if with_probs:
te_probs = probs.detach()
te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32)
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
)
if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.")
return
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens, map_type="index")
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_permute_output,
pytorch_permute_bwd_input,
forward_input=[pytorch_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: pytorch_unpermute_index_map(
pytorch_unpermute_fwd_input, sorted_indices, probs=probs
)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs, map_type="index")
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_unpermute_output,
pytorch_unpermute_bwd_input,
forward_input=(
[pytorch_unpermute_fwd_input, probs]
if with_probs
else [pytorch_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_unpermute_output,
te_unpermute_bwd_input,
forward_input=(
[te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_probs,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
"mask map:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_permute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input)
permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input)
unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
restore_shape = pytorch_permute_fwd_input.shape
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = None
if with_probs:
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
if fp8:
probs = probs.to(torch.float16)
else:
probs = probs.to(dtype)
probs.requires_grad_(True)
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_permute_output, sorted_indices = pytorch_permute_mask_map(
pytorch_permute_fwd_input, routing_map
)
pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)
pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
pytorch_unpermute_fwd_input.requires_grad_(True)
pytorch_unpermute_output = pytorch_unpermute_mask_map(
pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map
)
pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
te_probs = None
if with_probs:
te_probs = probs.detach()
te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32)
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
)
if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.")
return
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(
te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
)
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_permute_output,
pytorch_permute_bwd_input,
forward_input=[pytorch_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: pytorch_unpermute_mask_map(
pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map
)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
)
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_unpermute_output,
pytorch_unpermute_bwd_input,
forward_input=(
[pytorch_unpermute_fwd_input, probs]
if with_probs
else [pytorch_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_unpermute_output,
te_unpermute_bwd_input,
forward_input=(
[te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_moe_chunk_sort(
te_dtype,
num_tokens,
num_expert,
tp_size,
hidden_size,
BENCHMARK=False,
):
print(
"chunk permute:"
f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
fwd_input = _fwd_input_quantizer.quantize(fwd_input)
bwd_input = _bwd_input_quantizer.quantize(bwd_input)
pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16)
pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_fwd_input.requires_grad_(True)
_split_sizes = [0] * (num_expert * tp_size)
for _ in range(num_tokens):
idx = random.randint(0, num_expert * tp_size - 1)
_split_sizes[idx] += 1
split_sizes = torch.tensor(_split_sizes, dtype=torch.int32).ravel()
split_sizes_cuda = split_sizes.to(device="cuda")
_sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32)
sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel()
sorted_idxs_cuda = sorted_idxs.to(device="cuda")
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_output = pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs)
pytorch_output.backward(pytorch_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach()
te_fwd_input.requires_grad_(True)
te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach()
te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
te_output.backward(te_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_output_ = te_output.dequantize(dtype=torch.float32)
te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_output.float(),
te_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_fwd_input.grad.float(),
te_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
if not pytorch_fwd_input.numel():
print("Empty pytorch_fwd_input activation test passed.")
return
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs)
)
t2 = perf_test_cuda_kernel(
lambda: te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
)
print(f"chunk sort\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_output,
pytorch_bwd_input,
forward_input=[pytorch_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_output,
te_bwd_input,
forward_input=[te_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_mask_map_alongside_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
tp_size,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
"mask map alongside probs:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input)
unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
restore_shape = pytorch_permute_fwd_input.shape
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
if fp8:
probs = probs.to(torch.float16)
else:
probs = probs.to(dtype)
probs.requires_grad_(True)
split_sizes = [0] * (num_expert * tp_size)
for i in range(num_out_tokens):
idx = random.randint(0, num_expert * tp_size - 1)
split_sizes[idx] += 1
split_sizes = torch.tensor(split_sizes, dtype=torch.int32)
split_sizes_cuda = split_sizes.to(device="cuda")
_sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32)
sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel()
sorted_idxs_cuda = sorted_idxs.to(device="cuda")
split_sizes_2 = [split_sizes[i] for i in sorted_idxs.tolist()]
split_sizes_2 = torch.tensor(split_sizes_2, dtype=torch.int32)
split_sizes_2_cuda = split_sizes_2.to(device="cuda")
sorted_idxs_2 = [0] * (num_expert * tp_size)
for i in range(num_expert * tp_size):
sorted_idxs_2[sorted_idxs[i]] = i
sorted_idxs_2 = torch.tensor(sorted_idxs_2, dtype=torch.int32)
sorted_idxs_2_cuda = sorted_idxs_2.to(device="cuda")
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_permute_output, sorted_indices = pytorch_permute_mask_map(
pytorch_permute_fwd_input, routing_map
)
pytorch_permute_output = pytorch_sort_chunks_by_index(
pytorch_permute_output, split_sizes, sorted_idxs
)
pytorch_permute_output = pytorch_sort_chunks_by_index(
pytorch_permute_output, split_sizes_2, sorted_idxs_2
)
pytorch_unpermute_output = pytorch_unpermute_mask_map(
pytorch_permute_output, sorted_indices, restore_shape, probs, routing_map
)
pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_probs = probs.detach()
te_probs.requires_grad_(True)
print(te_probs.shape)
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input,
te_probs,
routing_map,
num_out_tokens=num_out_tokens,
)
print(te_permuted_probs.shape)
te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs(
te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda
)
if fp8:
_permute_output_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
te_permute_output = te_permute_output.dequantize(dtype=torch.float32)
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = _permute_output_quantizer.quantize(te_permute_output)
else:
te_permute_output_dtype = te_permute_output.dtype
print(te_permute_output.shape)
print(te_permuted_probs.shape)
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype)
te_permute_output = te_sort_chunks_by_index(
te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda
)
te_unpermute_output = te_unpermute(
te_permute_output,
row_id_map,
restore_shape=restore_shape,
map_type="mask",
)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
###############################################################################################
tols = dtype_tols(te_dtype)
if fp8:
# backward of dequantize is in high precision
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
else:
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in fused_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in fused_permute bwd",
**tols,
)
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
)
def perf_test_cuda_kernel(cuda_kernel_fn):
if torch.cuda.is_available():
# create CUDA event
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(50):
cuda_kernel_fn()
start_event.record()
for _ in range(100):
cuda_kernel_fn()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
return elapsed_time_ms / 100
else:
pytest.skip("CUDA is not available.")
# TE tensor dtypes
_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16]
if is_bf16_compatible():
_te_dtypes.append(tex.DType.kBFloat16)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_index_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_empty_input(te_dtype):
with_probs = True
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=0,
num_expert=8,
hidden_size=4096,
topK=2,
num_out_tokens=0,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
def test_permutation_mask_map_alongside_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
tp_size,
):
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=0,
num_expert=8,
hidden_size=4096,
topK=2,
num_out_tokens=0,
tp_size=2,
)
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_index_map_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
def test_permutation_mask_map_alongside_probs_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
tp_size,
):
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_index_map_topk1_no_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
):
topK = 1
num_out_tokens = None
with_probs = False
BENCHMARK = False
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
):
topK = 1
num_out_tokens = None
with_probs = False
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation(
te_dtype,
num_tokens,
num_expert,
tp_size,
hidden_size,
):
BENCHMARK = False
_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
tp_size=tp_size,
hidden_size=hidden_size,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_chunk_permutation_empty_input(te_dtype):
BENCHMARK = False
_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=0,
num_expert=8,
tp_size=2,
hidden_size=4096,
BENCHMARK=BENCHMARK,
)
def test_permutation_single_case():
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
# te_dtype = tex.DType.kBFloat16
te_dtype = tex.DType.kFloat8E5M2
# te_dtype = tex.DType.kFloat8E4M3
num_tokens = 10
num_expert = 4
hidden_size = 16
topK = 2
num_out_tokens = num_tokens * topK - 1
with_probs = True
Benchmark = True
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=Benchmark,
)
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=Benchmark,
)
_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
tp_size=4,
hidden_size=hidden_size,
BENCHMARK=Benchmark,
)
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=4,
)
if __name__ == "__main__":
test_permutation_single_case()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
_amax_and_scale_update,
get_default_fp8_recipe,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# FP8 per tensor delayed scaling
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_fp8_scale_update_with_linear_module(
self,
amax_history_len: int,
amax_compute_algo: str,
is_first_microbatch: Optional[bool],
margin: int = 2,
):
# Construct linear module
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(
torch.randn([16, 16], device="cuda"),
is_first_microbatch=True,
)
y.backward(torch.zeros_like(y))
# Get amax history and scaling factors
fp8_meta = module.fp8_meta
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
amax_history_forward = fp8_meta[forward_key].amax_history
scale_forward = fp8_meta[forward_key].scale
# scale_inv_forward = fp8_meta[forward_key].scale_inv
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
amax_history_backward = fp8_meta[backward_key].amax_history
scale_backward = fp8_meta[backward_key].scale
# scale_inv_backward = fp8_meta[backward_key].scale_inv
# Tweak amax history and scaling factors
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
amax_history_forward[0, :].zero_()
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
# scale_inv_forward.copy_(torch.reciprocal(scale_forward))
amax_history_backward[0, :].zero_()
# Expected amax history after update
# Note: amax history is only updated when amax is updated
update_weight_amax = is_first_microbatch is None or is_first_microbatch
ref_amax_history_forward = amax_history_forward.clone()
ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1))
if update_weight_amax:
ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1))
ref_amax_history_forward[0, :].zero_()
ref_amax_history_backward = amax_history_backward.clone()
ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1))
ref_amax_history_backward[0, :].zero_()
# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
# ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_amax = is_first_microbatch is None or is_first_microbatch
# if not update_weight_amax:
# ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
# ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Perform forward, backward, and optimizer steps to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = torch.randn([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.randn_like(y))
# Check that amax history matches expected values
torch.testing.assert_close(
amax_history_forward[:-1],
ref_amax_history_forward[:-1],
)
torch.testing.assert_close(
amax_history_backward[:-1],
ref_amax_history_backward[:-1],
)
# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
# ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
# ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Check that scale and scale inverse match expected values
# Note: scale and scale inverse are only updated when amax is updated
torch.testing.assert_close(
scale_forward[0],
ref_scale_forward[0],
)
if update_weight_amax:
torch.testing.assert_close(
scale_forward[1],
ref_scale_forward[1],
)
torch.testing.assert_close(
scale_backward[0],
ref_scale_backward[0],
)
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
def test_fp8_scale_update_with_linear_fuser_op(
self,
amax_history_len: int,
amax_compute_algo: str,
margin: float = 2,
num_steps: int = 4,
in_shape: tuple[int] = (16, 16),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
# Construct linear op
op = te_ops.BasicLinear(in_shape[-1], in_shape[-1])
# FP8 recipe
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
# Get FP8 meta tensors
with te.fp8_autocast(fp8_recipe=recipe):
x_fp8_meta = op.get_quantizer("forward", 0)
w_fp8_meta = op.get_quantizer("forward", 1)
dy_fp8_meta = op.get_quantizer("backward", 0)
# Perform training steps
x_history = []
w_history = []
dy_history = []
for step in range(num_steps):
# Fill tensors with known values
x_history.append(step + 0.25)
w_history.append(step + 0.5)
dy_history.append(step + 0.75)
x = torch.full(
in_shape,
x_history[-1],
dtype=dtype,
device=device,
requires_grad=True,
)
dy = torch.full(
in_shape,
dy_history[-1],
dtype=dtype,
device=device,
)
with torch.no_grad():
op.weight.fill_(w_history[-1])
# Forward and backward pass
with te.fp8_autocast(fp8_recipe=recipe):
y = op(x)
y.backward(dy)
def check_amax_history(
fp8_meta: dict,
ref_amax_history: Iterable[float],
) -> None:
"""Check that amax history matches expected values"""
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-amax_history_len:]
ref_amax_history = torch.tensor(
ref_amax_history,
dtype=torch.float32,
device=device,
)
test_amax_history = fp8_meta.amax_history[:, 0]
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(
test_amax_history[-(step + 1) :],
ref_amax_history[: (step + 1)],
**tols,
)
def check_scale(
quantizer: Float8Quantizer,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale
max_val = {
"forward": 448.0,
"backward": 57344.0,
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors
torch.testing.assert_close(
quantizer.scale.item(),
ref_scale,
)
# Check that results match expected values
check_scale(x_fp8_meta, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
@pytest.mark.parametrize(
"fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"]
)
def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype):
if fp8_dtype == tex.DType.kFloat8E4M3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
fp8_max = fp8_format.value.max_fwd
elif fp8_dtype == tex.DType.kFloat8E5M2:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
fp8_max = fp8_format.value.max_bwd
else:
raise ValueError(f"{fp8_dtype=} is not supported")
scaling_factor_compute_algo = None
if fused_update:
scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
amax, scale, fp8_max, recipe.margin
)
)
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
)
# Setup fp8_meta dictionary
def setup_fp8_meta():
with te.fp8_autocast(fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))
return module.fp8_meta
fp8_meta = setup_fp8_meta()
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
# Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose
fp8_meta[forward_key] = tex.FP8TensorMeta()
fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda")
fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")
# test different scenarios
if amax_case == "zero":
fp8_meta[forward_key].amax_history = torch.tensor(
[[0]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "tiny":
# calculate the minimum amax value that results in a FP32 maximum scale
fp32_max = torch.tensor(torch.finfo(torch.float32).max)
tiny_amax = fp8_max / fp32_max
# make the amax less than the minimum amax so that the scale will be infinite
amax_value = tiny_amax / 2
fp8_meta[forward_key].amax_history = torch.tensor(
[[amax_value]], dtype=torch.float32, device="cuda"
)
# expected scale is FP32_max
expected_scale = fp32_max.view(1).cuda()
elif amax_case == "normal":
# plus a small epsilon to avoid zero amax
amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5
fp8_meta[forward_key].amax_history = amax_value.view(1, 1)
expected_scale = fp8_max / amax_value
elif amax_case == "inf":
fp8_meta[forward_key].amax_history = torch.tensor(
[[torch.inf]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "nan":
fp8_meta[forward_key].amax_history = torch.tensor(
[[torch.nan]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
if fused_update:
tex.fused_amax_and_scale_update_after_reduction(
fp8_meta[forward_key].amax_history.clone().view(-1),
[fp8_meta[forward_key].amax_history],
[fp8_meta[forward_key].scale],
recipe.amax_compute_algo,
fp8_dtype,
recipe.margin,
)
else:
_amax_and_scale_update(
fp8_meta[forward_key].amax_history,
fp8_meta[forward_key].scale,
fp8_max,
recipe,
)
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext
import torch
import pytest
import os
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
get_cudnn_version,
)
from transformer_engine.pytorch import (
LayerNormLinear,
Linear,
GroupedLinear,
LayerNormMLP,
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: torch.Tensor,
recipe: recipe.DelayedScaling,
) -> torch.Tensor:
"""Custom func to test recipe."""
sf = fp8_max / amax
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
@dataclass
class ModelConfig:
"""Transformer model configuration"""
num_layers: int
seq_len: int
batch_size: int
hidden_size: int
num_attention_heads: int
kv_channels: Optional[int] = None
def is_fp8_supported(self):
if self.seq_len * self.batch_size % 16:
return False
if self.hidden_size % 16:
return False
return True
model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
}
fp8_recipes = [
None, # Handles non-FP8 case
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="most_recent",
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(
config.seq_len,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
use_fp8 = fp8_recipe is not None
if skip_wgrad:
_disable_wgrads(block)
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=torch.float32,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
assert te_out.dtype == dtype, "AMP wrong output type."
assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
for name, p in block.named_parameters():
if p.requires_grad:
assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
p.main_grad = torch.zeros_like(p)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
failed_grads = []
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
if not torch.count_nonzero(p.main_grad) > 0:
failed_grads.append(name)
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
if skip_wgrad:
_disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
enc_dec_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=not skip_dgrad,
)
if skip_wgrad:
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
device="cuda",
requires_grad=True,
)
te_inp.retain_grad()
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
te_out = block(te_inp)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
assert te_out.dtype == dtype, "AMP wrong output type."
assert te_inp.grad is not None, "Gradient should not be empty"
assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
for name, p in block.named_parameters():
if p.requires_grad:
assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
config = model_configs[model]
module = RMSNorm if normalization == "RMSNorm" else LayerNorm
block = module(config.hidden_size).to(dtype=torch.float32).cuda()
_test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
block = LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = Linear(
config.hidden_size,
config.hidden_size,
init_method=output_layer_init_method,
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_linear = Linear(
config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda()
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = te_linear(inp_hidden_states)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
num_tokens = bs * config.seq_len * (num_gemms - 1)
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda()
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
m_splits = [bs * config.seq_len] * num_gemms
if empty_split == "first":
m_splits[0] = 0
elif empty_split == "last":
m_splits[-1] = 0
elif empty_split == "middle":
m_splits[num_gemms // 2] = 0
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = te_grouped_linear(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
bias,
activation,
normalization,
parallel_attention_mlp,
cpu_offload,
):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
activation=activation,
normalization=normalization,
device="cuda",
parallel_attention_mlp=parallel_attention_mlp,
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
def test_sanity_gpt_126m():
fp8_recipe = None
if fp8_available:
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="most_recent",
)
test_sanity_gpt(
dtype=param_types[-1],
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=True,
bias=True,
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="causal",
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_bert_126m():
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=1,
amax_compute_algo="most_recent",
)
test_sanity_bert(
dtype=param_types[-1],
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_T5_126m():
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=1,
amax_compute_algo="most_recent",
)
test_sanity_T5(
dtype=param_types[-1],
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=torch.float32,
device="cuda",
)
_test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
device="cuda",
)
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16, 16), device="cuda")
m = Linear(16, 32)
y = m(a)
assert y.dtype == torch.float32
m.half()
a = a.half()
y2 = m(a)
assert y2.dtype == torch.float16
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace())
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
offset = 16
scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
scales = torch.ones(1).cuda().squeeze()
amaxes = torch.ones(1).cuda().squeeze()
dtype = tex.DType.kFloat8E4M3
fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
outp_type = datatype
scratchpad_fp8 = fp8_quantizer(scratchpad)
inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
bias=None,
use_split_accumulator=False,
)
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import transformer_engine.pytorch
print("OK")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
"""Convert type name to PyTorch dtype"""
if isinstance(dtype, torch.dtype):
return dtype
name = str(dtype).strip().lower()
if name.startswith("torch."):
name = name.replace("torch.", "", 1)
if name.startswith("fp"):
name = name.replace("fp", "float", 1)
dtype = dict(
float32=torch.float32,
float=torch.float32,
float64=torch.float64,
double=torch.float64,
float16=torch.float16,
half=torch.float16,
bfloat16=torch.bfloat16,
bf16=torch.bfloat16,
float8_e4m3fn=torch.float8_e4m3fn,
float8_e4m3=torch.float8_e4m3fn,
float8e4m3=torch.float8_e4m3fn,
float8=torch.float8_e4m3fn,
float8_e5m2=torch.float8_e5m2,
float8e5m2=torch.float8_e5m2,
uint8=torch.uint8,
byte=torch.uint8,
int8=torch.int8,
char=torch.int8,
int16=torch.int16,
short=torch.int16,
int32=torch.int32,
int=torch.int32,
int64=torch.int64,
long=torch.int64,
bool=torch.bool,
)[name]
return dtype
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package"""
# pylint: disable=unused-import
from importlib import metadata
import transformer_engine.common
try:
from . import pytorch
except (ImportError, StopIteration) as e:
pass
try:
from . import jax
except (ImportError, StopIteration) as e:
pass
try:
import transformer_engine_jax
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine"))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.21)
# Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif()
# Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
# Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
# Helper functions to make header files with C++ strings
function(make_string_header STRING STRING_NAME)
configure_file(util/string_header.h.in
"string_headers/${STRING_NAME}.h"
@ONLY)
endfunction()
function(make_string_header_from_file file_ STRING_NAME)
file(READ "${file_}" STRING)
configure_file(util/string_header.h.in
"string_headers/${STRING_NAME}.h"
@ONLY)
endfunction()
# Header files with C++ strings
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
make_string_header_from_file(util/math.h
string_code_util_math_h)
target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
# Number of parallel build jobs
if(ENV{MAX_JOBS})
set(BUILD_JOBS_STR "$ENV{MAX_JOBS}")
elseif(ENV{NVTE_BUILD_MAX_JOBS})
set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}")
else()
set(BUILD_JOBS_STR "max")
endif()
message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}")
# Number of threads per parallel build job
set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB})
if (NOT BUILD_THREADS_PER_JOB)
set(BUILD_THREADS_PER_JOB 1)
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}")
message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}")
# Install library
install(TARGETS transformer_engine DESTINATION .)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FW agnostic user-end APIs"""
import sys
import glob
import sysconfig
import subprocess
import ctypes
import os
import platform
from pathlib import Path
import transformer_engine
def is_package_installed(package):
"""Checks if a pip package is installed."""
return (
subprocess.run(
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
).returncode
== 0
)
def get_te_path():
"""Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent
def _get_sys_extension():
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
return extension
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
)
)
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "stub" in lib or "libnvrtc-builtins" in lib:
continue
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_TE_LIB_CTYPES = _load_library()
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file activation_template.h
* \brief Activation functions template.
*/
#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_relu);
using namespace transformer_engine;
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment