Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
swizzled_scale: bool,
use_cpp_allocator: bool,
with_2d_quantization: bool,
) -> None:
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=with_2d_quantization,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
# Reference quantization
quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16)
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=quant_tile_shape,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
if x_nvfp4_ref.data_t is not None
else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# # largest tile
(8192, 8192),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"])
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"]
)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
swizzled_scale: bool,
use_cpp_allocator: bool,
with_2d_quantization: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
return_transpose=return_transpose,
swizzled_scale=swizzled_scale,
use_cpp_allocator=use_cpp_allocator,
with_2d_quantization=with_2d_quantization,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(128, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_extrema_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
extrema_high: bool,
return_transpose: bool,
use_cpp_allocator: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if extrema_high:
x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device)
else:
x = torch.zeros((M, N), dtype=x_dtype, device=device)
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(16, 128),
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_boundary_values(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
):
"""
Stress rounding/threshold behavior by placing values just below/above
many potential bin edges within each 16-element microblock.
Validates native vs reference byte-for-byte and scale parity.
"""
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Construct a single row with paired boundary values: v-eps, v+eps
# spanning a wide dynamic range to exercise clipping and multiple bins.
# Ensure even N and N is multiple of 16 for microblocks, which holds for 128.
base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device)
eps = torch.full_like(base, 1e-3)
# Avoid zero eps for very small magnitudes
eps = torch.maximum(eps, 1e-4 * torch.ones_like(base))
lower = base - eps
upper = base + eps
row = torch.empty(N, dtype=torch.float32, device=device)
row[0::2] = lower
row[1::2] = upper
x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype)
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, N), dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only valid portion of scales (trim any padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
seed = 17
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Start from a contiguous tensor, then make a non-contiguous view by transpose
x_base = torch.randn((M, N), dtype=x_dtype, device=device)
x_nc = x_base.t() # shape (N, M), non-contiguous
assert not x_nc.is_contiguous()
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x_nc)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x_nc.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_amax = x_nvfp4_sut._amax_rowwise
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
)
x_nvfp4_ref = ref_quantizer.quantize(x_nc)
qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
qx_t_ref = (
x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
)
sx_t_ref = (
x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
)
ref_amax = x_nvfp4_ref.global_amax_row
# Quantized must match
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only valid portion of scales (trim padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
import pytest
import torch
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
contiguous: bool,
return_transpose: bool,
use_cpp_allocator: bool,
swizzled_scale: bool = False,
hadamard_dimension: int = 16,
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled."
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
x = x.transpose(0, 1) if not contiguous else x
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
if use_cpp_allocator:
x_nvfp4_sut = nvfp4_quantizer(x)
else:
x_nvfp4_sut = nvfp4_quantizer.make_empty(
x.shape, dtype=x_dtype, device=device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
qx_t = (
x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
if x_nvfp4_sut._columnwise_data is not None
else None
)
sx_t = x_nvfp4_sut._columnwise_scale_inv
amax_rowwise = x_nvfp4_sut._amax_rowwise
amax_colwise = x_nvfp4_sut._amax_columnwise
qx = unpack_fp4(qx)
qx_t = unpack_fp4(qx_t) if qx_t is not None else None
# Reference quantization using NVFP4QuantizerRef with built-in RHT
ref_quantizer = NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
rowwise=True,
columnwise=return_transpose,
pow_2_scales=False,
eps=0.0,
quant_tile_shape=(1, 16),
with_rht=with_rht,
with_random_sign_mask=with_random_sign_mask,
)
x_nvfp4_ref = ref_quantizer.quantize(x)
# Extract data from RefNVFP4Tensor
qx_ref = (
unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
if x_nvfp4_ref.data is not None
else None
)
sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
ref_amax_rowwise = x_nvfp4_ref.global_amax_row
if return_transpose:
assert x_nvfp4_ref.data_t is not None
assert x_nvfp4_ref.scale_t is not None
qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8)
# Compute transpose amax using the same reference quantizer
x_t_for_amax = (
ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous()
)
ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1)
else:
qx_t_ref = None
sx_t_ref = None
ref_amax_colwise_t = None
torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape = sx_ref.shape
sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0)
torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape = sx_t_ref.shape
sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# Some larger tiles
(2048, 2048),
(1024, 2048),
(2048, 1024),
# Real shapes,
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
) -> None:
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=True,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
)
@pytest.mark.parametrize(
"use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
N: int,
return_transpose: bool,
use_cpp_allocator: bool,
with_random_sign_mask: bool,
):
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
contiguous=False,
return_transpose=return_transpose,
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(2, dim=1)
repeated[:, 0::2] &= 0x0F
repeated[:, 1::2] >>= 4
return repeated
_FP4_LUT = torch.tensor(
[
0.0, # 0: 0000 - zero
0.5, # 1: 0001 - smallest positive normal
1.0, # 2: 0010
1.5, # 3: 0011
2.0, # 4: 0100
3.0, # 5: 0101
4.0, # 6: 0110
6.0, # 7: 0111 - largest positive normal
-0.0, # 8: 1000 - negative zero
-0.5, # 9: 1001 - smallest negative normal
-1.0, # 10: 1010
-1.5, # 11: 1011
-2.0, # 12: 1100
-3.0, # 13: 1101
-4.0, # 14: 1110
-6.0, # 15: 1111 - largest negative normal
],
dtype=torch.float32,
)
def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor:
# Convert FP4 indices to their corresponding floating point values
# Each index (0-15) represents a 4-bit FP4 value in E2M1 format
# Values based on the FP4 E2M1 specification
fp4_lut = _FP4_LUT.to(fp4.device)
return fp4_lut[fp4.to(torch.long)]
def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor:
sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
dqx = fp4_to_fp32(unpack_fp4(qx))
sf = sf[: dqx.shape[0], : dqx.shape[1]]
dequant = dqx * sf * (amax / (6.0 * 448))
return dequant
def RHT(x: torch.Tensor) -> torch.Tensor:
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded signs for Hadamard transform"""
return torch.tensor(
[
1.0,
1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
1.0,
-1.0,
1.0,
-1.0,
-1.0,
],
dtype=torch.float32,
)
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
) -> torch.Tensor:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
h = torch.ones((1, 1), device=device, dtype=torch.float32)
while h.shape[0] < size:
h = torch.cat(
[
torch.cat([h, h], dim=1),
torch.cat([h, -h], dim=1),
],
dim=0,
)
if with_random_sign_mask:
sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
size, device=device, dtype=torch.float32
)
h = sign_mat @ h
return h.to(dtype)
rht_dim = 16
# Build H and scale
H = _build_hadamard_matrix(rht_dim, x.device, x.dtype)
scale = 1.0 / float(rht_dim) ** 0.5
# Perform blockwise transform along the last dimension
original_shape = x.shape
x_mat = x.contiguous().view(-1, rht_dim)
# Random sign matrix is identity in this reference (no sign flipping)
transform = H * scale
out = x_mat @ transform
return out.view(original_shape)
def quantize_fp4(
x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool
) -> torch.Tensor:
nvfp4_quantizer = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
x_nvfp4_sut = nvfp4_quantizer(x)
# Extract data from NVFP4Tensor
assert x_nvfp4_sut._rowwise_data is not None
qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv
return qx, sx, qx_t, sx_t
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4(
x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT
)
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for i in range(n_iters):
q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4(
x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT
)
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# sr_result_tmp = sr_result / (i + 1)
# error_sr = (sr_result_tmp - x).float()
# me_sr = torch.sqrt((error_sr * error_sr).mean())
# sr_t_result_tmp = sr_t_result / (i + 1)
# error_t_sr = (sr_t_result_tmp - y).float()
# me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
# print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
# print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(8192, 8256), # to test the nonfused RHT path
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False, True], ids=str)
@pytest.mark.parametrize("use_RHT", [False, True], ids=str)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
)
......@@ -12,13 +12,15 @@ import pathlib
import pytest
import torch
from typing import Optional
import transformer_engine.pytorch as te
from utils import make_recipe
# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
# Test cases for loading checkpoint files
......@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
if name == "ops_linear":
return te.ops.Linear(1, 1)
if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16)
if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16)
if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32)
if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})")
......
......@@ -12,14 +12,13 @@ import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
......@@ -79,9 +78,9 @@ def _warmup_model(
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.fp8_autocast(
with te.autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
......@@ -159,8 +158,8 @@ def _measure_cached_memory(
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.fp8_autocast(
enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
......
......@@ -13,12 +13,15 @@ from transformer_engine.pytorch import (
Linear,
MultiheadAttention,
TransformerLayer,
fp8_autocast,
fp8_model_init,
autocast,
quantized_model_init,
make_graphed_callables,
is_fp8_available,
is_fp8_block_scaling_available,
is_mxfp8_available,
is_bf16_available,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
......@@ -28,20 +31,67 @@ if IS_HIP_EXTENSION:
from functools import cache
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
model_configs = {
"small": ModelConfig(32, 2, 2, 32),
"small": ModelConfig(2, 32, 2, 32),
}
def nvfp4_vanilla():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
return nvfp4_recipe
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False
def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
fp8_recipes.append(nvfp4_rht_and_2d_quantization())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
......@@ -50,7 +100,7 @@ if fp8_available:
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
if is_bf16_available(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
......@@ -167,7 +217,7 @@ def _test_cuda_graphs(
fp8_weight_caching = False
# Create modules.
with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
if module == "transformer":
modules = [
TransformerLayer(
......@@ -247,9 +297,9 @@ def _test_cuda_graphs(
model,
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
enabled=fp8,
cache_quantized_params=fp8_weight_caching,
recipe=fp8_recipe,
)
elif graph_mode == "individual":
# Graph individual modules.
......@@ -258,9 +308,9 @@ def _test_cuda_graphs(
module,
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
enabled=fp8,
cache_quantized_params=fp8_weight_caching,
recipe=fp8_recipe,
)
for module in modules
]
......@@ -277,7 +327,7 @@ def _test_cuda_graphs(
for grad_accumulation_step in range(2):
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8, recipe=fp8_recipe):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0
......@@ -291,7 +341,7 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
def test_make_graphed_callables(
*,
module: str,
......@@ -308,15 +358,25 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
pytest.skip(
f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
)
if fp8 and fp8_recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe"
f" {fp8_recipe.__class__.__name__}"
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
pytest.skip("FP8 not supported on rocm GPU.")
if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
pytest.skip("FP8 block scaling not supported on rocm GPU.")
if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
pytest.skip("MXFP8 not supported on rocm GPU.")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
......@@ -353,17 +413,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
def test_make_graphed_callables_with_fp8_weight_caching(
*,
module: str,
dtype: torch.dtype,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
) -> None:
test_make_graphed_callables(
module=module,
dtype=torch.float32,
dtype=dtype,
fp8_params=fp8_params,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
......@@ -415,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention(
model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10,
fp8_enabled=False,
enabled=False,
)
# Forward and backward passes.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch import (
autocast,
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
Float8CurrentScalingQuantizer,
)
import transformer_engine.pytorch.ops as te_ops
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
def test_custom_recipe_sanity(module_type):
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
# Simple linear layer with dims divisible by 16
in_features = 64
out_features = 64
batch = 32
if module_type == "Linear":
model = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
elif module_type == "LayerNormLinear":
model = LayerNormLinear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
elif module_type == "LayerNormMLP":
# hidden_size == in_features == out_features for simplicity
model = LayerNormMLP(
hidden_size=in_features, ffn_hidden_size=out_features, params_dtype=torch.bfloat16
).cuda()
else:
# OpsLinear path
model = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Single factory: map roles to quantizers
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
# Execute with custom recipe
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp)
loss = out.float().sum()
loss.backward()
# Basic sanity: gradients exist
assert inp.grad is not None
def test_custom_recipe_grouped_linear_sanity():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
num_gemms = 3
in_features = 64
out_features = 64
batch = 32
base = batch // num_gemms
rem = batch % num_gemms
m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)]
model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda()
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp, m_splits)
loss = out.float().sum()
loss.backward()
assert inp.grad is not None
def test_custom_recipe_matches_current_scaling():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(123)
in_features = 64
out_features = 64
batch = 32
# Create two identical models
model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
model_custom = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
model_custom.load_state_dict(model_ref.state_dict())
# Identical inputs for both paths
base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16)
inp_ref = base_inp.clone().detach().requires_grad_(True)
inp_custom = base_inp.clone().detach().requires_grad_(True)
# Reference: use Float8CurrentScaling recipe
ref_recipe = recipe.Float8CurrentScaling()
with autocast(enabled=True, recipe=ref_recipe):
out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3
assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2
assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2
# Stress dynamic range in grad_output
scale = torch.ones(out_features, device="cuda", dtype=torch.float32)
scale[0] = 1e8
scale[1] = 1e-8
loss_ref = (out_ref.float() * scale.view(1, -1)).sum()
loss_ref.backward()
# Custom: single factory returning quantizers per role to match Float8CurrentScaling
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with autocast(enabled=True, recipe=custom_recipe):
out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3
assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2
assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2
loss_custom = (out_custom.float() * scale.view(1, -1)).sum()
loss_custom.backward()
# Compare forward outputs (exact match expected)
assert torch.allclose(out_ref, out_custom, rtol=0.0, atol=0.0)
# Compare input gradients
assert inp_ref.grad is not None and inp_custom.grad is not None
assert torch.allclose(inp_ref.grad, inp_custom.grad, rtol=0.0, atol=0.0)
# Compare parameter gradients (weights and bias if present)
ref_params = dict(model_ref.named_parameters())
custom_params = dict(model_custom.named_parameters())
for name, p_ref in ref_params.items():
p_cus = custom_params[name]
assert p_ref.grad is not None and p_cus.grad is not None
assert torch.allclose(p_ref.grad, p_cus.grad, rtol=0.0, atol=0.0)
def test_custom_recipe_ops_linear_2_1_layout():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(7)
in_features = 64
out_features = 64
batch = 16
# Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer
op = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
with autocast(enabled=True, recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
assert inp.grad is not None
def test_custom_recipe_factory_invocation_counts_and_cycling():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(13)
in_features = 64
out_features = 64
batch = 8
op = Linear(in_features, out_features, params_dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Counters per role
counts = {
"linear_input": 0,
"linear_weight": 0,
"linear_output": 0,
"linear_grad_output": 0,
"linear_grad_input": 0,
}
def quantizer_factory(role):
if role in counts:
counts[role] += 1
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory).
with autocast(enabled=True, recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
# Single GEMM: forward should request input, weight, output; backward grad_output, grad_input
assert counts["linear_input"] == 1
assert counts["linear_weight"] == 1
assert counts["linear_output"] == 1
assert counts["linear_grad_output"] == 1
assert counts["linear_grad_input"] == 1
def test_factories_return_distinct_instances_and_buffers():
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
# Two calls should produce distinct quantizer objects and distinct tensor buffers
def factory():
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
q1 = factory()
q2 = factory()
assert q1 is not q2
assert q1.scale.data_ptr() != q2.scale.data_ptr()
assert q1.amax.data_ptr() != q2.amax.data_ptr()
# Mutating one should not affect the other
q1.scale.fill_(123.0)
assert not torch.equal(q1.scale, q2.scale)
......@@ -4,7 +4,6 @@
import pytest
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
......
......@@ -4,13 +4,13 @@
import pytest
import torch
import transformer_engine as te
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
from transformer_engine.pytorch.fp8 import (blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
get_device_compute_capability,
)
from references.blockwise_quantizer_reference import CuBLASScaleMunger
......@@ -19,8 +19,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm
def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
supported = te.is_fp8_block_scaling_available()
emulated = get_device_compute_capability() >= (10, 0)
return supported and not emulated
def cublas_gemm_fp8_blockwise_case(
......
......@@ -8,14 +8,13 @@ import os
import pathlib
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
from transformer_engine.pytorch import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
get_device_compute_capability,
)
from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
......@@ -32,7 +31,8 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
recipe_available, reason_for_no_recipe = te.is_fp8_block_scaling_available(return_reason=True)
recipe_emulated = get_device_compute_capability() >= (10, 0)
class GetRecipes:
......@@ -219,6 +219,12 @@ def check_quantization_block_tiling_versus_reference(
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
te_dtype = TE_DType[quant_dtype]
if tile_size in ((1, 128), (1, 64)):
block_scaling_dim = 1
......@@ -414,6 +420,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
tile_size: Tuple[int, int],
extrema_high: bool,
) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
if blockwise_fp8_block_len != tile_size[1]:
......
......@@ -8,12 +8,9 @@ import torch
import pytest
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.fp8 import int8_simulation_fp8
......@@ -25,7 +22,7 @@ if tensor_dump_dir_env is not None:
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class GetRecipes:
......@@ -274,6 +271,14 @@ class TestFP8RecipeLinearBase:
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod
def run_linear(
cls,
......@@ -388,7 +393,7 @@ class TestFP8RecipeLinearBase:
# recipe1
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
with autocast(enabled=True, recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
else:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
......@@ -396,7 +401,7 @@ class TestFP8RecipeLinearBase:
# recipe2
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
with autocast(enabled=True, recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
else:
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
......@@ -611,7 +616,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe1
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
with autocast(enabled=True, recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x,
w,
......@@ -633,7 +638,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe2
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
with autocast(enabled=True, recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x,
w,
......
......@@ -11,12 +11,11 @@ import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
from transformer_engine.pytorch import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
get_device_compute_capability,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
......
......@@ -11,13 +11,11 @@ import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
......@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
# delayed scaling
......
......@@ -11,14 +11,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class TestFusedOptimizer:
......@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context = quantized_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
......@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
store_param_remainders=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
......@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
......@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
model = MultiheadAttention(
......@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16
with fp8_model_init(enabled=True, recipe=DelayedScaling()):
with quantized_model_init(enabled=True, recipe=DelayedScaling()):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
......
......@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast():
rope_layer = RotaryPositionEmbedding(128)
rope_embeddings_no_autocast = rope_layer(max_seq_len=1024)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
rope_embeddings_autocast = rope_layer(max_seq_len=1024)
torch.testing.assert_close(
rope_embeddings_no_autocast.to(dtype=torch.bfloat16),
rope_embeddings_autocast.to(dtype=torch.bfloat16),
atol=1e-8,
rtol=1e-8,
)
......@@ -7,8 +7,6 @@ from __future__ import annotations
from collections.abc import Iterable
import io
import math
import pathlib
import sys
from typing import Optional
import pytest
......@@ -18,7 +16,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
......@@ -29,20 +26,18 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd,
ForwardLinearScaleAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
from transformer_engine.pytorch import (
QuantizedTensor,
Float8CurrentScalingQuantizer,
Float8Quantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe, reset_rng_states
from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
if IS_HIP_EXTENSION:
import os
......@@ -52,13 +47,14 @@ if IS_HIP_EXTENSION:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
if is_bf16_available(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16)
# Supported devices
......@@ -70,6 +66,8 @@ if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
_quantization_list.append("mxfp8")
if nvfp4_available:
_quantization_list.append("nvfp4")
def maybe_skip_quantization(
......@@ -77,6 +75,7 @@ def maybe_skip_quantization(
*,
dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
"""Skip test case if a quantization scheme is not supported"""
......@@ -84,12 +83,17 @@ def maybe_skip_quantization(
if quantization is None:
return
# Check if quantization scheme is supported
# Check if quantization scheme is supported on device
if device is not None and torch.device(device).type != "cuda":
pytest.skip("Quantization is only supported on CUDA devices")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Check dims
if dims is not None:
if not isinstance(dims, Iterable):
dims = (dims,)
......@@ -99,10 +103,14 @@ def maybe_skip_quantization(
elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")
elif quantization == "nvfp4":
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("NVFP4 GEMMs require dims that are divisible by 16")
# Check if device is supported
if device is not None and torch.device(device).type != "cuda":
pytest.skip("Quantization is only supported on CUDA devices")
# Check dtype
if dtype is not None:
if quantization == "nvfp4" and dtype != torch.bfloat16:
pytest.skip("NVFP4 quantization is only supported with BF16 data")
@torch.no_grad()
......@@ -152,6 +160,14 @@ def make_reference_and_test_tensors(
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
......@@ -361,7 +377,7 @@ class TestFuser:
)
# Construct model
with te.fp8_model_init(recipe=recipe):
with te.quantized_model_init(recipe=recipe):
model = te_ops.basic.BasicLinear(
size,
size,
......@@ -393,7 +409,7 @@ class TestFuser:
)
# Training step
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
......@@ -406,12 +422,12 @@ class TestFuser:
torch.testing.assert_close(
y,
torch.full_like(y, y_val_ref),
**dtype_tols(tex.DType.kFloat8E4M3),
**quantization_tols("fp8_delayed_scaling"),
)
torch.testing.assert_close(
x.grad,
torch.full_like(x.grad, dx_val_ref),
**dtype_tols(tex.DType.kFloat8E5M2),
**quantization_tols("fp8_delayed_scaling"),
)
# Check that scaling factors match expected
......@@ -445,7 +461,8 @@ class TestFuser:
# Skip invalid configurations
in_shape = (size, size)
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype)
maybe_skip_quantization(quantization, dtype=final_dtype)
# Random data
dtype = torch.float32
......@@ -461,7 +478,7 @@ class TestFuser:
)
# Construct operation
with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad():
op.weight.copy_(w_test)
......@@ -513,11 +530,12 @@ class TestFuser:
# Skip invalid configurations
in_shape = (size, size)
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype)
maybe_skip_quantization(quantization, dtype=autocast_dtype)
# Construct operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weights, recipe=recipe):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass
......@@ -527,7 +545,7 @@ class TestFuser:
device=device,
requires_grad=True,
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x)
y.backward(torch.zeros_like(y))
......@@ -540,7 +558,7 @@ class TestFuser:
x.grad = None
op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
......@@ -569,7 +587,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -635,7 +653,7 @@ class TestBasicOps:
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
maybe_skip_quantization(quantization, device=device)
maybe_skip_quantization(quantization, device=device, dtype=dtype)
with_quantization = quantization is not None
# Random data
......@@ -701,7 +719,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -763,7 +781,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape)
......@@ -790,7 +808,7 @@ class TestBasicOps:
# Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization)
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
......@@ -830,7 +848,7 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
quantization_needed = any(
(
......@@ -887,7 +905,7 @@ class TestBasicOps:
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
......@@ -904,7 +922,7 @@ class TestBasicOps:
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -913,7 +931,7 @@ class TestBasicOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute or quantized_output or quantized_grad_input:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1024,7 +1042,7 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
......@@ -1065,7 +1083,7 @@ class TestBasicOps:
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.Linear(
in_features,
out_features,
......@@ -1081,7 +1099,7 @@ class TestBasicOps:
del b_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
......@@ -1091,7 +1109,7 @@ class TestBasicOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1128,7 +1146,7 @@ class TestBasicOps:
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -1182,14 +1200,14 @@ class TestBasicOps:
op,
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1298,7 +1316,7 @@ class TestBasicOps:
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -1344,14 +1362,14 @@ class TestBasicOps:
op,
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1431,7 +1449,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
......@@ -1470,8 +1488,11 @@ class TestBasicOps:
# Check results
tols = dtype_tols(dtype)
if with_quantization:
tols = dtype_tols(x1_test._fp8_dtype)
if in_place:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"):
tols = dtype_tols(x1_test._fp8_dtype)
elif quantization == "nvfp4":
tols = dtype_tols(x1_test._fp4_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
......@@ -1500,7 +1521,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -1573,7 +1594,7 @@ class TestBasicOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if cache_quantized_input:
maybe_skip_quantization("fp8_current_scaling", device=device)
......@@ -1641,14 +1662,16 @@ class TestBasicOps:
make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3)
if quantized_compute:
tols = quantization_tols(quantization)
elif cache_quantized_input:
tols = quantization_tols("fp8_current_scaling")
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1679,7 +1702,7 @@ class TestBasicOps:
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
......@@ -1706,13 +1729,87 @@ class TestBasicOps:
te_ops.SwiGLU(),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_clamped_swiglu(
self,
*,
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
limit: float = 0.75,
alpha: float = 1.702,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute and quantization == "nvfp4":
tols = dtype_tols(tex.DType.kFloat4E2M1)
elif quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
......@@ -1781,7 +1878,7 @@ class TestBasicOps:
# Skip invalid configurations
quantized_input = quantization is not None
maybe_skip_quantization(quantization, dims=shape, device=device)
maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype)
# Random data
# Note: Shift values to make sure inputs are non-zero
......@@ -1872,7 +1969,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if dtype not in (torch.float16, torch.bfloat16):
pytest.skip(
......@@ -1913,7 +2010,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -1929,7 +2026,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -1943,7 +2040,7 @@ class TestFusedOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1979,7 +2076,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
......@@ -2023,7 +2120,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2040,7 +2137,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
......@@ -2054,7 +2151,7 @@ class TestFusedOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -2093,7 +2190,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
......@@ -2130,7 +2227,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2146,7 +2243,7 @@ class TestFusedOps:
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
......@@ -2161,7 +2258,7 @@ class TestFusedOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -2194,7 +2291,7 @@ class TestFusedOps:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8")
......@@ -2237,7 +2334,7 @@ class TestFusedOps:
with torch.no_grad():
model[1].bias.copy_(b_test)
del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -2256,7 +2353,7 @@ class TestFusedOps:
# Expected numerical error
tols = dtype_tols(dtype)
if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -2375,7 +2472,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
......@@ -2415,7 +2512,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear(
......@@ -2429,7 +2526,7 @@ class TestFusedOps:
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
......@@ -2443,7 +2540,7 @@ class TestFusedOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
......@@ -2479,7 +2576,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
......@@ -2511,7 +2608,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2525,7 +2622,7 @@ class TestFusedOps:
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
(y_test * dy_test).sum().backward()
......@@ -2539,7 +2636,7 @@ class TestFusedOps:
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -2580,12 +2677,12 @@ class TestCheckpointing:
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
# Construct model
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
......@@ -2596,7 +2693,7 @@ class TestCheckpointing:
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(x)
y.backward(dy)
optim_save.step()
......@@ -2625,14 +2722,14 @@ class TestCheckpointing:
ys_save = []
for i in range(post_checkpoint_steps):
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(xs_save[i])
y.backward(dys[i])
optim_save.step()
ys_save.append(y)
# Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
......@@ -2645,7 +2742,7 @@ class TestCheckpointing:
ys_load = []
for i in range(post_checkpoint_steps):
optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_load(xs_load[i])
y.backward(dys[i])
optim_load.step()
......@@ -2706,7 +2803,7 @@ class TestSequentialModules:
ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed:
......@@ -2732,7 +2829,7 @@ class TestSequentialModules:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm":
norm = te_ops.LayerNorm(
hidden_size,
......@@ -2763,6 +2860,6 @@ class TestSequentialModules:
dtype=dtype,
)
forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -6,7 +6,7 @@ import pytest
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch import TransformerLayer
class SimpleTEModel(PreTrainedModel):
......
......@@ -5,7 +5,7 @@
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch
import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply
......
......@@ -13,18 +13,15 @@ import torch.nn as nn
from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
autocast,
quantized_model_init,
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
......@@ -36,27 +33,29 @@ from transformer_engine.pytorch import (
LayerNorm,
Fp8Padding,
Fp8Unpadding,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
get_device_compute_capability,
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_bf16_available,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available = is_fp8_block_scaling_available(return_reason=True)
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -82,7 +81,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
if is_bf16_available(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
......@@ -553,7 +552,7 @@ def _test_e2e_selective_recompute(
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -580,7 +579,7 @@ def _test_e2e_selective_recompute(
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
with autocast(enabled=fp8, recipe=recipe):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
......@@ -649,7 +648,7 @@ def _test_e2e_full_recompute(
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -677,7 +676,7 @@ def _test_e2e_full_recompute(
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
with autocast(enabled=fp8, recipe=recipe):
if recompute:
te_out = te_checkpoint(
block,
......@@ -1107,7 +1106,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
)
inp_hidden_states.retain_grad()
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
with autocast(enabled=fp8, recipe=recipe):
out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)):
out = out[0]
......@@ -1328,7 +1327,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
......@@ -1782,7 +1781,7 @@ def _test_grouped_linear_accuracy(
else:
m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
with autocast(enabled=fp8, recipe=recipe):
if isinstance(block, GroupedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
......@@ -1850,7 +1849,7 @@ def test_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
......@@ -1994,7 +1993,7 @@ def test_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
......@@ -2154,7 +2153,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
with autocast(enabled=fp8, recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
......@@ -2208,7 +2207,7 @@ def test_padding_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
......@@ -2219,7 +2218,7 @@ def test_padding_grouped_linear_accuracy(
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
......@@ -2285,7 +2284,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
......@@ -2296,7 +2295,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
......@@ -2446,7 +2445,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
with quantized_model_init(enabled=fp8_model_params, recipe=recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -2473,7 +2472,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
with autocast(enabled=True, recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
......
......@@ -33,10 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.onnx_extensions import te_translation_table
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt
......@@ -59,8 +58,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
# The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_recipes = []
if mxfp8_available:
......@@ -179,8 +178,8 @@ def do_export(
input_names = input_names or ["input"]
output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
with torch.inference_mode(), te.autocast(
enabled=fp8_recipe is not None, recipe=fp8_recipe
), warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
......@@ -234,8 +233,8 @@ def te_infer(
fp8_recipe: recipe.Recipe,
):
"""Transformer Engine forward propagation."""
with torch.inference_mode(), te.fp8_autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe
with torch.inference_mode(), te.autocast(
enabled=is_fp8, recipe=fp8_recipe
), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
......@@ -441,7 +440,7 @@ def _test_export_linear(
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
device="cuda"
)
......@@ -507,7 +506,7 @@ def _test_export_layernorm(
fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx"
with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm
model = layernorm_cls(
hidden_size,
......@@ -577,7 +576,7 @@ def _test_export_layernorm_linear(
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = te.LayerNormLinear(
hidden_size,
3 * hidden_size,
......@@ -673,7 +672,7 @@ def _test_export_layernorm_mlp(
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = te.LayerNormMLP(
hidden_size,
ffn_hidden_size,
......@@ -1215,13 +1214,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
).eval()
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
out_ref = model(*inps)
onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx")
os.close(onnx_fd)
try:
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
with te.onnx_export(enabled=True):
torch.onnx.export(
model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment