Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -97,6 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize()
tensor.sum().backward()
return max_mem_used
......@@ -115,6 +117,9 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import gc
gc.collect()
model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
......
......@@ -88,6 +88,126 @@ def initialize_for_many_scales(
return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
) -> None:
te_dtype = TE_DType[quant_dtype]
tile_size = (1, 128)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=1,
all_gather_usage=True,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
x_fp8_sut_cpp_alloc = sut_quantizer(x)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=True,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
munge_scale_shapes=False,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous()
sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous()
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
# check that the C++ and Python allocators are equivalent
torch.testing.assert_close(
x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_scale_inv,
x_fp8_sut_cpp_alloc._columnwise_scale_inv,
atol=0.0,
rtol=0.0,
)
# check if the fp8 output between C++ and Python are the same
assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format
def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
......
......@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
)
# recipe1
using_fp8_recipe = recipe1() is not None
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
......@@ -393,7 +393,7 @@ class TestFP8RecipeLinearBase:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
......@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe1
using_fp8_recipe = recipe1() is not None
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
......@@ -630,7 +630,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
......
......@@ -176,7 +176,40 @@ class TestFloat8BlockwiseTensor:
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_quantize_dequantize_dims(
self,
dims: DimsType,
block_scaling_dim: int,
dq_columnwise: bool,
all_gather_usage: bool,
) -> None:
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.xfail(raises=NotImplementedError)
def test_quantize_dequantize_compact_format(
self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool
) -> None:
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
......@@ -186,6 +219,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
)
self._test_quantize_dequantize(
quantizer=quantizer,
......@@ -250,8 +284,13 @@ class TestFloat8BlockwiseTensor:
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_serialization(
self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool
) -> None:
"""Test serialization of Float8BlockwiseQTensor"""
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
......@@ -260,6 +299,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)
# Create FP8 tensor
......@@ -283,6 +323,7 @@ class TestFloat8BlockwiseTensor:
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
assert x_fp8_loaded._data_format == x_fp8._data_format
# Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize()
......
......@@ -7,6 +7,8 @@ from __future__ import annotations
from collections.abc import Iterable
import io
import math
import pathlib
import sys
from typing import Optional
import pytest
......@@ -25,10 +27,20 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe
if IS_HIP_EXTENSION:
import os
from functools import cache
......@@ -49,6 +61,13 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
_quantization_list.append("mxfp8")
def maybe_skip_quantization(
quantization: Optional[str],
......@@ -56,13 +75,14 @@ def maybe_skip_quantization(
dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None,
) -> None:
"""Skip test case if a quantization scheme is not supported"""
# Don't skip if there is no quantization
if quantization is None:
return
# Check if quantization scheme is supported
if quantization == "fp8" and not fp8_available:
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
......@@ -70,7 +90,7 @@ def maybe_skip_quantization(
if dims is not None:
if not isinstance(dims, Iterable):
dims = (dims,)
if quantization == "fp8":
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
elif quantization == "mxfp8":
......@@ -82,47 +102,15 @@ def maybe_skip_quantization(
pytest.skip("Quantization is only supported on CUDA devices")
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -131,39 +119,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
class TestSequential:
"""Tests for sequential container"""
......@@ -373,7 +371,7 @@ class TestFuser:
@pytest.mark.parametrize("init_dtype", _dtypes)
@pytest.mark.parametrize("final_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_dtype_cast(
self,
*,
......@@ -386,8 +384,9 @@ class TestFuser:
"""Check dtype cast functions"""
# Skip invalid configurations
maybe_skip_quantization(quantization, device=device)
in_shape = (size, size)
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
dtype = torch.float32
......@@ -397,9 +396,9 @@ class TestFuser:
dtype = torch.bfloat16
w_ref, w_test = make_reference_and_test_tensors(
(size, size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=with_quantization,
)
# Construct operation
......@@ -421,11 +420,11 @@ class TestFuser:
assert isinstance(op.weight, QuantizedTensor) == with_quantization
assert op.weight.dtype == final_dtype
w_test = op.weight.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0)
torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
# Check forward and backward pass
x = torch.zeros(
(size, size),
in_shape,
dtype=init_dtype,
device=device,
requires_grad=True,
......@@ -438,7 +437,7 @@ class TestFuser:
@pytest.mark.parametrize("model_dtype", _dtypes)
@pytest.mark.parametrize("autocast_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_pyt_autocast(
self,
*,
......@@ -453,8 +452,9 @@ class TestFuser:
device = torch.device(device)
# Skip invalid configurations
in_shape = (size, size)
quantized_compute = quantization is not None
maybe_skip_quantization(quantization)
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Construct operation
recipe = make_recipe(quantization)
......@@ -463,7 +463,7 @@ class TestFuser:
# Check forward and backward pass
x = torch.zeros(
(size, size),
in_shape,
dtype=model_dtype,
device=device,
requires_grad=True,
......@@ -501,33 +501,34 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_identity(
self,
*,
in_shape: Iterable[int] = (1,),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -563,7 +564,7 @@ class TestBasicOps:
),
)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
def test_reshape(
self,
*,
......@@ -571,31 +572,32 @@ class TestBasicOps:
dtype: torch.dtype,
device: torch.device = "cuda",
memory_format: torch.memory_format = torch.contiguous_format,
fp8: bool,
quantization: Optional[str],
) -> None:
in_shape, out_shape = shapes
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
maybe_skip_quantization(quantization, device=device)
with_quantization = quantization is not None
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
x_test = x_test.contiguous(memory_format=memory_format)
x_test = x_test.detach().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
x_ref.reshape(out_shape).size(),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -624,10 +626,10 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_bias(
self,
*,
......@@ -635,24 +637,23 @@ class TestBasicOps:
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
# Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size]
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
b_ref, b_test = make_reference_and_test_tensors(
size,
......@@ -661,8 +662,10 @@ class TestBasicOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -687,7 +690,7 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cast_forward", (False, True))
@pytest.mark.parametrize("cast_backward", (False, True))
def test_quantize(
......@@ -703,25 +706,26 @@ class TestBasicOps:
"""Quantize"""
# Skip invalid configurations
maybe_skip_quantization(quantization)
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
test_is_fp8=True,
requires_grad=True,
)
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
test_is_fp8=True,
)
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = x_ref
......@@ -730,11 +734,12 @@ class TestBasicOps:
# Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization)
with te.fp8_autocast(fp8_recipe=recipe):
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Check tensor types
if with_quantization:
assert isinstance(y_test, QuantizedTensor) == cast_forward
assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
......@@ -771,9 +776,24 @@ class TestBasicOps:
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization == "fp8" and quantized_output and not quantized_compute:
quantization_needed = any(
(
quantized_compute,
quantized_input,
quantized_weight,
quantized_output,
quantized_grad_output,
quantized_grad_input,
)
)
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if quantized_output and not quantized_compute:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantization == "fp8" and quantized_grad_input and not quantized_compute:
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
......@@ -786,28 +806,25 @@ class TestBasicOps:
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_input),
test_is_quantized=quantized_input,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_grad_output),
test_is_quantized=quantized_grad_output,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -870,7 +887,7 @@ class TestBasicOps:
@pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
def test_basic_linear(
self,
......@@ -892,7 +909,7 @@ class TestBasicOps:
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_input", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
......@@ -911,6 +928,8 @@ class TestBasicOps:
quantized_grad_input: bool,
) -> None:
"""GEMM with FP8 inputs and outputs"""
if quantization is None:
pytest.skip("Skipping case without quantization")
self._test_basic_linear(
dtype=torch.bfloat16,
quantization=quantization,
......@@ -923,8 +942,11 @@ class TestBasicOps:
)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_linear(
self,
*,
......@@ -934,7 +956,10 @@ class TestBasicOps:
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""GEMM + bias"""
......@@ -944,25 +969,25 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
......@@ -973,6 +998,7 @@ class TestBasicOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -998,8 +1024,11 @@ class TestBasicOps:
op.bias.copy_(b_test)
del w_test
del b_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = op(x_test)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
# Expected numerical error
......@@ -1011,10 +1040,12 @@ class TestBasicOps:
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if weight_requires_grad:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
......@@ -1024,7 +1055,7 @@ class TestBasicOps:
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_layer_norm(
self,
*,
......@@ -1194,7 +1225,7 @@ class TestBasicOps:
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_rmsnorm(
self,
*,
......@@ -1275,16 +1306,68 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_l2normalization(
self,
*,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 1e-6,
) -> None:
"""L2 Normalization"""
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
# L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps)
l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
y_ref = x_ref * rsqrt_norm
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.L2Normalization(
eps=eps,
)
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
# L2Norm backward pass requires slightly looser atol for bfloat16
if dtype == torch.bfloat16:
tols["atol"] = 2e-3
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place(
self,
*,
in_shape: Iterable[int] = (1,),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Add two tensors
......@@ -1293,28 +1376,30 @@ class TestBasicOps:
"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
x2_ref, x2_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -1331,7 +1416,7 @@ class TestBasicOps:
# Check results
tols = dtype_tols(dtype)
if fp8:
if with_quantization:
tols = dtype_tols(x1_test._fp8_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
......@@ -1342,14 +1427,14 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_make_extra_output(
self,
*,
in_shape: Iterable[int] = (1,),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Output tensor twice
......@@ -1358,28 +1443,31 @@ class TestBasicOps:
"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -1405,7 +1493,7 @@ class TestBasicOps:
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cache_quantized_input", (False, True))
def test_activation(
self,
......@@ -1428,26 +1516,21 @@ class TestBasicOps:
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
if cache_quantized_input:
maybe_skip_quantization("fp8", device=device)
maybe_skip_quantization("fp8_current_scaling", device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization="fp8_current_scaling" if cache_quantized_input else None,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref: torch.Tensor
......@@ -1490,8 +1573,6 @@ class TestBasicOps:
tols = dtype_tols(dtype)
if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu" and not cache_quantized_input:
tols = {"atol": 0, "rtol": 0}
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1500,7 +1581,7 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_swiglu(
......@@ -1578,7 +1659,7 @@ class TestFusedOps:
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_forward_linear_bias_activation(
self,
......@@ -1610,18 +1691,15 @@ class TestFusedOps:
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
......@@ -1632,6 +1710,7 @@ class TestFusedOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1688,7 +1767,7 @@ class TestFusedOps:
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_forward_linear_bias_add(
self,
*,
......@@ -1717,18 +1796,15 @@ class TestFusedOps:
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x1_test, QuantizedTensor):
with torch.no_grad():
x1_test = x1_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
......@@ -1744,6 +1820,7 @@ class TestFusedOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1802,7 +1879,7 @@ class TestFusedOps:
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add(
self,
*,
......@@ -1830,27 +1907,26 @@ class TestFusedOps:
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1914,7 +1990,7 @@ class TestCheckpointing:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
self,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.my_layer = TransformerLayer(
hidden_size=320,
num_attention_heads=16,
ffn_hidden_size=1024,
layer_number=None,
)
def forward(self, hidden_states, attention_mask):
return self.my_layer(hidden_states, attention_mask)
def test_save_hf_model(tmp_path):
model = SimpleTEModel(PretrainedConfig())
model.save_pretrained(tmp_path / "simple_te_model")
@pytest.mark.xfail(reason="This test is failing until huggingface/transformers#38155 is merged.")
def test_save_and_load_hf_model(tmp_path):
model = SimpleTEModel(PretrainedConfig())
model.save_pretrained(tmp_path / "simple_te_model")
del model
model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model")
assert model is not None
......@@ -63,3 +63,62 @@ def test_lazy_compile():
from transformer_engine.pytorch.jit import dgelu_fused_
dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
def test_l2normalization_fused():
"""Smoke test for L2Normalization fusion functions."""
from transformer_engine.pytorch.jit import (
l2normalization_fused,
l2normalization_fwd_fused,
l2normalization_backward_fused,
)
# Basic smoke test like other JIT functions
x = torch.randn(10, 128, device="cuda", dtype=torch.float32)
eps = 1e-6
# Test inference version
output_inf = l2normalization_fused(x, eps)
# Test training version with backward
x_train = torch.randn(10, 128, device="cuda", dtype=torch.float32, requires_grad=True)
output_train, rsqrt_norm = l2normalization_fwd_fused(x_train, eps)
grad_output = torch.randn_like(output_train)
grad_input = l2normalization_backward_fused(grad_output, x_train, rsqrt_norm, eps)
def test_l2normalization_fused_correctness():
"""Simple verification that L2Normalization fusion matches reference implementation."""
from transformer_engine.pytorch.jit import (
l2normalization_fwd_fused,
l2normalization_backward_fused,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(16, 64, device=device, dtype=torch.float32, requires_grad=True)
eps = 1e-6
# Test fused forward
output_fused, rsqrt_norm = l2normalization_fwd_fused(x, eps)
# Reference implementation
x_ref = x.clone().detach().requires_grad_(True)
x_squared = x_ref.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm_ref = torch.rsqrt(l2_norm_squared + eps)
output_ref = x_ref * rsqrt_norm_ref
# Check forward pass matches
torch.testing.assert_close(output_fused, output_ref, atol=1e-6, rtol=1e-5)
torch.testing.assert_close(rsqrt_norm, rsqrt_norm_ref, atol=1e-6, rtol=1e-5)
# Test fused backward
grad_output = torch.randn_like(output_fused)
grad_input_fused = l2normalization_backward_fused(grad_output, x, rsqrt_norm, eps)
# Reference backward
output_ref.backward(grad_output)
grad_input_ref = x_ref.grad
# Check backward pass matches
torch.testing.assert_close(grad_input_fused, grad_input_ref, atol=1e-5, rtol=1e-4)
......@@ -106,6 +106,20 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"]
NVTE_TEST_NVINSPECT_ENABLED = os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False)
if NVTE_TEST_NVINSPECT_ENABLED:
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
......@@ -572,6 +586,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
......@@ -686,6 +702,8 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
......@@ -1730,6 +1748,8 @@ def test_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
......@@ -1934,6 +1954,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
......@@ -2049,6 +2071,8 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_cuda_graph(dtype, bs, model):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Cuda Graphs are not supported in debug mode.")
config = model_configs[model]
sigma = 0.023
......@@ -2146,6 +2170,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.pytorch import MultiheadAttention
import pytest
import torch
@pytest.mark.parametrize("use_qk_norm", [False, True])
@pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None:
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
# Create MultiheadAttention module
mha = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_type=attention_type,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
bias=False,
device="cuda",
).cuda()
# Check module structure based on use_qk_norm parameter
if use_qk_norm:
assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True"
assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module"
assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module"
# Check that the module is L2Norm type
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
assert isinstance(
mha.qk_norm, L2Normalization
), "qk_norm should be an L2Normalization module"
else:
assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False"
# Create input tensors
batch_size = 2 # Use a fixed batch size for testing
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
if attention_type == "cross":
encoder_output = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
else:
encoder_output = None
# Test forward pass
with torch.no_grad():
if attention_type == "cross":
output = mha(hidden_states, encoder_output=encoder_output)
else:
output = mha(hidden_states)
# Check output shape and numerical properties
assert output.shape == (
seq_len,
batch_size,
hidden_size,
), f"Output shape mismatch: {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
# Test with RoPE (if self-attention)
if attention_type == "self":
head_dim = hidden_size // num_attention_heads
rotary_dim = head_dim // 2
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)
with torch.no_grad():
output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb)
assert output_with_rope.shape == (
seq_len,
batch_size,
hidden_size,
), "Output shape with RoPE mismatch"
assert not torch.isnan(output_with_rope).any(), "RoPE output contains NaN"
assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"
def test_qk_norm_output_difference() -> None:
"""Test that QK normalization actually changes the output compared to no normalization."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with QK normalization
mha_with_norm = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=True,
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create identical model without QK normalization
mha_no_norm = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=False,
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm settings
with torch.no_grad():
output_with_norm = mha_with_norm(hidden_states)
output_no_norm = mha_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled
assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the output, but outputs are identical"
def test_qk_norm_with_fused_qkv() -> None:
"""Test QK normalization works with fused QKV parameters."""
hidden_size = 256
num_attention_heads = 8
seq_len = 64
mha = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
fuse_qkv_params=True,
use_qk_norm=True,
bias=False,
device="cuda",
).cuda()
# Create input and test forward pass
batch_size = 2 # Use a fixed batch size for testing
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
with torch.no_grad():
output = mha(hidden_states)
assert output.shape == (
seq_len,
batch_size,
hidden_size,
), f"Output shape mismatch: {output.shape}"
def test_qk_norm_transformer_layer_output_difference() -> None:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from transformer_engine.pytorch import TransformerLayer
hidden_size = 256
ffn_hidden_size = 1024
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create TransformerLayer with QK normalization
transformer_with_norm = TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=True,
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create identical TransformerLayer without QK normalization
transformer_no_norm = TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=False,
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm settings
with torch.no_grad():
output_with_norm = transformer_with_norm(hidden_states)
output_no_norm = transformer_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled
assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the TransformerLayer output, but outputs are identical"
# Check that outputs have expected shapes and properties
assert output_with_norm.shape == (
seq_len,
batch_size,
hidden_size,
), f"Output shape mismatch: {output_with_norm.shape}"
assert not torch.isnan(output_with_norm).any(), "Output with QK norm contains NaN"
assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"
......@@ -6,22 +6,32 @@ from typing import Iterable, Optional
import pytest
import torch
import warnings
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
_amax_and_scale_update,
get_default_fp8_recipe,
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
# FP8 per tensor delayed scaling
......@@ -368,3 +378,127 @@ class TestFP8Recipe:
)
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
@pytest.mark.parametrize(
"model_init_recipe",
[
pytest.param(
MXFP8BlockScaling(),
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
Float8BlockScaling(),
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
],
)
def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
with fp8_model_init(enabled=True, recipe=model_init_recipe):
linear = Linear(32, 32).cuda()
x = torch.randn(32, 32, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
with pytest.raises(RuntimeError) as excinfo:
_ = linear(x)
assert "Recipe mismatch for " in str(excinfo.value)
@pytest.mark.parametrize(
"target_recipe_class, expected_quantizer_type, available_flag, reason",
[
pytest.param(
MXFP8BlockScaling,
MXFP8Quantizer,
mxfp8_available,
reason_for_no_mxfp8,
id="DelayedScaling->MXFP8BlockScaling",
),
pytest.param(
Float8BlockScaling,
Float8BlockQuantizer,
fp8_block_scaling_available,
reason_for_no_fp8_block_scaling,
id="DelayedScaling->Float8BlockScaling",
),
],
)
def test_dynamic_recipe_update(
self, target_recipe_class, expected_quantizer_type, available_flag, reason
):
if not available_flag:
pytest.skip(reason)
in_features = 32
out_features = 32
batch_size = 32
linear = Linear(in_features, out_features).cuda()
initial_recipe = DelayedScaling()
# Run initial iterations with DelayedScaling
for _ in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=initial_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, Float8Quantizer)
# Change recipe
target_recipe = target_recipe_class()
# Run subsequent iterations with the target recipe
for i in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
if i == 0:
# Expect a warning on the first iteration with the new recipe
with pytest.warns(UserWarning, match="Recipe type changed"):
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
else:
# No warning expected on subsequent iterations
with warnings.catch_warnings():
warnings.simplefilter("error") # Raise error if unexpected warning occurs
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
# Final check
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
@pytest.mark.parametrize(
"module_class",
[
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
],
)
def test_quantizer_update(self, module_class):
in_features = 32
out_features = 32
batch_size = 32
recipe = DelayedScaling(amax_history_len=1024)
with fp8_model_init(recipe=recipe):
if module_class == GroupedLinear:
module = module_class(1, in_features, out_features).cuda()
else:
module = module_class(in_features, out_features).cuda()
x = torch.randn(batch_size, in_features, device="cuda")
recipe = DelayedScaling(amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
warn_msg = "Quantizer is being updated, this may affect model behavior"
with pytest.warns(UserWarning, match=warn_msg):
if module_class == GroupedLinear:
y = module(x, [batch_size])
else:
y = module(x)
......@@ -11,6 +11,7 @@ import pytest
import os
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
......@@ -39,9 +40,11 @@ from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols
......@@ -1349,3 +1352,80 @@ def test_sanity_checkpointing_on_callables():
# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)
@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
"quantization",
(None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
module_name: str,
quantization: Optional[str],
) -> None:
"""Test heuristics for initializing quantized weights"""
# Tensor dimensions
sequence_length = 32
hidden_size = 32
# Skip invalid configurations
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Construct quantization recipe
with_quantization = quantization not in (None, "None")
quantization_recipe = None
if quantization == "fp8_delayed_scaling":
quantization_recipe = recipe.DelayedScaling()
elif quantization == "fp8_current_scaling":
quantization_recipe = recipe.Float8CurrentScaling()
elif quantization == "mxfp8":
quantization_recipe = recipe.MXFP8BlockScaling()
# Construct module
module = None
with torch.no_grad():
with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(hidden_size, hidden_size)
elif module_name == "LayerNormLinear":
module = LayerNormLinear(hidden_size, hidden_size)
elif module_name == "LayerNormMLP":
module = LayerNormMLP(hidden_size, hidden_size)
elif module_name == "GroupedLinear":
module = GroupedLinear(1, hidden_size, hidden_size)
elif module_name == "ops.Linear":
module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)
def check_weights():
"""Helper function to check that weight parameters have expected data"""
for param in module.parameters():
if isinstance(param, Float8Tensor):
assert param._data is not None, "Missing FP8 data"
assert (
param._transpose is None and param._transpose_invalid
), "FP8 transpose is not expected for inference"
if isinstance(param, MXFP8Tensor):
assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
assert (
param._columnwise_data is None
), "Column-wise MXFP8 data is not expected for inference"
# Check that modules have expected weights after initialization
check_weights()
# Check that modules have expected weights after forward pass
with torch.inference_mode():
x = torch.zeros(sequence_length, hidden_size, device="cuda")
kwargs = {}
if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length]
with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
y = module(x, **kwargs)
check_weights()
......@@ -7,6 +7,7 @@ from __future__ import annotations
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
......@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
......@@ -6,17 +6,62 @@
# pylint: disable=unused-import
import os
from importlib import metadata
import transformer_engine.common
try:
from . import pytorch
except ImportError as e:
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" not in str(e):
raise e # Unexpected error
else:
if os.getenv("NVTE_FRAMEWORK"):
frameworks = os.getenv("NVTE_FRAMEWORK").split(",")
if "pytorch" in frameworks or "all" in frameworks:
raise e
else:
# If we got here, we could import `torch` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.jax` on a system that
# also has a PyTorch installation. In order to enable that use case, we issue a warning here
# about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK.
import warnings
warnings.warn(
"Detected a PyTorch installation but could not find the shared object file for the "
"Transformer Engine PyTorch extension library. If this is not intentional, please "
"reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or "
"build from source with `NVTE_FRAMEWORK=pytorch`.",
category=RuntimeWarning,
)
try:
from . import jax
except ImportError as e:
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" not in str(e):
raise e # Unexpected error
else:
if os.getenv("NVTE_FRAMEWORK"):
frameworks = os.getenv("NVTE_FRAMEWORK").split(",")
if "jax" in frameworks or "all" in frameworks:
raise e
else:
# If we got here, we could import `jax` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.pytorch` on a system
# that also has a Jax installation. In order to enable that use case, we issue a warning here
# about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK.
import warnings
warnings.warn(
"Detected a Jax installation but could not find the shared object file for the "
"Transformer Engine Jax extension library. If this is not intentional, please "
"reinstall Transformer Engine with `pip install transformer_engine[jax]` or "
"build from source with `NVTE_FRAMEWORK=jax`.",
category=RuntimeWarning,
)
__version__ = str(metadata.version("transformer_engine"))
......@@ -30,7 +30,9 @@ endif()
# Language options
if(USE_CUDA)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
......@@ -149,6 +151,7 @@ if(USE_CUDA)
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
......@@ -201,6 +204,7 @@ else()
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
......
......@@ -4,25 +4,26 @@
"""FW agnostic user-end APIs"""
import sys
import glob
import sysconfig
import subprocess
import ctypes
import functools
import glob
import importlib
from importlib.metadata import version, metadata, PackageNotFoundError
import logging
import os
import platform
import importlib
import functools
from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
import platform
import subprocess
import sys
import sysconfig
from typing import Optional
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package):
def _is_pip_package_installed(package) -> bool:
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
......@@ -37,37 +38,37 @@ def _is_pip_package_installed(package):
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Find a shared object file with the given prefix within the top level TE directory.
The following locations are searched:
1. Top level directory (editable install).
2. `transformer_engine` directory (source install).
3. `wheel_lib` directory (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
# Ensure top level dir exists and has the module before searching.
if not te_path.is_dir() or not (te_path / "transformer_engine").exists():
return None
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
te_path, # Editable build.
te_path / "transformer_engine", # Regular source build.
te_path / "transformer_engine/wheel_lib", # PyPI.
)
# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
for dir_path in search_paths:
if not dir_path.is_dir():
continue
for file_path in dir_path.iterdir():
if file_path.name.startswith(prefix) and file_path.suffix == _get_sys_extension():
files.append(file_path)
if len(files) == 0:
return None
......@@ -79,16 +80,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
Path to shared object file for a Transformer Engine library.
TE libraries are 'core', 'torch', or 'jax'. This function first
searches in the imported TE directory, and then in the
site-packages directory.
"""
# Check provided input and determine the correct prefix for .so.
......@@ -98,47 +95,25 @@ def _get_shared_object_file(library: str) -> Path:
else:
so_prefix = f"transformer_engine_{library}"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Search for shared lib in imported directory
te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path = _find_shared_object_in_te_dir(te_path, so_prefix)
if so_path is not None:
return so_path
# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)
# Search for shared lib in site-packages directory
te_path = Path(sysconfig.get_paths()["purelib"])
so_path = _find_shared_object_in_te_dir(te_path, so_prefix)
if so_path is not None:
return so_path
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
raise FileNotFoundError(
f"Could not find shared object file for Transformer Engine {library} lib."
)
# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
def load_framework_extension(framework: str) -> None:
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
......@@ -196,19 +171,18 @@ def load_framework_extension(framework: str):
@functools.lru_cache(maxsize=None)
def _get_sys_extension():
def _get_sys_extension() -> str:
"""File extension for shared objects."""
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
return ".so"
if system == "Darwin":
return ".dylib"
if system == "Windows":
return ".dll"
raise RuntimeError(f"Unsupported operating system ({system})")
return extension
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
......@@ -221,7 +195,7 @@ def _load_nvidia_cuda_library(lib_name: str):
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)
......@@ -236,7 +210,7 @@ def _load_nvidia_cuda_library(lib_name: str):
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try:
......@@ -255,14 +229,14 @@ def _load_cudnn():
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
......@@ -273,7 +247,7 @@ def _load_cudnn():
return handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
......@@ -281,7 +255,7 @@ def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
......@@ -305,7 +279,7 @@ def _load_nvrtc():
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
......
......@@ -248,7 +248,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
if (param_type == NVTETensorParam::kNVTERowwiseData ||
param_type == NVTETensorParam::kNVTEColumnwiseData) {
// Offset data pointer
param_dptr += chunk_offset * typeToSize(param_dtype);
param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype);
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
......@@ -269,7 +269,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else {
chunk_scale_height /= 32;
}
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width};
}
......@@ -288,7 +288,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source
auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape);
// Update chunk with offset data pointers from the communication buffer
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size());
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + chunk_offset * _ubuf.element_size();
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape());
}
......@@ -326,7 +326,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
"or 2 (multi-atomic).");
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
......@@ -398,7 +398,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0));
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
......@@ -723,7 +723,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
if (_is_reduce_scatter) {
......@@ -827,7 +827,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
......@@ -885,21 +885,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
......@@ -929,7 +928,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const size_t n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
......@@ -945,7 +944,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
std::vector<size_t> output_chunk_shape = {2 * n_chunk, m};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
......@@ -976,12 +975,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape);
auto aux_chunk = (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2,
{n_chunk * 2, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
......@@ -1012,8 +1011,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
} else {
......@@ -1072,8 +1071,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
}
......@@ -1103,7 +1102,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
// Reset counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
......@@ -1170,7 +1169,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n_chunk = _ubufs[0].size(0);
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
// Get input and workspace data pointers
size_t input_chunk_size = n_chunk * k;
......
......@@ -248,7 +248,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
CUmemFabricHandle *tmphndl =
reinterpret_cast<CUmemFabricHandle *>(malloc(sizeof(CUmemFabricHandle)));
CUmemFabricHandle *exphndls;
NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle)));
NVTE_CHECK_CUDA(cudaMallocHost(reinterpret_cast<void **>(&exphndls),
(*comm)->nvsize * sizeof(CUmemFabricHandle)));
if ((*comm)->ar2_nvrank == 0)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast<void *>(tmphndl),
(*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
......@@ -345,8 +346,10 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true);
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(
cudaMalloc(reinterpret_cast<void **>(&(*comm)->send_id), (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(reinterpret_cast<void **>(&(*comm)->recv_id),
NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(
cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
......@@ -358,13 +361,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
NVTE_CHECK_CUDA(
cudaMalloc(reinterpret_cast<void **>(&(*comm)->flags_baseptr), 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags = reinterpret_cast<int *>(
#ifdef USE_ROCM
reinterpret_cast<int *>((reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
(reinterpret_cast<uintptr_t>((*comm)->flags_baseptr) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#else
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#endif
using namespace std;
......@@ -442,20 +446,31 @@ int create_communicator_mpi(communicator **comm) {
}
void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) {
// Clear memory allocated in register_user_buffer_collective calls
for (int hndl = comm->free_region - 1; hndl >= 0; hndl--) {
if (comm->use_mc && comm->mem_dealloc[hndl]) {
// Unbind the local device buffer from the Multicast handle
CUdevice dev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, comm->mydev);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastUnbind, comm->mc_handle, dev, comm->uc_offsets[hndl],
comm->mem_size[hndl]);
// Unmap memory addresses and release handles for both peer and own buffers
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank == comm->nvrank) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemUnmap,
reinterpret_cast<CUdeviceptr>(comm->peer_ptr[hndl][rank]),
comm->mem_size[hndl]);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
} else {
comm->uchandles[hndl][rank] = 0;
}
}
free(reinterpret_cast<void *>(comm->uchandles[hndl]));
// Free memory reserved for buffer allocations
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, comm->ucbase_ptr[hndl],
static_cast<size_t>(comm->mem_size[hndl] * comm->nvsize));
} else {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank != comm->nvrank) {
cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]);
NVTE_CHECK_CUDA(cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]));
} else if (comm->mem_dealloc[hndl]) {
NVTE_CHECK_CUDA(cudaFree(comm->peer_ptr[hndl][rank]));
} else {
......@@ -464,11 +479,16 @@ void destroy_communicator(communicator *comm) {
}
}
free(comm->peer_ptr[hndl]);
comm->mem_ptr[hndl] = nullptr;
comm->mem_ptr[hndl] = nullptr; // this points to already cleaned up local device buffer
}
cudaFree(reinterpret_cast<void *>(comm->recv_id));
cudaFree(reinterpret_cast<void *>(comm->send_id));
// Clear memory allocated in the communicator constructor
NVTE_CHECK_CUDA(cudaFree(reinterpret_cast<void *>(comm->recv_id)));
NVTE_CHECK_CUDA(cudaFree(reinterpret_cast<void *>(comm->send_id)));
NVTE_CHECK_CUDA(cudaFree(reinterpret_cast<void *>(comm->flags_baseptr)));
if (comm->use_mc) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemUnmap, reinterpret_cast<CUdeviceptr>(comm->mc_baseptr),
comm->mc_maxsize);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, comm->mc_baseptr, comm->mc_maxsize);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle);
}
delete comm;
......@@ -535,7 +555,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
CUmemFabricHandle myhndl;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl,
comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0);
NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle)));
NVTE_CHECK_CUDA(cudaMallocHost(reinterpret_cast<void **>(&exphndl),
comm->nvsize * sizeof(CUmemFabricHandle)));
comm->_allgather(reinterpret_cast<void *>(exphndl), comm->nvsize * sizeof(CUmemFabricHandle),
reinterpret_cast<void *>(&myhndl), sizeof(CUmemFabricHandle),
comm->comm_intra);
......@@ -619,6 +640,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
aligned_size, (uint64_t)0);
comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED;
comm->mc_ptr[hndl] = reinterpret_cast<char *>(comm->mc_baseptr) + comm->mc_offset;
comm->uc_offsets[hndl] = comm->mc_offset;
comm->mc_offset += aligned_size;
} else if (!comm->myrank) {
printf("UB: warning region %d size %ld MB registered without MC access\n", hndl,
......
......@@ -111,6 +111,7 @@ struct communicator {
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
#endif
void *ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t uc_offsets[NVTE_MAX_REGIONS];
size_t mem_size[NVTE_MAX_REGIONS];
bool mem_dealloc[NVTE_MAX_REGIONS];
......@@ -133,7 +134,7 @@ struct communicator {
// max value for running block counters in hostflags
int basecounter[userbuffers_op_types]; // NOLINT(*)
int *flags, *map_flags;
int *flags_baseptr, *flags, *map_flags;
void *mem_mr[NVTE_MAX_REGIONS];
......
......@@ -121,13 +121,20 @@ void checkCuDriverContext(CUstream stream) {
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = []() {
std::unordered_map<DType, CUtensorMapDataType> typeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
{DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
{DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
{DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
#if FP4_TYPE_SUPPORTED
typeMapping.insert(
{DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B});
#endif
return typeMapping;
}();
return dtypeMapping.at(dtype);
}
......@@ -135,18 +142,19 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size) {
const uint32_t offset_elems, const size_t type_num_bits) {
// Get a function pointer to the cuTensorMapEncodeTiled driver API
static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() {
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
return reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(driver_ptr);
}();
// rank is the number of dimensions of the array
constexpr uint32_t rank = 2;
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
uint64_t stride[rank - 1] = {stride_elems * type_size};
uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8};
// The boxSize is the size of the shared memory buffer that is used as the
// source/destination of a TMA transfer
......@@ -156,15 +164,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
uint32_t elemStride[rank] = {1, 1};
const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
void *dataPtr =
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size);
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
"Tensor data pointer must be 16B aligned");
const int TMA_needed_size = TMA_gmem_alignment / type_size;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size,
"-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
// Create the tensor descriptor.
NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
......@@ -209,10 +217,24 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
ret.back().push_back(convertNVTETensor(nvte_tensors[i][j]));
}
}
return ret;
}
size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) {
return (elements_num * typeToNumBits(buffer_dtype)) / 8;
}
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype) {
if (buffer_dtype == DType::kFloat4E2M1) {
NVTE_CHECK(dim_last % 2 == 0,
"Last dimension of a tensor with FP4 type of data must be an even number!");
}
const size_t elements_num = dim_first * dim_last;
return get_buffer_size_bytes(elements_num, buffer_dtype);
}
} // namespace transformer_engine
......@@ -9,9 +9,15 @@
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
......@@ -90,9 +96,16 @@ struct SimpleTensor {
}
return acc;
}
void clear() {
dptr = nullptr;
shape.resize(0);
dtype = DType::kFloat32;
}
};
struct Tensor {
public:
SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor amax;
......@@ -100,8 +113,8 @@ struct Tensor {
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
public:
NVTEScalingMode scaling_mode;
NVTETensor nvte_tensor;
Tensor()
: data(),
......@@ -110,7 +123,20 @@ struct Tensor {
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
nvte_tensor(0) {}
void clear() {
data.clear();
columnwise_data.clear();
amax.clear();
scale.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
}
explicit operator NVTETensor() const noexcept { return nvte_tensor; }
size_t numel() const {
size_t acc = 1;
......@@ -164,6 +190,7 @@ struct Tensor {
}
break;
case NVTE_MXFP8_1D_SCALING:
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
......@@ -233,11 +260,14 @@ struct QuantizationConfig {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0f;
NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
sizeof(float), // amax_epsilon
sizeof(NVTETensor) // noop_tensor
sizeof(NVTETensor), // noop_tensor
sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format
};
};
......@@ -246,6 +276,13 @@ constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
}
template <typename T1, typename T2>
constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) {
static_assert(std::is_integral<T1>::value && std::is_integral<T2>::value,
"Integral type required.");
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
......@@ -259,8 +296,10 @@ using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
using e8m0_t = uint8_t;
using int8 = int8_t;
namespace detail {
......@@ -284,11 +323,21 @@ TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME
template <typename T>
struct TypeExtrema;
#if FP4_TYPE_SUPPORTED
template <>
struct TypeExtrema<fp4e2m1> {
static constexpr float max = 6.0f;
};
#endif
template <>
struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f;
......@@ -323,9 +372,28 @@ struct TypeExtrema {
} // namespace detail
template <typename T>
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8>;
#endif
template <typename U, DType current>
struct Helper {
......@@ -350,11 +418,21 @@ struct TypeInfo {
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
constexpr static size_t size = BitsNumber<T>::num_bits;
constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
constexpr static const char *name = detail::type_name<T>();
};
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -398,6 +476,7 @@ struct TypeInfo {
using type = byte; \
{ __VA_ARGS__ } \
} break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type."); \
}
......@@ -559,6 +638,9 @@ struct TypeInfo {
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
case DType::kFloat4E2M1: { \
NVTE_ERROR("FP4 type not instantiated for input."); \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
......@@ -629,6 +711,14 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
template <typename T>
struct is_fp4 : std::false_type {};
#if FP4_TYPE_SUPPORTED
template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
......@@ -647,13 +737,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
}
size_t typeToSize(const DType type);
size_t typeToNumBits(const DType type);
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t);
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
......@@ -673,7 +766,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size);
const uint32_t offset_elems, const size_t type_num_bits);
#endif
bool is_supported_by_CC_100();
......@@ -681,6 +774,8 @@ bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment