Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
swizzled_scale: bool,
use_cpp_allocator: bool,
with_2d_quantization: bool,
) -> None:
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=with_2d_quantization,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
# Reference quantization
quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16)
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=quant_tile_shape,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
if x_nvfp4_ref.data_t is not None
else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# # largest tile
(8192, 8192),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"])
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"]
)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
swizzled_scale: bool,
use_cpp_allocator: bool,
with_2d_quantization: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
return_transpose=return_transpose,
swizzled_scale=swizzled_scale,
use_cpp_allocator=use_cpp_allocator,
with_2d_quantization=with_2d_quantization,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(128, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_extrema_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
extrema_high: bool,
return_transpose: bool,
use_cpp_allocator: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if extrema_high:
x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device)
else:
x = torch.zeros((M, N), dtype=x_dtype, device=device)
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(16, 128),
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_boundary_values(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
):
"""
Stress rounding/threshold behavior by placing values just below/above
many potential bin edges within each 16-element microblock.
Validates native vs reference byte-for-byte and scale parity.
"""
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Construct a single row with paired boundary values: v-eps, v+eps
# spanning a wide dynamic range to exercise clipping and multiple bins.
# Ensure even N and N is multiple of 16 for microblocks, which holds for 128.
base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device)
eps = torch.full_like(base, 1e-3)
# Avoid zero eps for very small magnitudes
eps = torch.maximum(eps, 1e-4 * torch.ones_like(base))
lower = base - eps
upper = base + eps
row = torch.empty(N, dtype=torch.float32, device=device)
row[0::2] = lower
row[1::2] = upper
x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype)
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only valid portion of scales (trim any padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 17
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Start from a contiguous tensor, then make a non-contiguous view by transpose
x_base = torch.randn((M, N), dtype=x_dtype, device=device)
x_nc = x_base.t() # shape (N, M), non-contiguous
assert not x_nc.is_contiguous()
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x_nc)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x_nc.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x_nc)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
# Quantized must match
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only valid portion of scales (trim padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
import pytest
import torch
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
contiguous: bool,
return_transpose: bool,
use_cpp_allocator: bool,
swizzled_scale: bool = False,
hadamard_dimension: int = 16,
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled."
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
x = x.transpose(0, 1) if not contiguous else x
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
amax_rowwise = x_nvfp4_sut._amax_rowwise
amax_colwise = x_nvfp4_sut._amax_columnwise
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
# Reference quantization using NVFP4QuantizerRef with built-in RHT
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
with_rht=with_rht,
with_random_sign_mask=with_random_sign_mask,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
ref_amax_rowwise = x_nvfp4_ref.global_amax_row
if return_transpose:
assert x_nvfp4_ref.data_t is not None
assert x_nvfp4_ref.scale_t is not None
qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8)
# Compute transpose amax using the same reference quantizer
x_t_for_amax = (
ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous()
)
ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1)
else:
qx_t_ref = None
sx_t_ref = None
ref_amax_colwise_t = None
torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# Real shapes,
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=True,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
):
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=False,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
_FP4_LUT = torch.tensor(
[
0.0, # 0: 0000 - zero
0.5, # 1: 0001 - smallest positive normal
1.0, # 2: 0010
1.5, # 3: 0011
2.0, # 4: 0100
3.0, # 5: 0101
4.0, # 6: 0110
6.0, # 7: 0111 - largest positive normal
-0.0, # 8: 1000 - negative zero
-0.5, # 9: 1001 - smallest negative normal
-1.0, # 10: 1010
-1.5, # 11: 1011
-2.0, # 12: 1100
-3.0, # 13: 1101
-4.0, # 14: 1110
-6.0, # 15: 1111 - largest negative normal
],
dtype=torch.float32,
)
def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor:
# Convert FP4 indices to their corresponding floating point values
# Each index (0-15) represents a 4-bit FP4 value in E2M1 format
# Values based on the FP4 E2M1 specification
fp4_lut = _FP4_LUT.to(fp4.device)
return fp4_lut[fp4.to(torch.long)]
def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor:
sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
dqx = fp4_to_fp32(unpack_fp4(qx))
sf = sf[: dqx.shape[0], : dqx.shape[1]]
dequant = dqx * sf * (amax / (6.0 * 448))
return dequant
def RHT(x: torch.Tensor) -> torch.Tensor:
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded signs for Hadamard transform"""
return torch.tensor(
[
1.0,
1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
],
dtype=torch.float32,
)
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
) -> torch.Tensor:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
h = torch.ones((1, 1), device=device, dtype=torch.float32)
while h.shape[0] < size:
h = torch.cat(
[
torch.cat([h, h], dim=1),
torch.cat([h, -h], dim=1),
],
dim=0,
)
if with_random_sign_mask:
sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
size, device=device, dtype=torch.float32
)
h = sign_mat @ h
return h.to(dtype)
rht_dim = 16
# Build H and scale
H = _build_hadamard_matrix(rht_dim, x.device, x.dtype)
scale = 1.0 / float(rht_dim) ** 0.5
# Perform blockwise transform along the last dimension
original_shape = x.shape
x_mat = x.contiguous().view(-1, rht_dim)
# Random sign matrix is identity in this reference (no sign flipping)
transform = H * scale
out = x_mat @ transform
return out.view(original_shape)
def quantize_fp4(
x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool
) -> torch.Tensor:
nvfp4_quantizer = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
x_nvfp4_sut = nvfp4_quantizer(x)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv
return qx, sx, qx_t, sx_t
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4(
x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT
)
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for i in range(n_iters):
q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4(
x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT
)
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# sr_result_tmp = sr_result / (i + 1)
# error_sr = (sr_result_tmp - x).float()
# me_sr = torch.sqrt((error_sr * error_sr).mean())
# sr_t_result_tmp = sr_t_result / (i + 1)
# error_t_sr = (sr_t_result_tmp - y).float()
# me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
# print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
# print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(8192, 8256), # to test the nonfused RHT path
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False, True], ids=str)
@pytest.mark.parametrize("use_RHT", [False, True], ids=str)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
)
...@@ -12,13 +12,15 @@ import pathlib ...@@ -12,13 +12,15 @@ import pathlib
import pytest import pytest
import torch import torch
from typing import Optional
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from utils import make_recipe from utils import make_recipe
# Check supported quantization schemes # Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
# Test cases for loading checkpoint files # Test cases for loading checkpoint files
...@@ -65,16 +67,16 @@ class TestLoadCheckpoint: ...@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
if name == "ops_linear": if name == "ops_linear":
return te.ops.Linear(1, 1) return te.ops.Linear(1, 1)
if name == "linear.fp8": if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")): with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16) return te.Linear(16, 16)
if name == "ops_linear.fp8": if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")): with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16) return te.ops.Linear(16, 16)
if name == "linear.mxfp8": if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")): with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32) return te.Linear(32, 32)
if name == "ops_linear.mxfp8": if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")): with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32) return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})") raise ValueError(f"Unrecognized module name ({name})")
......
...@@ -12,14 +12,13 @@ import torch ...@@ -12,14 +12,13 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes # Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available = te.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None] quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available: if fp8_available:
...@@ -79,9 +78,9 @@ def _warmup_model( ...@@ -79,9 +78,9 @@ def _warmup_model(
"""Perform forward and backward pass""" """Perform forward and backward pass"""
tensor = _make_input() tensor = _make_input()
for module in modules: for module in modules:
with te.fp8_autocast( with te.autocast(
enabled=quantization_recipe is not None, enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe, recipe=quantization_recipe,
): ):
tensor = module(tensor) tensor = module(tensor)
tensor.sum().backward() tensor.sum().backward()
...@@ -159,8 +158,8 @@ def _measure_cached_memory( ...@@ -159,8 +158,8 @@ def _measure_cached_memory(
tensor = inp tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2) memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules: for module in modules:
with te.fp8_autocast( with te.autocast(
enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context: ), offload_context:
tensor = module(tensor) tensor = module(tensor)
tensor = sync_function(tensor) tensor = sync_function(tensor)
......
...@@ -13,12 +13,15 @@ from transformer_engine.pytorch import ( ...@@ -13,12 +13,15 @@ from transformer_engine.pytorch import (
Linear, Linear,
MultiheadAttention, MultiheadAttention,
TransformerLayer, TransformerLayer,
fp8_autocast, autocast,
fp8_model_init, quantized_model_init,
make_graphed_callables, make_graphed_callables,
is_fp8_available,
is_fp8_block_scaling_available,
is_mxfp8_available,
is_bf16_available,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states from utils import ModelConfig, reset_rng_states
...@@ -28,20 +31,67 @@ if IS_HIP_EXTENSION: ...@@ -28,20 +31,67 @@ if IS_HIP_EXTENSION:
from functools import cache from functools import cache
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available = is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available = is_mxfp8_available()
# Reset RNG states. # Reset RNG states.
reset_rng_states() reset_rng_states()
model_configs = { model_configs = {
"small": ModelConfig(32, 2, 2, 32), "small": ModelConfig(2, 32, 2, 32),
} }
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False
def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes
fp8_recipes = [] fp8_recipes = []
if mxfp8_available: if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_rht_and_2d_quantization())
if fp8_block_scaling_available: if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling()) fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available: if fp8_available:
...@@ -50,7 +100,7 @@ if fp8_available: ...@@ -50,7 +100,7 @@ if fp8_available:
# Supported data types # Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16] dtypes: List[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16) dtypes.append(torch.bfloat16)
...@@ -167,7 +217,7 @@ def _test_cuda_graphs( ...@@ -167,7 +217,7 @@ def _test_cuda_graphs(
fp8_weight_caching = False fp8_weight_caching = False
# Create modules. # Create modules.
with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
if module == "transformer": if module == "transformer":
modules = [ modules = [
TransformerLayer( TransformerLayer(
...@@ -247,9 +297,9 @@ def _test_cuda_graphs( ...@@ -247,9 +297,9 @@ def _test_cuda_graphs(
model, model,
(generate_data(model_config, dtype, warmup=True),), (generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8, enabled=fp8,
fp8_weight_caching=fp8_weight_caching, cache_quantized_params=fp8_weight_caching,
fp8_recipe=fp8_recipe, recipe=fp8_recipe,
) )
elif graph_mode == "individual": elif graph_mode == "individual":
# Graph individual modules. # Graph individual modules.
...@@ -258,9 +308,9 @@ def _test_cuda_graphs( ...@@ -258,9 +308,9 @@ def _test_cuda_graphs(
module, module,
(generate_data(model_config, dtype, warmup=True),), (generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8, enabled=fp8,
fp8_weight_caching=fp8_weight_caching, cache_quantized_params=fp8_weight_caching,
fp8_recipe=fp8_recipe, recipe=fp8_recipe,
) )
for module in modules for module in modules
] ]
...@@ -277,7 +327,7 @@ def _test_cuda_graphs( ...@@ -277,7 +327,7 @@ def _test_cuda_graphs(
for grad_accumulation_step in range(2): for grad_accumulation_step in range(2):
input_ = generate_data(model_config, dtype) input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False) grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): with autocast(enabled=fp8, recipe=fp8_recipe):
kwargs = {} kwargs = {}
if fp8_weight_caching: if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0 kwargs["is_first_microbatch"] = grad_accumulation_step == 0
...@@ -291,7 +341,7 @@ def _test_cuda_graphs( ...@@ -291,7 +341,7 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
def test_make_graphed_callables( def test_make_graphed_callables(
*, *,
module: str, module: str,
...@@ -308,15 +358,25 @@ def test_make_graphed_callables( ...@@ -308,15 +358,25 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8: if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") pytest.skip(
f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
)
if fp8 and fp8_recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe"
f" {fp8_recipe.__class__.__name__}"
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip("FP8 not supported on rocm GPU.")
if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip("FP8 block scaling not supported on rocm GPU.")
if fp8 and fp8_recipe.mxfp8() and not mxfp8_available: if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip("MXFP8 not supported on rocm GPU.")
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
kwargs = dict( kwargs = dict(
...@@ -353,17 +413,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [ ...@@ -353,17 +413,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"module", "module",
_test_make_graphed_callables_with_fp8_weight_caching_modules, _test_make_graphed_callables_with_fp8_weight_caching_modules,
) )
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
def test_make_graphed_callables_with_fp8_weight_caching( def test_make_graphed_callables_with_fp8_weight_caching(
*, *,
module: str, module: str,
dtype: torch.dtype,
fp8_params: bool, fp8_params: bool,
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
) -> None: ) -> None:
test_make_graphed_callables( test_make_graphed_callables(
module=module, module=module,
dtype=torch.float32, dtype=dtype,
fp8_params=fp8_params, fp8_params=fp8_params,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
fp8_weight_caching=True, fp8_weight_caching=True,
...@@ -415,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention( ...@@ -415,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention(
model, model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True), generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=False, enabled=False,
) )
# Forward and backward passes. # Forward and backward passes.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch import (
autocast,
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
Float8CurrentScalingQuantizer,
)
import transformer_engine.pytorch.ops as te_ops
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
def test_custom_recipe_sanity(module_type):
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
# Simple linear layer with dims divisible by 16
in_features = 64
out_features = 64
batch = 32
if module_type == "Linear":
model = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
elif module_type == "LayerNormLinear":
model = LayerNormLinear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
elif module_type == "LayerNormMLP":
# hidden_size == in_features == out_features for simplicity
model = LayerNormMLP(
hidden_size=in_features, ffn_hidden_size=out_features, params_dtype=torch.bfloat16
).cuda()
else:
# OpsLinear path
model = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Single factory: map roles to quantizers
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
# Execute with custom recipe
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp)
loss = out.float().sum()
loss.backward()
# Basic sanity: gradients exist
assert inp.grad is not None
def test_custom_recipe_grouped_linear_sanity():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
num_gemms = 3
in_features = 64
out_features = 64
batch = 32
base = batch // num_gemms
rem = batch % num_gemms
m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)]
model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda()
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp, m_splits)
loss = out.float().sum()
loss.backward()
assert inp.grad is not None
def test_custom_recipe_matches_current_scaling():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(123)
in_features = 64
out_features = 64
batch = 32
# Create two identical models
model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
model_custom = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
model_custom.load_state_dict(model_ref.state_dict())
# Identical inputs for both paths
base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16)
inp_ref = base_inp.clone().detach().requires_grad_(True)
inp_custom = base_inp.clone().detach().requires_grad_(True)
# Reference: use Float8CurrentScaling recipe
ref_recipe = recipe.Float8CurrentScaling()
with autocast(enabled=True, recipe=ref_recipe):
out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3
assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2
assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2
# Stress dynamic range in grad_output
scale = torch.ones(out_features, device="cuda", dtype=torch.float32)
scale[0] = 1e8
scale[1] = 1e-8
loss_ref = (out_ref.float() * scale.view(1, -1)).sum()
loss_ref.backward()
# Custom: single factory returning quantizers per role to match Float8CurrentScaling
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with autocast(enabled=True, recipe=custom_recipe):
out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3
assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2
assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2
loss_custom = (out_custom.float() * scale.view(1, -1)).sum()
loss_custom.backward()
# Compare forward outputs (exact match expected)
assert torch.allclose(out_ref, out_custom, rtol=0.0, atol=0.0)
# Compare input gradients
assert inp_ref.grad is not None and inp_custom.grad is not None
assert torch.allclose(inp_ref.grad, inp_custom.grad, rtol=0.0, atol=0.0)
# Compare parameter gradients (weights and bias if present)
ref_params = dict(model_ref.named_parameters())
custom_params = dict(model_custom.named_parameters())
for name, p_ref in ref_params.items():
p_cus = custom_params[name]
assert p_ref.grad is not None and p_cus.grad is not None
assert torch.allclose(p_ref.grad, p_cus.grad, rtol=0.0, atol=0.0)
def test_custom_recipe_ops_linear_2_1_layout():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(7)
in_features = 64
out_features = 64
batch = 16
# Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer
op = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
with autocast(enabled=True, recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
assert inp.grad is not None
def test_custom_recipe_factory_invocation_counts_and_cycling():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(13)
in_features = 64
out_features = 64
batch = 8
op = Linear(in_features, out_features, params_dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Counters per role
counts = {
"linear_input": 0,
"linear_weight": 0,
"linear_output": 0,
"linear_grad_output": 0,
"linear_grad_input": 0,
}
def quantizer_factory(role):
if role in counts:
counts[role] += 1
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory).
with autocast(enabled=True, recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
# Single GEMM: forward should request input, weight, output; backward grad_output, grad_input
assert counts["linear_input"] == 1
assert counts["linear_weight"] == 1
assert counts["linear_output"] == 1
assert counts["linear_grad_output"] == 1
assert counts["linear_grad_input"] == 1
def test_factories_return_distinct_instances_and_buffers():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
# Two calls should produce distinct quantizer objects and distinct tensor buffers
def factory():
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
q1 = factory()
q2 = factory()
assert q1 is not q2
assert q1.scale.data_ptr() != q2.scale.data_ptr()
assert q1.amax.data_ptr() != q2.amax.data_ptr()
# Mutating one should not affect the other
q1.scale.fill_(123.0)
assert not torch.equal(q1.scale, q2.scale)
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import pytest import pytest
import torch import torch
import torch.distributed as dist
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8) from transformer_engine.pytorch.fp8 import (blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, get_device_compute_capability,
) )
from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_quantizer_reference import CuBLASScaleMunger
...@@ -19,8 +19,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -19,8 +19,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm
def fp8_blockwise_gemm_supported() -> bool: def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() supported = te.is_fp8_block_scaling_available()
return supported emulated = get_device_compute_capability() >= (10, 0)
return supported and not emulated
def cublas_gemm_fp8_blockwise_case( def cublas_gemm_fp8_blockwise_case(
......
...@@ -8,14 +8,13 @@ import os ...@@ -8,14 +8,13 @@ import os
import pathlib import pathlib
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, get_device_compute_capability,
) )
from references.blockwise_quantizer_reference import ( from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference, BlockwiseQuantizerReference,
...@@ -32,7 +31,8 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso ...@@ -32,7 +31,8 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None: if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() recipe_available, reason_for_no_recipe = te.is_fp8_block_scaling_available(return_reason=True)
recipe_emulated = get_device_compute_capability() >= (10, 0)
class GetRecipes: class GetRecipes:
...@@ -219,6 +219,12 @@ def check_quantization_block_tiling_versus_reference( ...@@ -219,6 +219,12 @@ def check_quantization_block_tiling_versus_reference(
pow_2_scales: bool, pow_2_scales: bool,
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
) -> None: ) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
te_dtype = TE_DType[quant_dtype] te_dtype = TE_DType[quant_dtype]
if tile_size in ((1, 128), (1, 64)): if tile_size in ((1, 128), (1, 64)):
block_scaling_dim = 1 block_scaling_dim = 1
...@@ -414,6 +420,12 @@ def test_quantization_block_tiling_extrema_versus_reference( ...@@ -414,6 +420,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
extrema_high: bool, extrema_high: bool,
) -> None: ) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
# This test runs a single tile through a quantizer as a way to test # This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation. # branch coverage of scale computation.
if blockwise_fp8_block_len != tile_size[1]: if blockwise_fp8_block_len != tile_size[1]:
......
...@@ -8,12 +8,9 @@ import torch ...@@ -8,12 +8,9 @@ import torch
import pytest import pytest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.fp8 import int8_simulation_fp8 from transformer_engine.pytorch.fp8 import int8_simulation_fp8
...@@ -25,7 +22,7 @@ if tensor_dump_dir_env is not None: ...@@ -25,7 +22,7 @@ if tensor_dump_dir_env is not None:
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class GetRecipes: class GetRecipes:
...@@ -274,6 +271,14 @@ class TestFP8RecipeLinearBase: ...@@ -274,6 +271,14 @@ class TestFP8RecipeLinearBase:
if bgrad_list is not None and bgrad is not None: if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone()) bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod @classmethod
def run_linear( def run_linear(
cls, cls,
...@@ -388,7 +393,7 @@ class TestFP8RecipeLinearBase: ...@@ -388,7 +393,7 @@ class TestFP8RecipeLinearBase:
# recipe1 # recipe1
using_fp8_recipe = recipe1() != GetRecipes.none() using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with autocast(enabled=True, recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
else: else:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
...@@ -396,7 +401,7 @@ class TestFP8RecipeLinearBase: ...@@ -396,7 +401,7 @@ class TestFP8RecipeLinearBase:
# recipe2 # recipe2
using_fp8_recipe = recipe2() != GetRecipes.none() using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()): with autocast(enabled=True, recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
else: else:
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
...@@ -611,7 +616,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -611,7 +616,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe1 # recipe1
using_fp8_recipe = recipe1() != GetRecipes.none() using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with autocast(enabled=True, recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x, x,
w, w,
...@@ -633,7 +638,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -633,7 +638,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe2 # recipe2
using_fp8_recipe = recipe2() != GetRecipes.none() using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()): with autocast(enabled=True, recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x, x,
w, w,
......
...@@ -11,12 +11,11 @@ import pytest ...@@ -11,12 +11,11 @@ import pytest
import torch import torch
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
get_device_compute_capability,
) )
from transformer_engine.pytorch.utils import get_device_compute_capability
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
......
...@@ -11,13 +11,11 @@ import torch ...@@ -11,13 +11,11 @@ import torch
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8Tensor, Float8Tensor,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: ...@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType = Union[Iterable[int], int] DimsType = Union[Iterable[int], int]
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
# delayed scaling # delayed scaling
......
...@@ -11,14 +11,11 @@ from torch import nn ...@@ -11,14 +11,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class TestFusedOptimizer: class TestFusedOptimizer:
...@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
build_model_context = nullcontext build_model_context = nullcontext
build_model_context_args = {} build_model_context_args = {}
if use_fp8_params: if use_fp8_params:
build_model_context = fp8_model_init build_model_context = quantized_model_init
build_model_context_args["enabled"] = True build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
...@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master(self): def test_fp32_master(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self): def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
store_param_remainders=True, store_param_remainders=True,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_master(self): def test_fp16_master(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_grad(self): def test_bf16_grad(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg(self): def test_fp16_exp_avg(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg(self): def test_bf16_exp_avg(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self): def test_fp8_exp_avg(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
...@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=1e-2, master_atol=1e-2,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self): def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self): def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self): def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
...@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
skip_assert=True, skip_assert=True,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self): def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
model = MultiheadAttention( model = MultiheadAttention(
...@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self): def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
with fp8_model_init(enabled=True, recipe=DelayedScaling()): with quantized_model_init(enabled=True, recipe=DelayedScaling()):
model = MultiheadAttention( model = MultiheadAttention(
hidden_size=1024, hidden_size=1024,
num_attention_heads=16, num_attention_heads=16,
......
...@@ -373,3 +373,19 @@ def test_fused_qkv_rope( ...@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if not isinstance(start_positions, torch.Tensor): if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast():
rope_layer = RotaryPositionEmbedding(128)
rope_embeddings_no_autocast = rope_layer(max_seq_len=1024)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
rope_embeddings_autocast = rope_layer(max_seq_len=1024)
torch.testing.assert_close(
rope_embeddings_no_autocast.to(dtype=torch.bfloat16),
rope_embeddings_autocast.to(dtype=torch.bfloat16),
atol=1e-8,
rtol=1e-8,
)
...@@ -7,8 +7,6 @@ from __future__ import annotations ...@@ -7,8 +7,6 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import io import io
import math import math
import pathlib
import sys
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -18,7 +16,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -18,7 +16,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine import transformer_engine
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias, BackwardActivationBias,
...@@ -29,20 +26,18 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -29,20 +26,18 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
ForwardLinearScaleAdd, ForwardLinearScaleAdd,
) )
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_tensor import ( QuantizedTensor,
Float8Tensor,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
Float8Quantizer, Float8Quantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe, reset_rng_states
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import os import os
...@@ -52,13 +47,14 @@ if IS_HIP_EXTENSION: ...@@ -52,13 +47,14 @@ if IS_HIP_EXTENSION:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None ) or os.getenv("NVTE_USE_ROCBLAS") is None )
# Check if FP8 is supported # Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
# Supported data types # Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16] _dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16) _dtypes.append(torch.bfloat16)
# Supported devices # Supported devices
...@@ -70,6 +66,8 @@ if fp8_available: ...@@ -70,6 +66,8 @@ if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
_quantization_list.append("mxfp8") _quantization_list.append("mxfp8")
if nvfp4_available:
_quantization_list.append("nvfp4")
def maybe_skip_quantization( def maybe_skip_quantization(
...@@ -77,6 +75,7 @@ def maybe_skip_quantization( ...@@ -77,6 +75,7 @@ def maybe_skip_quantization(
*, *,
dims: Optional[Iterable[int] | int] = None, dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None, device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
"""Skip test case if a quantization scheme is not supported""" """Skip test case if a quantization scheme is not supported"""
...@@ -84,12 +83,17 @@ def maybe_skip_quantization( ...@@ -84,12 +83,17 @@ def maybe_skip_quantization(
if quantization is None: if quantization is None:
return return
# Check if quantization scheme is supported # Check if quantization scheme is supported on device
if device is not None and torch.device(device).type != "cuda":
pytest.skip("Quantization is only supported on CUDA devices")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available: if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Check dims
if dims is not None: if dims is not None:
if not isinstance(dims, Iterable): if not isinstance(dims, Iterable):
dims = (dims,) dims = (dims,)
...@@ -99,10 +103,14 @@ def maybe_skip_quantization( ...@@ -99,10 +103,14 @@ def maybe_skip_quantization(
elif quantization == "mxfp8": elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")
elif quantization == "nvfp4":
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("NVFP4 GEMMs require dims that are divisible by 16")
# Check if device is supported # Check dtype
if device is not None and torch.device(device).type != "cuda": if dtype is not None:
pytest.skip("Quantization is only supported on CUDA devices") if quantization == "nvfp4" and dtype != torch.bfloat16:
pytest.skip("NVFP4 quantization is only supported with BF16 data")
@torch.no_grad() @torch.no_grad()
...@@ -152,6 +160,14 @@ def make_reference_and_test_tensors( ...@@ -152,6 +160,14 @@ def make_reference_and_test_tensors(
test = quantizer(test) test = quantizer(test)
elif quantization == "mxfp8": elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else: else:
raise ValueError(f"Unsupported quantization scheme ({quantization})") raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized: if isinstance(test, QuantizedTensor) and not test_is_quantized:
...@@ -361,7 +377,7 @@ class TestFuser: ...@@ -361,7 +377,7 @@ class TestFuser:
) )
# Construct model # Construct model
with te.fp8_model_init(recipe=recipe): with te.quantized_model_init(recipe=recipe):
model = te_ops.basic.BasicLinear( model = te_ops.basic.BasicLinear(
size, size,
size, size,
...@@ -393,7 +409,7 @@ class TestFuser: ...@@ -393,7 +409,7 @@ class TestFuser:
) )
# Training step # Training step
with te.fp8_autocast(fp8_recipe=recipe): with te.autocast(recipe=recipe):
y = model(x) y = model(x)
y.backward(dy) y.backward(dy)
with torch.no_grad(): with torch.no_grad():
...@@ -406,12 +422,12 @@ class TestFuser: ...@@ -406,12 +422,12 @@ class TestFuser:
torch.testing.assert_close( torch.testing.assert_close(
y, y,
torch.full_like(y, y_val_ref), torch.full_like(y, y_val_ref),
**dtype_tols(tex.DType.kFloat8E4M3), **quantization_tols("fp8_delayed_scaling"),
) )
torch.testing.assert_close( torch.testing.assert_close(
x.grad, x.grad,
torch.full_like(x.grad, dx_val_ref), torch.full_like(x.grad, dx_val_ref),
**dtype_tols(tex.DType.kFloat8E5M2), **quantization_tols("fp8_delayed_scaling"),
) )
# Check that scaling factors match expected # Check that scaling factors match expected
...@@ -445,7 +461,8 @@ class TestFuser: ...@@ -445,7 +461,8 @@ class TestFuser:
# Skip invalid configurations # Skip invalid configurations
in_shape = (size, size) in_shape = (size, size)
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype)
maybe_skip_quantization(quantization, dtype=final_dtype)
# Random data # Random data
dtype = torch.float32 dtype = torch.float32
...@@ -461,7 +478,7 @@ class TestFuser: ...@@ -461,7 +478,7 @@ class TestFuser:
) )
# Construct operation # Construct operation
with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad(): with torch.no_grad():
op.weight.copy_(w_test) op.weight.copy_(w_test)
...@@ -513,11 +530,12 @@ class TestFuser: ...@@ -513,11 +530,12 @@ class TestFuser:
# Skip invalid configurations # Skip invalid configurations
in_shape = (size, size) in_shape = (size, size)
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype)
maybe_skip_quantization(quantization, dtype=autocast_dtype)
# Construct operation # Construct operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): with te.quantized_model_init(enabled=quantized_weights, recipe=recipe):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass # Check forward and backward pass
...@@ -527,7 +545,7 @@ class TestFuser: ...@@ -527,7 +545,7 @@ class TestFuser:
device=device, device=device,
requires_grad=True, requires_grad=True,
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
with torch.autocast(device_type=device.type, dtype=autocast_dtype): with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x) y = op(x)
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
...@@ -540,7 +558,7 @@ class TestFuser: ...@@ -540,7 +558,7 @@ class TestFuser:
x.grad = None x.grad = None
op.weight.grad = None op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype): with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = op(x) y = op(x)
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype assert y.dtype == autocast_dtype
...@@ -569,7 +587,7 @@ class TestBasicOps: ...@@ -569,7 +587,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -635,7 +653,7 @@ class TestBasicOps: ...@@ -635,7 +653,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4: if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors") pytest.skip("torch.channels_last only supports 4D tensors")
maybe_skip_quantization(quantization, device=device) maybe_skip_quantization(quantization, device=device, dtype=dtype)
with_quantization = quantization is not None with_quantization = quantization is not None
# Random data # Random data
...@@ -701,7 +719,7 @@ class TestBasicOps: ...@@ -701,7 +719,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -763,7 +781,7 @@ class TestBasicOps: ...@@ -763,7 +781,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device) maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8": if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape) maybe_skip_quantization(quantization, dims=in_shape)
...@@ -790,7 +808,7 @@ class TestBasicOps: ...@@ -790,7 +808,7 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -830,7 +848,7 @@ class TestBasicOps: ...@@ -830,7 +848,7 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features] out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
quantization_needed = any( quantization_needed = any(
( (
...@@ -887,7 +905,7 @@ class TestBasicOps: ...@@ -887,7 +905,7 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear( op = te_ops.BasicLinear(
in_features, in_features,
out_features, out_features,
...@@ -904,7 +922,7 @@ class TestBasicOps: ...@@ -904,7 +922,7 @@ class TestBasicOps:
op, op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -913,7 +931,7 @@ class TestBasicOps: ...@@ -913,7 +931,7 @@ class TestBasicOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute or quantized_output or quantized_grad_input: if quantized_compute or quantized_output or quantized_grad_input:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1024,7 +1042,7 @@ class TestBasicOps: ...@@ -1024,7 +1042,7 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features] out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight): if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified") pytest.skip("Quantization scheme is not specified")
...@@ -1065,7 +1083,7 @@ class TestBasicOps: ...@@ -1065,7 +1083,7 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.Linear( op = te_ops.Linear(
in_features, in_features,
out_features, out_features,
...@@ -1081,7 +1099,7 @@ class TestBasicOps: ...@@ -1081,7 +1099,7 @@ class TestBasicOps:
del b_test del b_test
for param in op.parameters(): for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad) param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
if input_requires_grad or weight_requires_grad: if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1091,7 +1109,7 @@ class TestBasicOps: ...@@ -1091,7 +1109,7 @@ class TestBasicOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1128,7 +1146,7 @@ class TestBasicOps: ...@@ -1128,7 +1146,7 @@ class TestBasicOps:
in_shape = list(in_shape)[:-1] + list(weight_shape) in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1182,14 +1200,14 @@ class TestBasicOps: ...@@ -1182,14 +1200,14 @@ class TestBasicOps:
op, op,
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1298,7 +1316,7 @@ class TestBasicOps: ...@@ -1298,7 +1316,7 @@ class TestBasicOps:
in_shape = list(in_shape)[:-1] + list(weight_shape) in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1344,14 +1362,14 @@ class TestBasicOps: ...@@ -1344,14 +1362,14 @@ class TestBasicOps:
op, op,
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1431,7 +1449,7 @@ class TestBasicOps: ...@@ -1431,7 +1449,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x1_ref, x1_test = make_reference_and_test_tensors( x1_ref, x1_test = make_reference_and_test_tensors(
...@@ -1470,8 +1488,11 @@ class TestBasicOps: ...@@ -1470,8 +1488,11 @@ class TestBasicOps:
# Check results # Check results
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if with_quantization: if in_place:
tols = dtype_tols(x1_test._fp8_dtype) if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"):
tols = dtype_tols(x1_test._fp8_dtype)
elif quantization == "nvfp4":
tols = dtype_tols(x1_test._fp4_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
...@@ -1500,7 +1521,7 @@ class TestBasicOps: ...@@ -1500,7 +1521,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1573,7 +1594,7 @@ class TestBasicOps: ...@@ -1573,7 +1594,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if cache_quantized_input: if cache_quantized_input:
maybe_skip_quantization("fp8_current_scaling", device=device) maybe_skip_quantization("fp8_current_scaling", device=device)
...@@ -1641,14 +1662,16 @@ class TestBasicOps: ...@@ -1641,14 +1662,16 @@ class TestBasicOps:
make_op(cache_quantized_input=cache_quantized_input), make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute or cache_quantized_input: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
elif cache_quantized_input:
tols = quantization_tols("fp8_current_scaling")
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1679,7 +1702,7 @@ class TestBasicOps: ...@@ -1679,7 +1702,7 @@ class TestBasicOps:
quantized_compute = quantization is not None quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward): if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided") pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1706,13 +1729,87 @@ class TestBasicOps: ...@@ -1706,13 +1729,87 @@ class TestBasicOps:
te_ops.SwiGLU(), te_ops.SwiGLU(),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_clamped_swiglu(
self,
*,
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
limit: float = 0.75,
alpha: float = 1.702,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute and quantization == "nvfp4":
tols = dtype_tols(tex.DType.kFloat4E2M1)
elif quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results # Check results
...@@ -1781,7 +1878,7 @@ class TestBasicOps: ...@@ -1781,7 +1878,7 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
quantized_input = quantization is not None quantized_input = quantization is not None
maybe_skip_quantization(quantization, dims=shape, device=device) maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype)
# Random data # Random data
# Note: Shift values to make sure inputs are non-zero # Note: Shift values to make sure inputs are non-zero
...@@ -1872,7 +1969,7 @@ class TestFusedOps: ...@@ -1872,7 +1969,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if dtype not in (torch.float16, torch.bfloat16): if dtype not in (torch.float16, torch.bfloat16):
pytest.skip( pytest.skip(
...@@ -1913,7 +2010,7 @@ class TestFusedOps: ...@@ -1913,7 +2010,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -1929,7 +2026,7 @@ class TestFusedOps: ...@@ -1929,7 +2026,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test) model[0].bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1943,7 +2040,7 @@ class TestFusedOps: ...@@ -1943,7 +2040,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1979,7 +2076,7 @@ class TestFusedOps: ...@@ -1979,7 +2076,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2023,7 +2120,7 @@ class TestFusedOps: ...@@ -2023,7 +2120,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2040,7 +2137,7 @@ class TestFusedOps: ...@@ -2040,7 +2137,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test) model[0].bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test) y_test = model(x1_test, x2_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2054,7 +2151,7 @@ class TestFusedOps: ...@@ -2054,7 +2151,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2093,7 +2190,7 @@ class TestFusedOps: ...@@ -2093,7 +2190,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2130,7 +2227,7 @@ class TestFusedOps: ...@@ -2130,7 +2227,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2146,7 +2243,7 @@ class TestFusedOps: ...@@ -2146,7 +2243,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[0].weight.copy_(w_test) model[0].weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test) y_test = model(x1_test, x2_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2161,7 +2258,7 @@ class TestFusedOps: ...@@ -2161,7 +2258,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2194,7 +2291,7 @@ class TestFusedOps: ...@@ -2194,7 +2291,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device) maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8") pytest.skip("Unsupported tensor size for MXFP8")
...@@ -2237,7 +2334,7 @@ class TestFusedOps: ...@@ -2237,7 +2334,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[1].bias.copy_(b_test) model[1].bias.copy_(b_test)
del b_test del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2256,7 +2353,7 @@ class TestFusedOps: ...@@ -2256,7 +2353,7 @@ class TestFusedOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if with_quantization: if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2375,7 +2472,7 @@ class TestFusedOps: ...@@ -2375,7 +2472,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2415,7 +2512,7 @@ class TestFusedOps: ...@@ -2415,7 +2512,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight): with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.MakeExtraOutput(in_place=True), te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear( te_ops.Linear(
...@@ -2429,7 +2526,7 @@ class TestFusedOps: ...@@ -2429,7 +2526,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[1].weight.copy_(w_test) model[1].weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y1_test, y2_test = model(x_test) y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward() (y1_test * dy1_test + y2_test * dy2_test).sum().backward()
...@@ -2443,7 +2540,7 @@ class TestFusedOps: ...@@ -2443,7 +2540,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu") y1_test = y1_test.to(dtype=torch.float64, device="cpu")
...@@ -2479,7 +2576,7 @@ class TestFusedOps: ...@@ -2479,7 +2576,7 @@ class TestFusedOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16): if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
...@@ -2511,7 +2608,7 @@ class TestFusedOps: ...@@ -2511,7 +2608,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight): with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2525,7 +2622,7 @@ class TestFusedOps: ...@@ -2525,7 +2622,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[0].weight.copy_(w_test) model[0].weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
(y_test * dy_test).sum().backward() (y_test * dy_test).sum().backward()
...@@ -2539,7 +2636,7 @@ class TestFusedOps: ...@@ -2539,7 +2636,7 @@ class TestFusedOps:
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute: if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -2580,12 +2677,12 @@ class TestCheckpointing: ...@@ -2580,12 +2677,12 @@ class TestCheckpointing:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
# Construct model # Construct model
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential( model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype) te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
) )
...@@ -2596,7 +2693,7 @@ class TestCheckpointing: ...@@ -2596,7 +2693,7 @@ class TestCheckpointing:
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device) dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad() optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(x) y = model_save(x)
y.backward(dy) y.backward(dy)
optim_save.step() optim_save.step()
...@@ -2625,14 +2722,14 @@ class TestCheckpointing: ...@@ -2625,14 +2722,14 @@ class TestCheckpointing:
ys_save = [] ys_save = []
for i in range(post_checkpoint_steps): for i in range(post_checkpoint_steps):
optim_save.zero_grad() optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(xs_save[i]) y = model_save(xs_save[i])
y.backward(dys[i]) y.backward(dys[i])
optim_save.step() optim_save.step()
ys_save.append(y) ys_save.append(y)
# Load checkpoint # Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential( model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype) te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
) )
...@@ -2645,7 +2742,7 @@ class TestCheckpointing: ...@@ -2645,7 +2742,7 @@ class TestCheckpointing:
ys_load = [] ys_load = []
for i in range(post_checkpoint_steps): for i in range(post_checkpoint_steps):
optim_load.zero_grad() optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_load(xs_load[i]) y = model_load(xs_load[i])
y.backward(dys[i]) y.backward(dys[i])
optim_load.step() optim_load.step()
...@@ -2706,7 +2803,7 @@ class TestSequentialModules: ...@@ -2706,7 +2803,7 @@ class TestSequentialModules:
ffn_shape = in_shape[:-1] + (ffn_hidden_size,) ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device) maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed: if quantization is None and quantization_needed:
...@@ -2732,7 +2829,7 @@ class TestSequentialModules: ...@@ -2732,7 +2829,7 @@ class TestSequentialModules:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm": if normalization == "LayerNorm":
norm = te_ops.LayerNorm( norm = te_ops.LayerNorm(
hidden_size, hidden_size,
...@@ -2763,6 +2860,6 @@ class TestSequentialModules: ...@@ -2763,6 +2860,6 @@ class TestSequentialModules:
dtype=dtype, dtype=dtype,
) )
forward = te_ops.Sequential(norm, ffn1, act, ffn2) forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch import TransformerLayer
class SimpleTEModel(PreTrainedModel): class SimpleTEModel(PreTrainedModel):
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
......
...@@ -13,18 +13,15 @@ import torch.nn as nn ...@@ -13,18 +13,15 @@ import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.quantization import FP8GlobalStateManager
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
attention_mask_func, attention_mask_func,
is_bf16_compatible,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
autocast,
quantized_model_init,
DotProductAttention, DotProductAttention,
LayerNormLinear, LayerNormLinear,
LayerNormMLP, LayerNormMLP,
...@@ -36,27 +33,29 @@ from transformer_engine.pytorch import ( ...@@ -36,27 +33,29 @@ from transformer_engine.pytorch import (
LayerNorm, LayerNorm,
Fp8Padding, Fp8Padding,
Fp8Unpadding, Fp8Unpadding,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
get_device_compute_capability,
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_bf16_available,
) )
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available = is_fp8_block_scaling_available(return_reason=True)
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -82,7 +81,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"] ...@@ -82,7 +81,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"] input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
batch_sizes = [1, 2] batch_sizes = [1, 2]
...@@ -553,7 +552,7 @@ def _test_e2e_selective_recompute( ...@@ -553,7 +552,7 @@ def _test_e2e_selective_recompute(
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -580,7 +579,7 @@ def _test_e2e_selective_recompute( ...@@ -580,7 +579,7 @@ def _test_e2e_selective_recompute(
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
...@@ -649,7 +648,7 @@ def _test_e2e_full_recompute( ...@@ -649,7 +648,7 @@ def _test_e2e_full_recompute(
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -677,7 +676,7 @@ def _test_e2e_full_recompute( ...@@ -677,7 +676,7 @@ def _test_e2e_full_recompute(
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
if recompute: if recompute:
te_out = te_checkpoint( te_out = te_checkpoint(
block, block,
...@@ -1107,7 +1106,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, ...@@ -1107,7 +1106,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
out = block(inp_hidden_states) out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)): if isinstance(out, (List, Tuple)):
out = out[0] out = out[0]
...@@ -1328,7 +1327,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1328,7 +1327,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear( te_linear_ref = Linear(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -1782,7 +1781,7 @@ def _test_grouped_linear_accuracy( ...@@ -1782,7 +1781,7 @@ def _test_grouped_linear_accuracy(
else: else:
m_splits = torch.tensor([config.max_seqlen_q]) m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
if isinstance(block, GroupedLinear): if isinstance(block, GroupedLinear):
m_splits = m_splits * bs m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist()) out = block(inp_hidden_states, m_splits.tolist())
...@@ -1850,7 +1849,7 @@ def test_grouped_linear_accuracy( ...@@ -1850,7 +1849,7 @@ def test_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear( grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -1994,7 +1993,7 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1994,7 +1993,7 @@ def test_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear( grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2154,7 +2153,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -2154,7 +2153,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding): if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits) out = block(inp_hidden_states, m_splits)
else: else:
...@@ -2208,7 +2207,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -2208,7 +2207,7 @@ def test_padding_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding( grouped_linear = TorchGroupedLinearWithPadding(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2219,7 +2218,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -2219,7 +2218,7 @@ def test_padding_grouped_linear_accuracy(
fp8=fp8, fp8=fp8,
).eval() ).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear( ref_grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2285,7 +2284,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2285,7 +2284,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding( grouped_linear = TorchGroupedLinearWithPadding(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2296,7 +2295,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2296,7 +2295,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8=fp8, fp8=fp8,
).eval() ).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear( ref_grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2446,7 +2445,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2446,7 +2445,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8_model_params, recipe=recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -2473,7 +2472,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2473,7 +2472,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe): with autocast(enabled=True, recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
......
...@@ -33,10 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op ...@@ -33,10 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.onnx_extensions import te_translation_table
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt import tensorrt as trt
...@@ -59,8 +58,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( ...@@ -59,8 +58,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
# The directory where this file is stored. # The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_recipes = [] fp8_recipes = []
if mxfp8_available: if mxfp8_available:
...@@ -179,8 +178,8 @@ def do_export( ...@@ -179,8 +178,8 @@ def do_export(
input_names = input_names or ["input"] input_names = input_names or ["input"]
output_names = output_names or ["output"] output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast( with torch.inference_mode(), te.autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe enabled=fp8_recipe is not None, recipe=fp8_recipe
), warnings.catch_warnings(): ), warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
...@@ -234,8 +233,8 @@ def te_infer( ...@@ -234,8 +233,8 @@ def te_infer(
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
): ):
"""Transformer Engine forward propagation.""" """Transformer Engine forward propagation."""
with torch.inference_mode(), te.fp8_autocast( with torch.inference_mode(), te.autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe enabled=is_fp8, recipe=fp8_recipe
), warnings.catch_warnings(): ), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple): if not isinstance(te_outputs, tuple):
...@@ -441,7 +440,7 @@ def _test_export_linear( ...@@ -441,7 +440,7 @@ def _test_export_linear(
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
device="cuda" device="cuda"
) )
...@@ -507,7 +506,7 @@ def _test_export_layernorm( ...@@ -507,7 +506,7 @@ def _test_export_layernorm(
fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx" fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx"
with torch.no_grad(): with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm
model = layernorm_cls( model = layernorm_cls(
hidden_size, hidden_size,
...@@ -577,7 +576,7 @@ def _test_export_layernorm_linear( ...@@ -577,7 +576,7 @@ def _test_export_layernorm_linear(
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with torch.no_grad(): with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = te.LayerNormLinear( model = te.LayerNormLinear(
hidden_size, hidden_size,
3 * hidden_size, 3 * hidden_size,
...@@ -673,7 +672,7 @@ def _test_export_layernorm_mlp( ...@@ -673,7 +672,7 @@ def _test_export_layernorm_mlp(
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = te.LayerNormMLP( model = te.LayerNormMLP(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -1215,13 +1214,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): ...@@ -1215,13 +1214,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
).eval() ).eval()
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
out_ref = model(*inps) out_ref = model(*inps)
onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx")
os.close(onnx_fd) os.close(onnx_fd)
try: try:
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
with te.onnx_export(enabled=True): with te.onnx_export(enabled=True):
torch.onnx.export( torch.onnx.export(
model, model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment