Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import torch
def scale_from_amax_tensor(
x_dtype: torch.dtype,
amax: torch.Tensor,
quant_dtype: torch.dtype,
*,
eps: float,
pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
...@@ -6,63 +6,16 @@ import torch ...@@ -6,63 +6,16 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType_To_Torch from transformer_engine.pytorch.constants import TE_DType_To_Torch
from references.quantize_scale_calc import scale_from_amax_tensor
# Compute scale and scale_inv from amax
def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
# option1: set scale to fp32 max when scale is inf
scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale)
# option2: when scale is inf, set scale to 1
scale = torch.where(scale == torch.inf, 1.0, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
# TODO: If/when adding a URM option an option is to cap to 126
# rather than allowing the full range of FP32 (2 - 2^23) x 2^127
# addresses cases where adding a mantissa overflows into inf scales.
# Not necessary currently without additional scale smudging options.
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv
# compute amax and scale # compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32) x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1) amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor." return scale_from_amax_tensor(
fp8_max = torch.finfo(quant_dtype).max torch.float32, amax, quant_dtype, eps=eps, pow_2_scales=pow_2_scales
)
scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
return scale, scale_inv, amax
def _multi_dim_transpose(tensor): def _multi_dim_transpose(tensor):
...@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast( ...@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast(
qx_t = _multi_dim_transpose(qx) qx_t = _multi_dim_transpose(qx)
sx_t = sx sx_t = sx
return qx, sx, qx_t, sx_t return qx, sx, qx_t, sx_t
def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
...@@ -2,41 +2,84 @@ ...@@ -2,41 +2,84 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import os
from contextlib import nullcontext
import pytest import pytest
import torch import torch
from contextlib import nullcontext
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Check if FP8 supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [
None, # non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe.Float8CurrentScaling(),
recipe.DelayedScaling(),
]
SIZE = 512 SIZE = 512
NUM_HEADS = 8
NUM_LAYERS = 5
EPSILON = 0.1
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
models = { # Offloading is supported for attention only for fused and flash attention backends,
"linear": te.Linear, # so the use of bfloat16 is required.
"layernorm_mlp": te.LayerNormMLP, #
"layernorm_linear": te.LayerNormLinear, # For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
} }
def _get_input(): def _get_input():
return torch.empty((128, SIZE, SIZE)).cuda() return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()
def _get_fp8_weight_cache_size(models, fp8_recipe):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if fp8_recipe is None:
return 0
def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): params_bytes = 0
for model in models:
for name, param in model.named_parameters():
if "weight" in name:
params_bytes += param.numel()
input_layer = model_cls(SIZE, SIZE) # One byte for columnwise and one byte for rowwise,
hidden_layer = model_cls(SIZE, SIZE) # hence multiply by 2 and convert to MB
output_layer = model_cls(SIZE, SIZE) # there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
input = _get_input()
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
tensor = _get_input()
if cpu_offload: if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context( offload_context, sync_function = te.get_cpu_offload_context(
enabled=True, enabled=True,
num_layers=2, num_layers=len(models) - 1,
model_layers=3, model_layers=len(models),
offload_activations=True, offload_activations=True,
offload_weights=False, offload_weights=False,
) )
...@@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): ...@@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
offload_context = nullcontext() offload_context = nullcontext()
sync_function = lambda x: x sync_function = lambda x: x
with te.fp8_autocast(enabled=fp8), offload_context: for model in models:
out = input_layer(input) with te.fp8_autocast(
out = sync_function(out) enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
with te.fp8_autocast(enabled=fp8), offload_context: ), offload_context:
out = hidden_layer(out) tensor = model(tensor)
out = sync_function(out) tensor = sync_function(tensor)
with te.fp8_autocast(enabled=fp8), offload_context:
out = output_layer(out)
out = sync_function(out)
max_mem_used = torch.cuda.memory_allocated() / 1024**2
out.sum().backward()
del input_layer
del hidden_layer
del output_layer
del input
del out
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize() torch.cuda.synchronize()
return max_mem_used return max_mem_used
@pytest.mark.parametrize("fp8", [True, False]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", models.keys()) @pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8, model_key) -> None: def test_cpu_offload(fp8_recipe, model_key) -> None:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
if fp8 and not fp8_available: We expect the memory consumption of configurations (2) and (3) to be similar, with
pytest.skip(reason_for_no_fp8) the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
model_cls = models[model_key] model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) if fp8_recipe and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
without_offloading_one_layer = _measure_memory_between_forward_and_backward(
models_list[:1], fp8_recipe, False
)
with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
assert with_offloading < without_offloading assert with_offloading < without_offloading
# The only difference between the memory consumption of with_offloading
# and without_offloading_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
assert (
memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
)
...@@ -30,6 +30,9 @@ if IS_HIP_EXTENSION: ...@@ -30,6 +30,9 @@ if IS_HIP_EXTENSION:
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
...@@ -58,6 +61,7 @@ fp8_recipes = [ ...@@ -58,6 +61,7 @@ fp8_recipes = [
recipe.DelayedScaling(), recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(), recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
] ]
# Supported data types # Supported data types
...@@ -328,9 +332,13 @@ def test_make_graphed_callables( ...@@ -328,9 +332,13 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8: if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
kwargs = dict( kwargs = dict(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
def cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
atol: float = 0.0,
rtol: float = 0.0
):
if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_x_1d_scaled or is_w_1d_scaled):
pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile")
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
# generate random input and weight
if noise_type == "uniform":
x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude
w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude
elif noise_type == "normal":
x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude
w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude
else:
assert False
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude
else:
out = None
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
w_te_dtype = TE_DType[w_dtype]
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Reference GEMM
ref_gemm = CuBLASRefBlockwiseGemm()
scale_decoder = CuBLASScaleMunger()
qx_data = (
qx._columnwise_data.view(dtype=x_dtype)
if x_columnwise
else qx._rowwise_data.view(dtype=x_dtype)
)
qw_data = (
qw._columnwise_data.view(dtype=w_dtype)
if w_columnwise
else qw._rowwise_data.view(dtype=w_dtype)
)
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y_ref = ref_gemm.qgemm(
qx=qx_data,
qw=qw_data,
out_dtype=out_dtype,
demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape
),
demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape
),
quant_tile_shape_x=x_quant_tile_shape,
quant_tile_shape_w=w_quant_tile_shape,
bias=bias,
out=out.clone() if accumulate else None,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)
# Allocate cuBLAS workspace
workspace_size = 0
workspace = torch.empty(0, dtype=torch.uint8, device=device)
transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
out_quantizer = None
assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM"
aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None
aux_tensor_ref = aux_tensor.clone() if use_gelu else None
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y = tex.generic_gemm(
qw,
transa,
qx,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# just in case of accumulation, make sure y_ref and y are not the same tensor
assert y_ref is not y, "y_ref and y should not be the same tensor"
# Reset nans to zeros because torch.assert_close does not assume nans to be equal
assert not torch.isnan(y_ref.float()).all(), "All elements are nan"
y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref)
y = torch.where(y.isnan(), torch.zeros_like(y), y)
if use_gelu:
# Check
if use_grad:
# With use_grad, GEMM should use aux tensor to calculate
# gradient
gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None)
# TODO: How do we decide whether this is acceptably close?
# Could also try to put the activation inside the reference
# before the output cast to see different tolerances.
torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2)
else:
# aux tensor is pre-gelu aux output. Verify against y_ref.
torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol)
act = torch.nn.GELU()
gelu_ref = act(y_ref)
# gelu_ref = tex.gelu(y_ref, None)
torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol)
else:
torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
def cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED",
expected_err_cls=RuntimeError
):
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
# generate random input and weight
x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0
w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device)
else:
out = None
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
w_te_dtype = TE_DType[w_dtype]
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Allocate cuBLAS workspace
workspace_size = 0
workspace = torch.empty(0, dtype=torch.uint8, device=device)
transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
out_quantizer = None
grad = use_grad
gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device)
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
with pytest.raises(expected_err_cls, match=expected_err_msg):
y = tex.generic_gemm(
qw,
transa,
qx,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
gelu_in,
grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(128, 128, 128),
(256, 128, 256),
# non 128x128 divisible input shapes
(320, 128, 336),
(320, 64, 336),
# k > 128
(256, 256, 256),
(320, 256, 336),
(1024, 4096, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_cublas_gemm_fp8_blockwise_shape_varying(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
(256, 128, 256),
(320, 256, 336),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
# non 128x128 divisible input shapes
(320, 64, 336),
# k > 128
(256, 256, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_cublas_gemm_fp8_blockwise_bias(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_bias=True,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
# non 128x128 divisible input shapes
(16, 128, 128),
(320, 64, 336),
# k > 128
(4096, 128, 4096),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
@pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise",
[
(True, False),
(True, True),
(False, True),
],
ids=["colxrow", "colxcol", "rowxcol"],
)
def test_cublas_gemm_fp8_blockwise_columnwise(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
is_x_columnwise,
is_w_columnwise,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
x_columnwise=is_x_columnwise,
w_columnwise=is_w_columnwise,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
# non 128x128 divisible input shapes
(320, 64, 336),
# k > 128
(256, 256, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
@pytest.mark.parametrize(
"use_grad",
[
True,
],
ids=["grad"],
)
def test_cublas_gemm_fp8_gelu(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_grad,
):
# NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed
# so the epilogue is disabled on the transformer engine side.
if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled):
pytest.skip(
"CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)."
)
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_gelu=True,
use_grad=use_grad,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_split_accumulator_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_bgrad_not_supported(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
# NOTE: BGRAD epilogue is not supported for fp8.
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_grad=True,
use_bias=True,
expected_err_msg="Epilogue requested outside of the available",
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"])
@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_gelu_unsupported_cases_error(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_bias,
use_grad,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
if use_grad and not use_bias and out_dtype == torch.bfloat16:
pytest.skip("DGELU epilogue is supported for bfloat16.")
elif use_grad and not use_bias:
expected_err = "an unsupported value or parameter was passed"
else:
expected_err = "Epilogue requested outside of the available"
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_grad=use_grad,
use_bias=use_bias,
use_gelu=True,
expected_err_msg=expected_err,
)
@pytest.mark.parametrize(
"M, K, N",
[
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_illegal_dtype_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
# e5m2 by e5m2 not supported.
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(False, False),
],
ids=["2Dx2D"],
)
def test_illegal_2D_by_2D_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
# 2D block quantization by 2D block quantization is not supported.
expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported"
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
expected_err_msg=expected_err_msg,
)
@pytest.mark.parametrize(
"M, K, N, legalX1d, legalX2d",
[
# M dim unconstrained when X is 2D.
(255, 128, 256, False, True),
# K must be multiple of 16
(256, 120, 256, False, False),
# N must be a multiple of 8
(256, 128, 252, False, False),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(False, True),
(True, True),
],
ids=["1Dx2D", "2Dx1D", "1Dx1D"],
)
def test_unaligned_shapes(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
legalX1d,
legalX2d,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
legal = legalX1d if is_x_1d_scaled else legalX2d
if not legal:
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
expected_err_msg="dimension requirement",
)
else:
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
"uniform", # noise type
1.0, # x_magnitude
1.0, # w_magnitude
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import math
import os
import pathlib
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
QuantizeResult,
)
from test_float8_current_scaling_exact import (
TestFP8RecipeLinearBase,
TestFP8RecipeLayerNormLinearBase,
)
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps"
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
class GetRecipes:
@staticmethod
def none():
return None
@staticmethod
def fp8_blockwise():
# return default configs
return Float8BlockScaling()
def initialize_for_many_scales(
x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str
) -> torch.Tensor:
"""
Put separate distributions into each quantization tile
to avoid many tiles having similar scale values and
causing false passes.
"""
tile_grid_shape = (
math.ceil(x_shape_2d[0] / tile_shape[0]),
math.ceil(x_shape_2d[1] / tile_shape[1]),
)
# Arbitrary size
max_val = 8192.0
# Make a uniform distribution of [-max_val, max_val]
tile_extrema = torch.rand(*tile_grid_shape, dtype=dtype) * max_val * 2 - max_val
result = torch.empty(x_shape_2d, dtype=dtype, device=device)
tile_elements = tile_shape[0] * tile_shape[1]
for i in range(tile_grid_shape[0]):
for j in range(tile_grid_shape[1]):
target = tile_extrema[i, j].item()
step = target / (tile_elements)
if target == 0:
tile = torch.zeros(tile_shape, dtype=dtype, device=device)
else:
tile = torch.arange(0.0, target, step=step, dtype=dtype, device=device)
tile = tile.reshape(*tile_shape)
min_dst_vals = (i * tile_shape[0], j * tile_shape[1])
max_dst_vals = (
min((i + 1) * tile_shape[0], x_shape_2d[0]),
min((j + 1) * tile_shape[1], x_shape_2d[1]),
)
max_src_vals = (
max_dst_vals[0] - min_dst_vals[0],
max_dst_vals[1] - min_dst_vals[1],
)
result[min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1]] = tile[
: max_src_vals[0], : max_src_vals[1]
]
return result
def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
return_transpose: bool,
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
block_scaling_dim = 1
elif tile_size == (128, 128):
block_scaling_dim = 2
else:
raise ValueError("Non support tile size")
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=return_transpose,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=block_scaling_dim,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=return_transpose,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask = torch.ones(
(math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device
)
scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend(
QuantizeResult(qx, scale_mask, None, None), tile_size
).scale
sx = sx * scale_mask
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
scale_mask = torch.ones(
(math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])),
device=sx_t.device,
)
scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend(
QuantizeResult(qx_t, scale_mask, None, None), tile_size
).scale
sx_t = sx_t * scale_mask
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
else:
# should be None
assert qx_t is None and qx_t_ref is None
assert sx_t is None and sx_t_ref is None
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
return_transpose: bool,
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(256, 256),
(2048, 1024),
# Padding required cases
(256, 272),
(303, 300),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
def test_quantization_block_tiling_versus_reference_fp32_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
return_transpose: bool,
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)])
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
def test_quantization_block_tiling_extrema_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
tile_size: Tuple[int, int],
extrema_high: bool,
) -> None:
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
block_scaling_dim = 1
elif tile_size == (128, 128):
block_scaling_dim = 2
else:
raise ValueError("Non support tile size")
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=False,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=block_scaling_dim,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
return_transpose = False
# Input
if extrema_high:
x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device)
else:
x = torch.zeros((M, N), dtype=x_dtype, device=device)
# Run cast and transpose kernel
# Internal call ops.quantize_tensorwise
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
sx = x_fp8_sut._rowwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=return_transpose,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
)
qx_ref, sx_ref = (
qresult_ref.data,
qresult_ref.scale,
)
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx.flatten()[0], sx_ref.flatten()[0], atol=0.0, rtol=0.0)
if extrema_high:
expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max
if pow_2_scales:
expected_value = math.floor(math.log2(expected_value))
expected_value = math.pow(2.0, expected_value)
expected_value = 1 / expected_value
elif not extrema_high and eps == 0:
expected_value = 1.0
else:
assert not extrema_high
# eps is small enough to trigger inf in quant_dtype_max / eps
if pow_2_scales:
expected_value = math.pow(2.0, -127)
else:
expected_value = 1 / torch.finfo(x_dtype).max
torch.testing.assert_close(
sx.flatten()[0],
torch.tensor(expected_value, device=sx.device),
atol=0.0,
rtol=0.0,
)
# FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1.6,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
...@@ -82,7 +82,8 @@ class TestFP8RecipeLinearBase: ...@@ -82,7 +82,8 @@ class TestFP8RecipeLinearBase:
@staticmethod @staticmethod
def _get_mean_abs_relative_error(a, b): def _get_mean_abs_relative_error(a, b):
return torch.mean(torch.abs((a - b) / b)) error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
return torch.mean(error)
@staticmethod @staticmethod
def _load_golden_tensor_values(a, b): def _load_golden_tensor_values(a, b):
...@@ -97,9 +98,12 @@ class TestFP8RecipeLinearBase: ...@@ -97,9 +98,12 @@ class TestFP8RecipeLinearBase:
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template # Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example if recipe.float8_current_scaling():
"ScalingType.PER_TENSOR" scaling_type = "ScalingType.PER_TENSOR"
) elif recipe.float8_block_scaling():
scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else:
scaling_type = "Unknown"
current_seed = torch.initial_seed() # Get the current seed current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = { expected_tensor_names = {
...@@ -437,9 +441,13 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -437,9 +441,13 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template # Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example if recipe.float8_current_scaling():
"ScalingType.PER_TENSOR" scaling_type = "ScalingType.PER_TENSOR"
) elif recipe.float8_block_scaling():
scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else:
scaling_type = "Unknown"
current_seed = torch.initial_seed() # Get the current seed current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = { expected_tensor_names = {
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections.abc import Iterable
import io
import math
from typing import Any, Dict, List, Tuple, Union
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.08),
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125),
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# TODO replace with call to fp8.py when recipe added.
recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8
reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFloat8BlockwiseTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
dtype: torch.dtype = torch.float32,
is_2D_scaled: bool = True,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
rowwise = True
columnwise = True
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=rowwise,
columnwise=columnwise,
block_scaling_dim=2 if is_2D_scaled else 1,
)
scale_dims = quantizer.get_scale_shape(dims, columnwise=False)
columnwise_scale_dims = quantizer.get_scale_shape(dims, columnwise=True)
columnwise_dims = quantizer.get_columnwise_shape(dims)
tensor = Float8BlockwiseQTensor(
shape=dims,
dtype=dtype,
rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
rowwise_scale_inv=torch.zeros(scale_dims, device="cuda", dtype=torch.float32),
columnwise_data=torch.zeros(columnwise_dims, device="cuda", dtype=torch.uint8),
columnwise_scale_inv=torch.zeros(
columnwise_scale_dims, device="cuda", dtype=torch.float32
),
fp8_dtype=fp8_dtype,
is_2D_scaled=is_2D_scaled,
quantizer=quantizer,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"
def _test_quantize_dequantize(
self,
quantizer: Float8BlockQuantizer,
dtype: torch.dtype = torch.float32,
dims: DimsType = (23, 128),
rtol: float = 0.0,
atol: float = 0.0,
dequant_columnwise: bool = False,
use_cpp_allocation: bool = False,
) -> None:
"""Check numerical error when casting to FP8 and back"""
dims = _to_list(dims)
# Initialize random data
# Note: Make sure values are not all close to zero, or else
# test may pass trivially.
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_ref.view(-1)[0] = 0.75
x_ref_cuda = x_ref.to("cuda")
# Cast to FP8 and back
if not use_cpp_allocation:
x_fp8 = quantizer.make_empty(shape=dims, device="cuda")
quantizer.update_quantized(x_ref_cuda, x_fp8)
else:
# This codepath allows the CPP binding to allocate the output
# tensor
x_fp8 = tex.quantize(x_ref_cuda, quantizer, None, None)
if dequant_columnwise:
# Strip out rowwise data to verify dequantization of
# columnwise data.
x_fp8.update_usage(rowwise_usage=False, columnwise_usage=True)
x_fp8 = x_fp8.dequantize(dtype=dtype).cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, rtol=rtol, atol=atol)
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_quantize_dequantize_dtypes(
self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int
) -> None:
atol = _tols[fp8_dtype]["atol"]
rtol = _tols[fp8_dtype]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=False,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("block_scaling_dim", [1])
def test_quantize_dequantize_columnwise_only(
self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int
) -> None:
atol = _tols[fp8_dtype]["atol"]
rtol = _tols[fp8_dtype]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=False,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(
quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
def test_quantize_dequantize_dims(
self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool
) -> None:
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dq_columnwise", [True, False])
def test_quantize_dequantize_dims_cpp_allocate_output(
self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool
) -> None:
atol = _tols[fp8_dtype]["atol"]
rtol = _tols[fp8_dtype]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
use_cpp_allocation=True,
)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None:
"""Test data accessors of Float8BlockwiseQTensor"""
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
fp8_dtype = tex.DType.kFloat8E4M3
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
# Create FP8 tensor
x_fp8 = quantizer.quantize(x_hp)
x_recovered = x_fp8.data
torch.testing.assert_close(x_recovered, x_hp, **_tols[fp8_dtype])
x_fp8.data = y_hp
y_recovered = x_fp8.data
torch.testing.assert_close(y_recovered, y_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
"""Test serialization of Float8BlockwiseQTensor"""
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
# Create FP8 tensor
x_fp8 = quantizer.quantize(x_hp)
# Save tensor
buffer = io.BytesIO()
torch.save(x_fp8, buffer)
# Load tensor
buffer.seek(0)
x_fp8_loaded = torch.load(buffer, weights_only=False)
# Test that loaded tensor matches original
assert isinstance(x_fp8_loaded, Float8BlockwiseQTensor)
torch.testing.assert_close(x_fp8_loaded._rowwise_data, x_fp8._rowwise_data)
torch.testing.assert_close(x_fp8_loaded._columnwise_data, x_fp8._columnwise_data)
torch.testing.assert_close(x_fp8_loaded._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
torch.testing.assert_close(x_fp8_loaded._columnwise_scale_inv, x_fp8._columnwise_scale_inv)
torch.testing.assert_close(x_fp8_loaded.data, x_fp8.data)
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
# Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize()
x_fp8_loaded_dequant = x_fp8_loaded.dequantize()
torch.testing.assert_close(x_fp8_loaded_dequant, x_fp8_dequant)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_inplace_ops(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test in-place operations"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
# Test in-place add
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
x_fp8.add_(y_fp8)
torch.testing.assert_close(x_fp8.dequantize(), x_hp + y_hp, **_tols[fp8_dtype])
# Test in-place subtract
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
x_fp8.sub_(y_fp8)
torch.testing.assert_close(x_fp8.dequantize(), x_hp - y_hp, **_tols[fp8_dtype])
# Test in-place multiply
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
x_fp8.mul_(y_fp8)
torch.testing.assert_close(x_fp8.dequantize(), x_hp * y_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_out_of_place_ops(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test out-of-place operations"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.quantize(x_hp.clone())
y_fp8 = quantizer.quantize(y_hp.clone())
# Test exact operations
torch.testing.assert_close(-x_fp8, -x_hp, **_tols[fp8_dtype])
torch.testing.assert_close(x_fp8.abs(), x_hp.abs(), **_tols[fp8_dtype])
# Test elementwise operations
torch.testing.assert_close(x_fp8 + y_fp8, x_hp + y_hp, **_tols[fp8_dtype])
torch.testing.assert_close(x_fp8 - y_fp8, x_hp - y_hp, **_tols[fp8_dtype])
torch.testing.assert_close(x_fp8 * y_fp8, x_hp * y_hp, **_tols[fp8_dtype])
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_hp), **_tols[fp8_dtype])
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_hp - y_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_view_same_shape(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view with same shape
x_view = x_fp8.view(*dims)
torch.testing.assert_close(x_view.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_view.shape == x_fp8.shape, "Shape changed after view with same dims"
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_reshape_same_shape(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test reshape operations that preserve tensor shape"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test reshape with same shape
x_reshape = x_fp8.reshape(*dims)
torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with same dims"
# Test reshape with -1 canonicalization
new_dims = [-1, dims[1]]
x_reshape = x_fp8.reshape(*new_dims)
torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with -1"
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_reshape.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_clone_detach(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int
) -> None:
"""Test clone and detach operations"""
device = "cuda"
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
x_fp8 = quantizer.quantize(x_hp.clone())
# Test clone
x_clone = x_fp8.clone()
torch.testing.assert_close(x_clone.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_clone.shape == x_fp8.shape, "Shape changed after clone"
# Test detach
x_detach = x_fp8.detach()
torch.testing.assert_close(x_detach.dequantize(), x_hp, **_tols[fp8_dtype])
assert x_detach.shape == x_fp8.shape, "Shape changed after detach"
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_clone.dequantize(), -x_hp, **_tols[fp8_dtype])
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from collections.abc import Iterable from collections.abc import Iterable
import io import io
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union, Optional
import pytest import pytest
import torch import torch
...@@ -158,6 +158,32 @@ class TestFloat8Tensor: ...@@ -158,6 +158,32 @@ class TestFloat8Tensor:
def test_quantize_dequantize_dims(self, dims: DimsType) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims) self._test_quantize_dequantize(dims=dims)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("noop", [True, False])
def test_quantize_dequantize_noop(
self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool
) -> None:
noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda")
if noop:
noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda")
dims = 23
scale: float = 3.5
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
# Cast to FP8 and back
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
# if noop, then when we input a different tensor, output should still be x_fp8_orig
x_ref_noop_test = 2 * x_ref.cuda()
x_fp8_orig = x_fp8.clone()
x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor)
if noop_tensor.item() == 1.0:
torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0)
else:
torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype])
def test_basic_ops( def test_basic_ops(
self, self,
dims: DimsType = 23, dims: DimsType = 23,
......
...@@ -360,6 +360,20 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -360,6 +360,20 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.bfloat16,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self): def test_fp8_exp_avg(self):
...@@ -389,6 +403,20 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -389,6 +403,20 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.bfloat16,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self): def test_fp8_exp_avg_sq(self):
......
...@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import ( ...@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import (
) )
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# Gradient is a broadcasted scalar # Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2 return output.sum() * 2
...@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: ...@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope( def test_fused_rope(
dtype: torch.dtype, dtype: torch.dtype,
seq_length: int, seq_length: int,
...@@ -85,6 +41,8 @@ def test_fused_rope( ...@@ -85,6 +41,8 @@ def test_fused_rope(
transpose: Union[Tuple, None], transpose: Union[Tuple, None],
tensor_format: str, tensor_format: str,
loss_func: Callable, loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None: ) -> None:
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
...@@ -99,14 +57,22 @@ def test_fused_rope( ...@@ -99,14 +57,22 @@ def test_fused_rope(
t = t.transpose(*transpose).contiguous().transpose(*transpose) t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(seq_length) emb = rotary_pos_emb(seq_length * cp_size)
assert emb.is_contiguous()
for cp_rank in range(cp_size):
# unfused # unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32 # The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison # for more accurate comparison
output_unfused = apply_rotary_pos_emb( output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False t.float(),
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
...@@ -118,7 +84,10 @@ def test_fused_rope( ...@@ -118,7 +84,10 @@ def test_fused_rope(
t, t,
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
interleaved=interleaved,
fused=True, fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward() loss_fused.backward()
...@@ -135,7 +104,8 @@ def test_fused_rope( ...@@ -135,7 +104,8 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3]) @pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope_thd( def test_fused_rope_thd(
dtype: torch.dtype, dtype: torch.dtype,
hidden_size: int, hidden_size: int,
...@@ -143,6 +113,7 @@ def test_fused_rope_thd( ...@@ -143,6 +113,7 @@ def test_fused_rope_thd(
transpose: Union[Tuple, None], transpose: Union[Tuple, None],
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool,
) -> None: ) -> None:
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
...@@ -170,15 +141,23 @@ def test_fused_rope_thd( ...@@ -170,15 +141,23 @@ def test_fused_rope_thd(
t = t.transpose(*transpose).contiguous().transpose(*transpose) t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(cu_seqlens_padded[-1]) emb = rotary_pos_emb(cu_seqlens_padded[-1])
assert emb.is_contiguous()
for cp_rank in range(cp_size): for cp_rank in range(cp_size):
# unfused # unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32 # The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison # for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd( output_unfused = apply_rotary_pos_emb(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank t.float(),
emb,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
...@@ -189,6 +168,7 @@ def test_fused_rope_thd( ...@@ -189,6 +168,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb( output_fused = apply_rotary_pos_emb(
t, t,
emb, emb,
interleaved=interleaved,
fused=True, fused=True,
tensor_format="thd", tensor_format="thd",
cu_seqlens=cu_seqlens_padded, cu_seqlens=cu_seqlens_padded,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import io
import math import math
from typing import Optional from typing import Optional
...@@ -1405,6 +1406,7 @@ class TestBasicOps: ...@@ -1405,6 +1406,7 @@ class TestBasicOps:
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("cache_quantized_input", (False, True))
def test_activation( def test_activation(
self, self,
*, *,
...@@ -1413,6 +1415,7 @@ class TestBasicOps: ...@@ -1413,6 +1415,7 @@ class TestBasicOps:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
cache_quantized_input: bool,
) -> None: ) -> None:
"""Activation functions""" """Activation functions"""
...@@ -1424,6 +1427,8 @@ class TestBasicOps: ...@@ -1424,6 +1427,8 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if cache_quantized_input:
maybe_skip_quantization("fp8", device=device)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
...@@ -1432,15 +1437,17 @@ class TestBasicOps: ...@@ -1432,15 +1437,17 @@ class TestBasicOps:
test_device=device, test_device=device,
test_is_fp8=quantized_compute, test_is_fp8=quantized_compute,
) )
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref: torch.Tensor y_ref: torch.Tensor
...@@ -1471,7 +1478,8 @@ class TestBasicOps: ...@@ -1471,7 +1478,8 @@ class TestBasicOps:
swiglu=te_ops.SwiGLU, swiglu=te_ops.SwiGLU,
)[activation] )[activation]
forward = te_ops.Sequential( forward = te_ops.Sequential(
make_op(), te_ops.Quantize(forward=False, backward=quantized_compute),
make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
...@@ -1480,9 +1488,9 @@ class TestBasicOps: ...@@ -1480,9 +1488,9 @@ class TestBasicOps:
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute: if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu": if activation == "relu" and not cache_quantized_input:
tols = {"atol": 0, "rtol": 0} tols = {"atol": 0, "rtol": 0}
# Check results # Check results
...@@ -1894,3 +1902,118 @@ class TestFusedOps: ...@@ -1894,3 +1902,118 @@ class TestFusedOps:
torch.testing.assert_close(y2_test, y2_ref, **tols) torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols)
class TestCheckpointing:
"""Tests for checkpointing"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
self,
*,
pre_checkpoint_steps: int = 2,
post_checkpoint_steps: int = 2,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""Check checkpointing with linear op"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
# Construct model
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25)
# Warmup training steps
for _ in range(pre_checkpoint_steps):
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_save(x)
y.backward(dy)
optim_save.step()
# Save checkpoint
byte_stream = io.BytesIO()
torch.save(
{"model": model_save.state_dict(), "optim": optim_save.state_dict()},
byte_stream,
)
checkpoint_bytes = byte_stream.getvalue()
del byte_stream
# Synthetic data for evaluation
xs_save = [
torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
for _ in range(post_checkpoint_steps)
]
with torch.no_grad():
xs_load = [x.clone().requires_grad_() for x in xs_save]
dys = [
torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps)
]
# Training steps with original model
ys_save = []
for i in range(post_checkpoint_steps):
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_save(xs_save[i])
y.backward(dys[i])
optim_save.step()
ys_save.append(y)
# Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
model_load.load_state_dict(state_dict["model"])
optim_load.load_state_dict(state_dict["optim"])
# Training steps with loaded model
ys_load = []
for i in range(post_checkpoint_steps):
optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_load(xs_load[i])
y.backward(dys[i])
optim_load.step()
ys_load.append(y)
# Check that original and loaded model match exactly
tols = {"rtol": 0, "atol": 0}
for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
torch.testing.assert_close(param_load, param_save, **tols)
torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
for y_load, y_save in zip(ys_load, ys_save):
torch.testing.assert_close(y_load, y_save, **tols)
for x_load, x_save in zip(xs_load, xs_save):
torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
...@@ -9,9 +9,10 @@ import transformer_engine.pytorch as te ...@@ -9,9 +9,10 @@ import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax from references.quantize_scale_calc import scale_from_amax_tensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
input_size_pairs = [ input_size_pairs = [
(7777 * 77, 555 * 555), (7777 * 77, 555 * 555),
(777, 555), (777, 555),
...@@ -224,17 +225,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, ...@@ -224,17 +225,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55]) @pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("max_fp8", [448.0 if not IS_HIP_EXTENSION else 240.0, 57344.0]) @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("pow_2_scales", [False, True]) @pytest.mark.parametrize("pow_2_scales", [False, True])
@pytest.mark.parametrize("epsilon", [0.0, 100.0]) @pytest.mark.parametrize("epsilon", [0.0, 100.0])
def test_multi_tensor_compute_scale_and_scale_inv( def test_multi_tensor_compute_scale_and_scale_inv(
input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon input_size_pair, applier, repeat, fp8_dtype, pow_2_scales, epsilon
): ):
sizea, sizeb = input_size_pair sizea, sizeb = input_size_pair
device = torch.device("cuda") device = torch.device("cuda")
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
a = torch.randn([sizea], dtype=torch.float32, device=device).abs() a = torch.randn([sizea], dtype=torch.float32, device=device).abs()
b = torch.randn([sizeb], dtype=torch.float32, device=device).abs() b = torch.randn([sizeb], dtype=torch.float32, device=device).abs()
max_fp8 = torch.finfo(fp8_dtype).max
amax_list = [] amax_list = []
for i in range(repeat): for i in range(repeat):
...@@ -253,8 +255,8 @@ def test_multi_tensor_compute_scale_and_scale_inv( ...@@ -253,8 +255,8 @@ def test_multi_tensor_compute_scale_and_scale_inv(
) )
for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list):
scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax( scale_ref, scale_inv_ref, _ = scale_from_amax_tensor(
amax, max_fp8, epsilon, pow_2_scales torch.float32, amax, fp8_dtype, eps=epsilon, pow_2_scales=pow_2_scales
) )
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
...@@ -52,6 +52,9 @@ import transformer_engine_torch as tex ...@@ -52,6 +52,9 @@ import transformer_engine_torch as tex
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -108,6 +111,7 @@ fp8_recipes = [ ...@@ -108,6 +111,7 @@ fp8_recipes = [
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(), recipe.DelayedScaling(),
recipe.Float8CurrentScaling(), recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
] ]
...@@ -567,6 +571,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m ...@@ -567,6 +571,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -679,6 +685,8 @@ def test_gpt_full_activation_recompute( ...@@ -679,6 +685,8 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -1032,7 +1040,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1032,7 +1040,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config): def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
...@@ -1048,11 +1056,17 @@ def _test_granular_accuracy(block, bs, dtype, config): ...@@ -1048,11 +1056,17 @@ def _test_granular_accuracy(block, bs, dtype, config):
out = out[0] out = out[0]
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute:
block.backward_dw()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
for p in block.parameters(): for p in block.parameters():
if p.requires_grad: if p.requires_grad:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad) outputs.append(p.grad)
return outputs return outputs
...@@ -1187,6 +1201,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): ...@@ -1187,6 +1201,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
config = model_configs[model]
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if bias:
te_linear_ref.bias = Parameter(te_linear.bias.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
...@@ -1372,6 +1434,67 @@ def test_layernorm_linear_accuracy( ...@@ -1372,6 +1434,67 @@ def test_layernorm_linear_accuracy(
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
ln_linear_ref = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
if bias:
ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
if fuse_wgrad_accumulation:
weight = getattr(ln_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
ln_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
...@@ -1448,8 +1571,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret ...@@ -1448,8 +1571,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
if fuse_wgrad_accumulation:
ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32)
ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32)
ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_grouped_linear_accuracy( def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation block,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute=False,
): ):
reset_rng_states() reset_rng_states()
if fp8: if fp8:
...@@ -1466,7 +1659,6 @@ def _test_grouped_linear_accuracy( ...@@ -1466,7 +1659,6 @@ def _test_grouped_linear_accuracy(
if num_gemms > 1: if num_gemms > 1:
split_size = 1 split_size = 1
if fp8: if fp8:
if recipe.delayed():
split_size = 16 split_size = 16
if recipe.mxfp8(): if recipe.mxfp8():
split_size = 128 split_size = 128
...@@ -1492,6 +1684,12 @@ def _test_grouped_linear_accuracy( ...@@ -1492,6 +1684,12 @@ def _test_grouped_linear_accuracy(
) )
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute:
if isinstance(block, GroupedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
...@@ -1505,33 +1703,34 @@ def _test_grouped_linear_accuracy( ...@@ -1505,33 +1703,34 @@ def _test_grouped_linear_accuracy(
return outputs return outputs
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy( def test_grouped_linear_accuracy(
dtype, dtype,
num_gemms, num_gemms,
bs, bs,
model, model,
fp8,
recipe, recipe,
fp8_model_params, fp8_model_params,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip("MXFP8 unsupported for grouped linear.") pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1542,18 +1741,19 @@ def test_grouped_linear_accuracy( ...@@ -1542,18 +1741,19 @@ def test_grouped_linear_accuracy(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=bias,
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
Linear( Linear(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=bias,
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
...@@ -1567,6 +1767,7 @@ def test_grouped_linear_accuracy( ...@@ -1567,6 +1767,7 @@ def test_grouped_linear_accuracy(
with torch.no_grad(): with torch.no_grad():
for i in range(num_gemms): for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}") weight_i = getattr(grouped_linear, f"weight{i}")
...@@ -1578,10 +1779,26 @@ def test_grouped_linear_accuracy( ...@@ -1578,10 +1779,26 @@ def test_grouped_linear_accuracy(
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1" os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy( outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
) )
outputs = _test_grouped_linear_accuracy( outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
) )
# Shoule be bit-wise match # Shoule be bit-wise match
...@@ -1589,24 +1806,7 @@ def test_grouped_linear_accuracy( ...@@ -1589,24 +1806,7 @@ def test_grouped_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("parallel_mode", ["column", "row"]) @pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
bs=2,
model="126m",
fp8=True,
recipe=recipe,
fp8_model_params=True,
parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
)
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_single_gemm(recipe): def test_grouped_linear_accuracy_single_gemm(recipe):
"""Split the tests to save CI time""" """Split the tests to save CI time"""
test_grouped_linear_accuracy( test_grouped_linear_accuracy(
...@@ -1614,19 +1814,23 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -1614,19 +1814,23 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
num_gemms=1, num_gemms=1,
bs=2, bs=2,
model="126m", model="126m",
fp8=True,
recipe=recipe, recipe=recipe,
fp8_model_params=True, fp8_model_params=True,
fuse_wgrad_accumulation=True, fuse_wgrad_accumulation=True,
bias=True,
delay_wgrad_compute=False,
) )
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16.""" align_size = 16
if recipe.mxfp8():
align_size = 32
padded_tokens_per_expert = [ padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert (num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert
] ]
hidden_states = torch.split(hidden_states, tokens_per_expert) hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = [] padded_hidden_states = []
...@@ -1727,10 +1931,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1727,10 +1931,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip("MXFP8 unsupported for grouped linear.") pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1941,6 +2143,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): ...@@ -1941,6 +2143,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import pytest import pytest
from typing import Dict, List from typing import Dict, List
from transformer_engine.common import recipe
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
moe_permute as te_permute, moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs, moe_permute_with_probs as te_permute_with_probs,
...@@ -17,9 +18,14 @@ from transformer_engine.pytorch import ( ...@@ -17,9 +18,14 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex import transformer_engine_torch as tex
import copy
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -234,7 +240,6 @@ def _test_permutation_index_map( ...@@ -234,7 +240,6 @@ def _test_permutation_index_map(
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
) )
fp8 = False
# Convert TE dtypes to PyTorch dtypes # Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32: if te_dtype == tex.DType.kFloat32:
dtype = torch.float32 dtype = torch.float32
...@@ -242,45 +247,9 @@ def _test_permutation_index_map( ...@@ -242,45 +247,9 @@ def _test_permutation_index_map(
dtype = torch.float16 dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16: elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16 dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else: else:
pytest.skip("Invalid dtype.") pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_permute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input)
permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input)
unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
...@@ -323,9 +292,9 @@ def _test_permutation_index_map( ...@@ -323,9 +292,9 @@ def _test_permutation_index_map(
# TE Permutation # TE Permutation
# #
################################################################################################################################### ###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() te_permute_fwd_input = pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True) te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() te_permute_bwd_input = pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute( te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, indices, num_out_tokens, map_type="index" te_permute_fwd_input, indices, num_out_tokens, map_type="index"
...@@ -338,7 +307,7 @@ def _test_permutation_index_map( ...@@ -338,7 +307,7 @@ def _test_permutation_index_map(
te_probs.requires_grad_(True) te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True) te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute( te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
...@@ -352,12 +321,6 @@ def _test_permutation_index_map( ...@@ -352,12 +321,6 @@ def _test_permutation_index_map(
################################################################################################################################### ###################################################################################################################################
tols = dtype_tols(te_dtype) tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32)
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_permute_output_ = te_permute_output.float() te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
...@@ -487,7 +450,6 @@ def _test_permutation_mask_map( ...@@ -487,7 +450,6 @@ def _test_permutation_mask_map(
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
) )
fp8 = False
# Convert TE dtypes to PyTorch dtypes # Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32: if te_dtype == tex.DType.kFloat32:
dtype = torch.float32 dtype = torch.float32
...@@ -495,46 +457,9 @@ def _test_permutation_mask_map( ...@@ -495,46 +457,9 @@ def _test_permutation_mask_map(
dtype = torch.float16 dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16: elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16 dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else: else:
pytest.skip("Invalid dtype.") pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_permute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input)
permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input)
unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
...@@ -553,9 +478,6 @@ def _test_permutation_mask_map( ...@@ -553,9 +478,6 @@ def _test_permutation_mask_map(
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True) row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums probs = probs / row_sums
if fp8:
probs = probs.to(torch.float16)
else:
probs = probs.to(dtype) probs = probs.to(dtype)
probs.requires_grad_(True) probs.requires_grad_(True)
...@@ -582,9 +504,9 @@ def _test_permutation_mask_map( ...@@ -582,9 +504,9 @@ def _test_permutation_mask_map(
# TE Permutation # TE Permutation
# #
################################################################################################################################### ###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() te_permute_fwd_input = pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True) te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() te_permute_bwd_input = pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute( te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
...@@ -597,7 +519,7 @@ def _test_permutation_mask_map( ...@@ -597,7 +519,7 @@ def _test_permutation_mask_map(
te_probs.requires_grad_(True) te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True) te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute( te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
...@@ -611,12 +533,6 @@ def _test_permutation_mask_map( ...@@ -611,12 +533,6 @@ def _test_permutation_mask_map(
################################################################################################################################### ###################################################################################################################################
tols = dtype_tols(te_dtype) tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32)
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_permute_output_ = te_permute_output.float() te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
...@@ -730,6 +646,118 @@ def _test_permutation_mask_map( ...@@ -730,6 +646,118 @@ def _test_permutation_mask_map(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
recipe,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
if recipe.delayed():
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
elif recipe.float8_current_scaling():
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=te_dtype,
device=torch.device("cuda"),
columnwise=False,
)
elif recipe.float8_block_scaling():
quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=False,
amax_epsilon=0.0,
force_pow_2_scales=True, # Fp8 sub-channel a2a requires e8 scales
block_scaling_dim=1, # 1x128 scaling
)
elif recipe.mxfp8():
quantizer = MXFP8Quantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=False,
)
else:
raise ValueError("Unsupported FP8 recipe")
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
# Make an empty fp8 tensor
permute_fwd_input_fp8 = quantizer.make_empty(
permute_fwd_input.shape,
dtype=permute_fwd_input.dtype,
device=permute_fwd_input.device,
)
# quantize the tensor
quantizer.update_quantized(permute_fwd_input, permute_fwd_input_fp8)
if recipe.float8_block_scaling():
pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data)
pytorch_permute_fwd_scale_input = copy.deepcopy(
permute_fwd_input_fp8._rowwise_scale_inv.T.contiguous()
)
elif recipe.mxfp8():
pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data)
pytorch_permute_fwd_scale_input = copy.deepcopy(
permute_fwd_input_fp8._rowwise_scale_inv.contiguous()
)
else:
pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._data)
pytorch_permute_fwd_scale_input = None
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
# PyTorch Permutaion
pytorch_permute_output, _ = pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
if pytorch_permute_fwd_scale_input is not None:
pytorch_permute_scale_output, _ = pytorch_permute_mask_map(
pytorch_permute_fwd_scale_input, routing_map
)
# TE Permutation
permute_output, _ = te_permute(
permute_fwd_input_fp8, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
)
if recipe.float8_block_scaling():
te_permute_output = permute_output._rowwise_data
te_permute_scale_output = permute_output._rowwise_scale_inv.T.contiguous()
elif recipe.mxfp8():
te_permute_output = permute_output._rowwise_data
te_permute_scale_output = permute_output._rowwise_scale_inv.contiguous()
else:
te_permute_output = permute_output._data
te_permute_scale_output = None
# check the permute output
torch.testing.assert_close(
pytorch_permute_output,
te_permute_output,
atol=0,
rtol=0,
)
if recipe.float8_block_scaling() or recipe.mxfp8():
torch.testing.assert_close(
pytorch_permute_scale_output,
te_permute_scale_output,
atol=0,
rtol=0,
)
def _test_moe_chunk_sort( def _test_moe_chunk_sort(
te_dtype, te_dtype,
num_tokens, num_tokens,
...@@ -743,7 +771,6 @@ def _test_moe_chunk_sort( ...@@ -743,7 +771,6 @@ def _test_moe_chunk_sort(
f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}" f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}"
) )
fp8 = False
# Convert TE dtypes to PyTorch dtypes # Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32: if te_dtype == tex.DType.kFloat32:
dtype = torch.float32 dtype = torch.float32
...@@ -751,32 +778,9 @@ def _test_moe_chunk_sort( ...@@ -751,32 +778,9 @@ def _test_moe_chunk_sort(
dtype = torch.float16 dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16: elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16 dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else: else:
pytest.skip("Invalid dtype.") pytest.skip("Invalid dtype.")
if fp8:
fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
fwd_input = _fwd_input_quantizer.quantize(fwd_input)
bwd_input = _bwd_input_quantizer.quantize(bwd_input)
pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16)
pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
...@@ -806,9 +810,9 @@ def _test_moe_chunk_sort( ...@@ -806,9 +810,9 @@ def _test_moe_chunk_sort(
# TE Permutation # TE Permutation
# #
################################################################################################################################### ###################################################################################################################################
te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach() te_fwd_input = pytorch_fwd_input.detach()
te_fwd_input.requires_grad_(True) te_fwd_input.requires_grad_(True)
te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach() te_bwd_input = pytorch_bwd_input.detach()
te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
te_output.backward(te_bwd_input, retain_graph=True) te_output.backward(te_bwd_input, retain_graph=True)
...@@ -820,10 +824,6 @@ def _test_moe_chunk_sort( ...@@ -820,10 +824,6 @@ def _test_moe_chunk_sort(
################################################################################################################################### ###################################################################################################################################
tols = dtype_tols(te_dtype) tols = dtype_tols(te_dtype)
if fp8:
te_output_ = te_output.dequantize(dtype=torch.float32)
te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_output_ = te_output.float() te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float() te_fwd_input_grad = te_fwd_input.grad.float()
...@@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs(
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
) )
fp8 = False
# Convert TE dtypes to PyTorch dtypes # Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32: if te_dtype == tex.DType.kFloat32:
dtype = torch.float32 dtype = torch.float32
...@@ -907,36 +906,9 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -907,36 +906,9 @@ def _test_permutation_mask_map_alongside_probs(
dtype = torch.float16 dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16: elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16 dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else: else:
pytest.skip("Invalid dtype.") pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input)
unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
...@@ -952,9 +924,6 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -952,9 +924,6 @@ def _test_permutation_mask_map_alongside_probs(
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True) row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums probs = probs / row_sums
if fp8:
probs = probs.to(torch.float16)
else:
probs = probs.to(dtype) probs = probs.to(dtype)
probs.requires_grad_(True) probs.requires_grad_(True)
...@@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs(
# TE Permutation # TE Permutation
# #
################################################################################################################################### ###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() te_permute_fwd_input = pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True) te_permute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
te_probs = probs.detach() te_probs = probs.detach()
te_probs.requires_grad_(True) te_probs.requires_grad_(True)
print(te_probs.shape)
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input, te_permute_fwd_input,
...@@ -1020,25 +988,12 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -1020,25 +988,12 @@ def _test_permutation_mask_map_alongside_probs(
routing_map, routing_map,
num_out_tokens=num_out_tokens, num_out_tokens=num_out_tokens,
) )
print(te_permuted_probs.shape)
te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs(
te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda
) )
if fp8:
_permute_output_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
te_permute_output = te_permute_output.dequantize(dtype=torch.float32)
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = _permute_output_quantizer.quantize(te_permute_output)
else:
te_permute_output_dtype = te_permute_output.dtype te_permute_output_dtype = te_permute_output.dtype
print(te_permute_output.shape)
print(te_permuted_probs.shape)
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype)
...@@ -1058,11 +1013,6 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -1058,11 +1013,6 @@ def _test_permutation_mask_map_alongside_probs(
tols = dtype_tols(te_dtype) tols = dtype_tols(te_dtype)
if fp8:
# backward of dequantize is in high precision
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
else:
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
...@@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): ...@@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
...@@ -1237,36 +1197,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() ...@@ -1237,36 +1197,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_index_map_fp8( @pytest.mark.parametrize("recipe", fp8_recipes)
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map_fp8( def test_permutation_mask_map_fp8(
te_dtype, te_dtype,
num_tokens, num_tokens,
...@@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8( ...@@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8(
hidden_size, hidden_size,
topK, topK,
num_out_tokens, num_out_tokens,
recipe,
): ):
with_probs = True if recipe.mxfp8() and not mxfp8_available:
BENCHMARK = False pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
_test_permutation_mask_map( pytest.skip(reason_for_no_fp8_block_scaling)
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) _test_permutation_mask_map_fp8(
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
def test_permutation_mask_map_alongside_probs_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
tp_size,
):
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype, te_dtype=te_dtype,
num_tokens=num_tokens, num_tokens=num_tokens,
num_expert=num_expert, num_expert=num_expert,
hidden_size=hidden_size, hidden_size=hidden_size,
topK=topK, topK=topK,
num_out_tokens=num_out_tokens, num_out_tokens=num_out_tokens,
tp_size=tp_size, recipe=recipe,
) )
...@@ -1415,11 +1320,9 @@ def test_permutation_single_case(): ...@@ -1415,11 +1320,9 @@ def test_permutation_single_case():
# te_dtype = tex.DType.kFloat32 # te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16 # te_dtype = tex.DType.kFloat16
# te_dtype = tex.DType.kBFloat16 te_dtype = tex.DType.kBFloat16
te_dtype = tex.DType.kFloat8E5M2
# te_dtype = tex.DType.kFloat8E4M3
num_tokens = 10 num_tokens = 12
num_expert = 4 num_expert = 4
hidden_size = 16 hidden_size = 16
topK = 2 topK = 2
......
...@@ -43,10 +43,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -43,10 +43,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
...@@ -113,6 +117,7 @@ fp8_recipes = [ ...@@ -113,6 +117,7 @@ fp8_recipes = [
None, # Test non-FP8 None, # Test non-FP8
recipe.MXFP8BlockScaling(), # Test default recipe.MXFP8BlockScaling(), # Test default
recipe.Float8CurrentScaling(), # Test default recipe.Float8CurrentScaling(), # Test default
recipe.Float8BlockScaling(), # Test default
recipe.DelayedScaling(), # Test default recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( # Test most_recent algo recipe.DelayedScaling( # Test most_recent algo
amax_history_len=16, amax_history_len=16,
...@@ -446,6 +451,8 @@ def test_sanity_layernorm_linear( ...@@ -446,6 +451,8 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -477,6 +484,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -477,6 +484,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -509,6 +518,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -509,6 +518,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -550,10 +561,10 @@ def test_sanity_grouped_linear( ...@@ -550,10 +561,10 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8(): if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip("Grouped linear does not support MXFP8") pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_current_scaling(): if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip("Grouped linear does not support FP8 current scaling") pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -597,6 +608,8 @@ def test_sanity_layernorm_mlp( ...@@ -597,6 +608,8 @@ def test_sanity_layernorm_mlp(
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -647,6 +660,8 @@ def test_sanity_gpt( ...@@ -647,6 +660,8 @@ def test_sanity_gpt(
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -714,6 +729,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -714,6 +729,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -773,6 +790,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no ...@@ -773,6 +790,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -830,6 +849,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -830,6 +849,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -865,6 +886,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -865,6 +886,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -903,6 +926,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -903,6 +926,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -944,6 +969,8 @@ def test_sanity_gradient_accumulation_fusion( ...@@ -944,6 +969,8 @@ def test_sanity_gradient_accumulation_fusion(
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported(): if not config.is_fp8_supported():
...@@ -990,8 +1017,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -990,8 +1017,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available: if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -1266,3 +1297,31 @@ def test_fp8_model_init_high_precision_init_val(): ...@@ -1266,3 +1297,31 @@ def test_fp8_model_init_high_precision_init_val():
assert not hasattr( assert not hasattr(
weight, "._high_precision_init_val" weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work" ), "clear_high_precision_init_val() not work"
def test_sanity_checkpointing_on_callables():
"""Test that TE checkpointing works correctly on callable modules."""
# torch.autograf.function
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp
@staticmethod
def backward(ctx, grad_output):
return grad_output
module = MyFunction.apply
inp = torch.randn(10, 10, device="cuda", requires_grad=True)
out_checkpoint = checkpoint(module, inp)
out_checkpoint.sum().backward()
grad_checkpoint = inp.grad
out_standard = module(inp)
out_standard.sum().backward()
grad_standard = inp.grad
# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)
...@@ -116,6 +116,8 @@ if(USE_CUDA) ...@@ -116,6 +116,8 @@ if(USE_CUDA)
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...@@ -161,6 +163,8 @@ else() ...@@ -161,6 +163,8 @@ else()
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
# transpose/quantize_transpose_square_blockwise.cu
# transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
...@@ -271,6 +275,20 @@ if (NVTE_UB_WITH_MPI) ...@@ -271,6 +275,20 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif() endif()
option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
if (NVTE_ENABLE_NVSHMEM)
add_subdirectory(nvshmem_api)
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif()
option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
if (NVTE_ENABLE_NVSHMEM)
add_subdirectory(nvshmem_api)
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif()
if (USE_CUDA) if (USE_CUDA)
# Hack to enable dynamic loading in cuDNN frontend # Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
......
...@@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { ...@@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
constexpr NVTETensor workspace = nullptr; constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr; constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias, quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
workspace, stream); nullptr, stream);
} }
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)> template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
...@@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, ...@@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
constexpr NVTETensor dbias = nullptr; constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr; constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias, quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
workspace, stream); nullptr, stream);
} }
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)> template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
......
...@@ -83,8 +83,8 @@ struct SimpleTensor { ...@@ -83,8 +83,8 @@ struct SimpleTensor {
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
operator NVTEBasicTensor() const { operator NVTEBasicTensor() const {
const NVTEShape shape = {this->shape.data(), this->shape.size()}; return {dptr, static_cast<NVTEDType>(dtype),
return {dptr, static_cast<NVTEDType>(dtype), shape}; nvte_make_shape(this->shape.data(), this->shape.size())};
} }
int numel() const { int numel() const {
...@@ -104,6 +104,7 @@ struct Tensor { ...@@ -104,6 +104,7 @@ struct Tensor {
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
public:
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
Tensor() Tensor()
...@@ -165,6 +166,28 @@ struct Tensor { ...@@ -165,6 +166,28 @@ struct Tensor {
return data.shape; return data.shape;
} }
break; break;
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> shape;
size_t ndim = columnwise_data.shape.size();
shape.reserve(ndim);
for (size_t i = 0; i + 1 < ndim; ++i) {
shape.push_back(columnwise_data.shape[i + 1]);
}
if (ndim > 0) {
shape.push_back(columnwise_data.shape[0]);
}
return shape;
} else {
// NOTE: We may have removed the data pointer from
// data by setting usage. In that case, we return
// the non-null shape. It is our best guess at the most
// recent shape.
return data.shape;
}
break;
}
default: default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {}; return {};
...@@ -205,10 +228,12 @@ struct Tensor { ...@@ -205,10 +228,12 @@ struct Tensor {
struct QuantizationConfig { struct QuantizationConfig {
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
float amax_epsilon = 0.0f; float amax_epsilon = 0.0f;
NVTETensor noop_tensor = nullptr;
static constexpr size_t attr_sizes[] = { static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales sizeof(bool), // force_pow_2_scales
sizeof(float) // amax_epsilon sizeof(float), // amax_epsilon
sizeof(NVTETensor) // noop_tensor
}; };
}; };
...@@ -264,6 +289,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) ...@@ -264,6 +289,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif #endif
#undef TRANSFORMER_ENGINE_TYPE_NAME #undef TRANSFORMER_ENGINE_TYPE_NAME
template <typename T>
struct TypeExtrema;
template <>
struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f;
};
template <>
struct TypeExtrema<fp8e5m2> {
static constexpr float max = 57344.0f;
};
template <>
struct TypeExtrema<bf16> {
// Hex float format of 1.(7 bits of 1) * 2 ^ 127
static constexpr float max = 0x1.FEp127;
};
template <>
struct TypeExtrema<fp16> {
// Hex float format of 1.(10 bits of 1) * 2 ^ 15
static constexpr float max = 0x1.FFCp15;
};
template <typename T>
struct TypeExtrema {
static constexpr float max = std::numeric_limits<T>::max();
};
} // namespace detail } // namespace detail
template <typename T> template <typename T>
...@@ -294,6 +349,7 @@ struct TypeInfo { ...@@ -294,6 +349,7 @@ struct TypeInfo {
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = sizeof(T);
constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
constexpr static const char *name = detail::type_name<T>(); constexpr static const char *name = detail::type_name<T>();
}; };
......
...@@ -16,10 +16,11 @@ namespace transformer_engine { ...@@ -16,10 +16,11 @@ namespace transformer_engine {
template <typename scalar_t> template <typename scalar_t>
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block, const bool interleaved, const int s_id,
const int offset_block_dst, const int h, const int d, const int offset_block, const int offset_block_dst,
const int d2, const int stride_h, const int stride_d, const int h, const int d, const int d2, const int stride_h,
const int o_stride_h, const int o_stride_d) { const int stride_d, const int o_stride_h,
const int o_stride_d) {
#pragma unroll #pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin; float v_cos, v_sin;
...@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs ...@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src]; float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2) float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d]) ? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]); : static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? -static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} }
} }
...@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs ...@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
template <typename scalar_t> template <typename scalar_t>
__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block, const bool interleaved, const int s_id,
const int offset_block_dst, const int h, const int d, const int offset_block, const int offset_block_dst,
const int d2, const int stride_h, const int stride_d, const int h, const int d, const int d2,
const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
#pragma unroll #pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]); float v_cos = cosf(freqs[s_id * d2 + d_id]);
float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) float v_sin;
if (!interleaved) {
v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
} else {
v_sin =
(d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]);
}
#pragma unroll #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src]; float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] float v_src_rotate;
: src[offset_src + (d2 / 2 - d2) * stride_d]; if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} }
} }
...@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq ...@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y; int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
if (cu_seqlens != nullptr) { // THD
int start = cu_seqlens[b_id] / cp_size; int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size; int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start; int t_id = s_id + start;
if (t_id >= end) return; if (t_id >= end) return;
int offset_block = t_id * stride_t; offset_block = t_id * stride_s_or_t;
int offset_block_dst = t_id * o_stride_t; offset_block_dst = t_id * o_stride_s_or_t;
cur_seqlens = end - start;
} else { // SBHD/BSHD
offset_block = s_id * stride_s_or_t + b_id * stride_b;
offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b;
cur_seqlens = s;
}
int s_id_for_freqs; int s_id_for_freqs;
if (cp_size > 1) { if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0); assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) { if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
...@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu ...@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu
} else { } else {
s_id_for_freqs = s_id; s_id_for_freqs = s_id;
} }
fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d); fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, __global__ void fused_rope_backward_kernel(
const float *freqs, scalar_t *dst, const int cp_size, const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst,
const int cp_rank, const int h, const int d, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h,
const int d2, const int stride_t, const int stride_h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_t, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h,
const int o_stride_h, const int o_stride_d) { const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y; int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
if (cu_seqlens != nullptr) { // THD
int start = cu_seqlens[b_id] / cp_size; int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size; int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start; int t_id = s_id + start;
if (t_id >= end) return; if (t_id >= end) return;
int offset_block = t_id * stride_t; offset_block = t_id * stride_s_or_t;
int offset_block_dst = t_id * o_stride_t; offset_block_dst = t_id * o_stride_s_or_t;
cur_seqlens = end - start;
} else { // SBHD/BSHD
offset_block = s_id * stride_s_or_t + b_id * stride_b;
offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b;
cur_seqlens = s;
}
int s_id_for_freqs; int s_id_for_freqs;
if (cp_size > 1) { if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0); assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) { if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
...@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c ...@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c
} else { } else {
s_id_for_freqs = s_id; s_id_for_freqs = s_id;
} }
fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d); fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
} }
template <typename scalar_t> template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b, const int stride_d, cudaStream_t stream) {
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
o_stride_s_or_t = h * d;
o_stride_b = 0;
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
o_stride_s_or_t = b * h * d;
o_stride_b = h * d;
} else {
o_stride_s_or_t = h * d;
o_stride_b = s * h * d;
}
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t,
o_stride_b, o_stride_h, o_stride_d); stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename scalar_t> template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
scalar_t *input_grads, const int s, const int b, const int h, const float *freqs, scalar_t *input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) { cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
o_stride_s_or_t = h * d;
o_stride_b = 0;
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
o_stride_s_or_t = b * h * d;
o_stride_b = h * d;
} else {
o_stride_s_or_t = h * d;
o_stride_b = s * h * d;
}
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
o_stride_s, o_stride_b, o_stride_h, o_stride_d); stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
NVTE_CHECK_CUDA(cudaGetLastError()); o_stride_d);
}
template <typename scalar_t>
void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens,
const float *freqs, scalar_t *output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename scalar_t> void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved,
const float *freqs, scalar_t *input_grads, const int cp_size, const int cp_size, const int cp_rank, const int s, const int b, const int h,
const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b,
const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, cudaStream_t stream) {
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s,
const int b, const int h, const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t, input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr), fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), s, b, h, d,
d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const int cp_size, const int cp_rank, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), cp_size, reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
o_stride_t, o_stride_h, o_stride_d, stream);); stride_b, stride_h, stride_d, stream););
} }
void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs,
const Tensor &freqs, Tensor *input_grads, const int cp_size, Tensor *input_grads, const NVTE_QKV_Format qkv_format,
const int cp_rank, const int max_s, const int b, const int h, const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int d, const int d2, const int stride_t, const int stride_h, const int b, const int h, const int d, const int d2,
const int stride_d, const int o_stride_t, const int o_stride_h, const int stride_s_or_t, const int stride_b, const int stride_h,
const int o_stride_d, cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t, output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr), fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format,
cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); stride_b, stride_h, stride_d, stream););
} }
} // end namespace transformer_engine } // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const int s, const int b, const int h, const int d, const int d2, const NVTETensor freqs, NVTETensor output,
const int stride_s, const int stride_b, const int stride_h, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_d, const int o_stride_s, const int o_stride_b, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) { const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward); NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input), fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output), *reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output),
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
o_stride_h, o_stride_d, stream); stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
NVTETensor input_grads, const int s, const int b, const int h, const NVTETensor freqs, NVTETensor input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward); NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads), fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), s, b, h, d, d2, stride_s, stride_b,
stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), cp_size, cp_rank, max_s, b, h, d, d2, reinterpret_cast<Tensor *>(input_grads), qkv_format, interleaved, cp_size,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
}
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(
*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment