Unverified Commit e963e4a9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add support for FP8 current scaling in operation-based API (#1858)



* Add FP8 current scaling to te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Helper function for test/ref tensors does not produce quantized tensor by default
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to distributed te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to Userbuffers te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug MXFP8 tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 655512c1
......@@ -22,19 +22,28 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
......@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -76,78 +86,55 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_all_reduce(
*,
local_size: int = 17,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -156,22 +143,25 @@ def _test_all_reduce(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
......@@ -199,10 +189,10 @@ def _test_all_reduce(
def _test_all_gather(
*,
local_size: int = 13,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -211,26 +201,29 @@ def _test_all_gather(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [world_size, world_size * local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [world_size, world_size * local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape)
y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
......@@ -257,10 +250,10 @@ def _test_all_gather(
def _test_reduce_scatter(
*,
local_size: int = 11,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -269,22 +262,25 @@ def _test_reduce_scatter(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, world_size * local_size]
out_shape = [world_size, local_size]
in_shape = [world_size, world_size * local_size, local_size]
out_shape = [world_size, local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
......@@ -324,7 +320,11 @@ def _test_basic_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
......@@ -348,30 +348,23 @@ def _test_basic_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -468,7 +461,11 @@ def _test_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
......@@ -492,21 +489,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -520,13 +512,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -773,9 +763,10 @@ def run_parallel_tests() -> None:
if rank == 0:
print(f"Running _test_all_reduce")
_test_all_reduce()
for quantization in quantization_list:
if rank == 0:
print(f"Running _test_all_gather")
_test_all_gather()
print(f"Running _test_all_gather with quantization={quantization}")
_test_all_gather(quantization=quantization)
if rank == 0:
print(f"Running _test_reduce_scatter")
_test_reduce_scatter()
......
......@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype
from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
......@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -131,47 +136,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random data
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors represent exact same values
# Make sure reference and test tensors match each other
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear(
*,
model_config: ModelConfig,
......@@ -201,21 +208,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -229,13 +231,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......
......@@ -7,6 +7,8 @@ from __future__ import annotations
from collections.abc import Iterable
import io
import math
import pathlib
import sys
from typing import Optional
import pytest
......@@ -24,10 +26,20 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
......@@ -40,6 +52,13 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
_quantization_list.append("mxfp8")
def maybe_skip_quantization(
quantization: Optional[str],
......@@ -47,13 +66,14 @@ def maybe_skip_quantization(
dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None,
) -> None:
"""Skip test case if a quantization scheme is not supported"""
# Don't skip if there is no quantization
if quantization is None:
return
# Check if quantization scheme is supported
if quantization == "fp8" and not fp8_available:
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
......@@ -61,7 +81,7 @@ def maybe_skip_quantization(
if dims is not None:
if not isinstance(dims, Iterable):
dims = (dims,)
if quantization == "fp8":
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
elif quantization == "mxfp8":
......@@ -73,47 +93,15 @@ def maybe_skip_quantization(
pytest.skip("Quantization is only supported on CUDA devices")
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -122,39 +110,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
class TestSequential:
"""Tests for sequential container"""
......@@ -364,7 +362,7 @@ class TestFuser:
@pytest.mark.parametrize("init_dtype", _dtypes)
@pytest.mark.parametrize("final_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_dtype_cast(
self,
*,
......@@ -377,8 +375,9 @@ class TestFuser:
"""Check dtype cast functions"""
# Skip invalid configurations
maybe_skip_quantization(quantization, device=device)
in_shape = (size, size)
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
dtype = torch.float32
......@@ -388,9 +387,9 @@ class TestFuser:
dtype = torch.bfloat16
w_ref, w_test = make_reference_and_test_tensors(
(size, size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=with_quantization,
)
# Construct operation
......@@ -412,11 +411,11 @@ class TestFuser:
assert isinstance(op.weight, QuantizedTensor) == with_quantization
assert op.weight.dtype == final_dtype
w_test = op.weight.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0)
torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
# Check forward and backward pass
x = torch.zeros(
(size, size),
in_shape,
dtype=init_dtype,
device=device,
requires_grad=True,
......@@ -429,7 +428,7 @@ class TestFuser:
@pytest.mark.parametrize("model_dtype", _dtypes)
@pytest.mark.parametrize("autocast_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_pyt_autocast(
self,
*,
......@@ -444,8 +443,9 @@ class TestFuser:
device = torch.device(device)
# Skip invalid configurations
in_shape = (size, size)
quantized_compute = quantization is not None
maybe_skip_quantization(quantization)
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Construct operation
recipe = make_recipe(quantization)
......@@ -454,7 +454,7 @@ class TestFuser:
# Check forward and backward pass
x = torch.zeros(
(size, size),
in_shape,
dtype=model_dtype,
device=device,
requires_grad=True,
......@@ -492,33 +492,34 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_identity(
self,
*,
in_shape: Iterable[int] = (1,),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -554,7 +555,7 @@ class TestBasicOps:
),
)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
def test_reshape(
self,
*,
......@@ -562,31 +563,32 @@ class TestBasicOps:
dtype: torch.dtype,
device: torch.device = "cuda",
memory_format: torch.memory_format = torch.contiguous_format,
fp8: bool,
quantization: Optional[str],
) -> None:
in_shape, out_shape = shapes
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
maybe_skip_quantization(quantization, device=device)
with_quantization = quantization is not None
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
x_test = x_test.contiguous(memory_format=memory_format)
x_test = x_test.detach().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
x_ref.reshape(out_shape).size(),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -615,10 +617,10 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_bias(
self,
*,
......@@ -626,24 +628,23 @@ class TestBasicOps:
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
# Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size]
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
b_ref, b_test = make_reference_and_test_tensors(
size,
......@@ -652,8 +653,10 @@ class TestBasicOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -678,7 +681,7 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cast_forward", (False, True))
@pytest.mark.parametrize("cast_backward", (False, True))
def test_quantize(
......@@ -694,25 +697,26 @@ class TestBasicOps:
"""Quantize"""
# Skip invalid configurations
maybe_skip_quantization(quantization)
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
test_is_fp8=True,
requires_grad=True,
)
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
test_is_fp8=True,
)
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = x_ref
......@@ -721,11 +725,12 @@ class TestBasicOps:
# Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization)
with te.fp8_autocast(fp8_recipe=recipe):
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Check tensor types
if with_quantization:
assert isinstance(y_test, QuantizedTensor) == cast_forward
assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
......@@ -762,9 +767,24 @@ class TestBasicOps:
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization == "fp8" and quantized_output and not quantized_compute:
quantization_needed = any(
(
quantized_compute,
quantized_input,
quantized_weight,
quantized_output,
quantized_grad_output,
quantized_grad_input,
)
)
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if quantized_output and not quantized_compute:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantization == "fp8" and quantized_grad_input and not quantized_compute:
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
......@@ -774,28 +794,25 @@ class TestBasicOps:
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_input),
test_is_quantized=quantized_input,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_grad_output),
test_is_quantized=quantized_grad_output,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -858,7 +875,7 @@ class TestBasicOps:
@pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
def test_basic_linear(
self,
......@@ -880,7 +897,7 @@ class TestBasicOps:
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_input", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
......@@ -899,6 +916,8 @@ class TestBasicOps:
quantized_grad_input: bool,
) -> None:
"""GEMM with FP8 inputs and outputs"""
if quantization is None:
pytest.skip("Skipping case without quantization")
self._test_basic_linear(
dtype=torch.bfloat16,
quantization=quantization,
......@@ -911,7 +930,8 @@ class TestBasicOps:
)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
......@@ -924,6 +944,7 @@ class TestBasicOps:
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
......@@ -936,26 +957,25 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
with torch.no_grad():
if isinstance(x_test, QuantizedTensor):
x_test = x_test.dequantize()
x_test.requires_grad_(requires_grad=input_requires_grad)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
......@@ -966,6 +986,7 @@ class TestBasicOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1022,7 +1043,7 @@ class TestBasicOps:
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_layer_norm(
self,
*,
......@@ -1192,7 +1213,7 @@ class TestBasicOps:
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_rmsnorm(
self,
*,
......@@ -1327,14 +1348,14 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place(
self,
*,
in_shape: Iterable[int] = (1,),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Add two tensors
......@@ -1343,28 +1364,30 @@ class TestBasicOps:
"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
x2_ref, x2_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -1381,7 +1404,7 @@ class TestBasicOps:
# Check results
tols = dtype_tols(dtype)
if fp8:
if with_quantization:
tols = dtype_tols(x1_test._fp8_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
......@@ -1392,14 +1415,14 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_make_extra_output(
self,
*,
in_shape: Iterable[int] = (1,),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Output tensor twice
......@@ -1408,28 +1431,31 @@ class TestBasicOps:
"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
......@@ -1455,7 +1481,7 @@ class TestBasicOps:
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cache_quantized_input", (False, True))
def test_activation(
self,
......@@ -1478,26 +1504,21 @@ class TestBasicOps:
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
if cache_quantized_input:
maybe_skip_quantization("fp8", device=device)
maybe_skip_quantization("fp8_current_scaling", device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization="fp8_current_scaling" if cache_quantized_input else None,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref: torch.Tensor
......@@ -1540,8 +1561,6 @@ class TestBasicOps:
tols = dtype_tols(dtype)
if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu" and not cache_quantized_input:
tols = {"atol": 0, "rtol": 0}
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
......@@ -1550,7 +1569,7 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_swiglu(
......@@ -1628,7 +1647,7 @@ class TestFusedOps:
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_forward_linear_bias_activation(
self,
......@@ -1660,18 +1679,15 @@ class TestFusedOps:
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
......@@ -1682,6 +1698,7 @@ class TestFusedOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1738,7 +1755,7 @@ class TestFusedOps:
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_forward_linear_bias_add(
self,
*,
......@@ -1767,18 +1784,15 @@ class TestFusedOps:
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x1_test, QuantizedTensor):
with torch.no_grad():
x1_test = x1_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
b_ref, b_test = None, None
if bias:
......@@ -1794,6 +1808,7 @@ class TestFusedOps:
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1852,7 +1867,7 @@ class TestFusedOps:
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add(
self,
*,
......@@ -1880,27 +1895,26 @@ class TestFusedOps:
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
......@@ -1964,7 +1978,7 @@ class TestCheckpointing:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
self,
......
......@@ -7,6 +7,7 @@ from __future__ import annotations
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
......@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
......@@ -947,7 +947,7 @@ def _all_gather_fp8(
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like(
out._data = torch.empty(
out_shape,
dtype=torch.uint8,
device=inp.device,
......
......@@ -22,7 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
......@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(weight_quantizer, Float8Quantizer) and isinstance(
weight, Float8TensorBase
):
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
@staticmethod
......
......@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter
......@@ -208,7 +208,9 @@ class UserbuffersForwardLinear(FusedOperation):
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer):
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -327,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8():
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment