Unverified Commit e963e4a9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add support for FP8 current scaling in operation-based API (#1858)



* Add FP8 current scaling to te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Helper function for test/ref tensors does not produce quantized tensor by default
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to distributed te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to Userbuffers te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug MXFP8 tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 655512c1
...@@ -22,19 +22,28 @@ import transformer_engine.common.recipe ...@@ -22,19 +22,28 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
# Check what quantization schemes are supported # Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.append("fp8") quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
quantization_list.append("mxfp8") quantization_list.append("mxfp8")
...@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None: ...@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32, test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda", test_device: torch.device = "cuda",
test_is_fp8: bool = False, test_is_quantized: bool = False,
requires_grad: bool = True, requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values """Construct tensors with the same values
...@@ -76,78 +86,55 @@ def make_reference_and_test_tensors( ...@@ -76,78 +86,55 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use operations in high precision. The test tensor is intended for use
in Transformer Engine operations. in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
""" """
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device), scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device), amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
test = quantizer(test) test = quantizer(test)
elif test.data_ptr() == ref.data_ptr(): elif quantization == "fp8_current_scaling":
test = test.clone() quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test) ref.copy_(test)
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad) test.requires_grad_(requires_grad)
return ref, test return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_all_reduce( def _test_all_reduce(
*, *,
local_size: int = 17, local_size: int = 32,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
fp8: bool = False, quantization: Optional[str] = None,
) -> None: ) -> None:
# Distributed process group # Distributed process group
...@@ -156,22 +143,25 @@ def _test_all_reduce( ...@@ -156,22 +143,25 @@ def _test_all_reduce(
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions # Tensor dimensions
in_shape = [world_size, local_size] in_shape = [world_size, local_size, local_size]
out_shape = [local_size] out_shape = [local_size, local_size]
# Random data # Random data
reset_rng() reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
# Plain PyTorch implementation # Plain PyTorch implementation
...@@ -199,10 +189,10 @@ def _test_all_reduce( ...@@ -199,10 +189,10 @@ def _test_all_reduce(
def _test_all_gather( def _test_all_gather(
*, *,
local_size: int = 13, local_size: int = 32,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
fp8: bool = False, quantization: Optional[str] = None,
) -> None: ) -> None:
# Distributed process group # Distributed process group
...@@ -211,26 +201,29 @@ def _test_all_gather( ...@@ -211,26 +201,29 @@ def _test_all_gather(
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions # Tensor dimensions
in_shape = [world_size, local_size] in_shape = [world_size, local_size, local_size]
out_shape = [world_size, world_size * local_size] out_shape = [world_size, world_size * local_size, local_size]
# Random data # Random data
reset_rng() reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
y_ref.backward(dy_ref) y_ref.backward(dy_ref)
# Convert to distributed tensors # Convert to distributed tensors
...@@ -257,10 +250,10 @@ def _test_all_gather( ...@@ -257,10 +250,10 @@ def _test_all_gather(
def _test_reduce_scatter( def _test_reduce_scatter(
*, *,
local_size: int = 11, local_size: int = 32,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
fp8: bool = False, quantization: Optional[str] = None,
) -> None: ) -> None:
# Distributed process group # Distributed process group
...@@ -269,22 +262,25 @@ def _test_reduce_scatter( ...@@ -269,22 +262,25 @@ def _test_reduce_scatter(
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions # Tensor dimensions
in_shape = [world_size, world_size * local_size] in_shape = [world_size, world_size * local_size, local_size]
out_shape = [world_size, local_size] out_shape = [world_size, local_size, local_size]
# Random data # Random data
reset_rng() reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
# Plain PyTorch implementation # Plain PyTorch implementation
...@@ -324,7 +320,11 @@ def _test_basic_linear( ...@@ -324,7 +320,11 @@ def _test_basic_linear(
tensor_parallel_mode: str = "column", tensor_parallel_mode: str = "column",
sequence_parallel: bool = False, sequence_parallel: bool = False,
) -> None: ) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -348,30 +348,23 @@ def _test_basic_linear( ...@@ -348,30 +348,23 @@ def _test_basic_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -468,7 +461,11 @@ def _test_linear( ...@@ -468,7 +461,11 @@ def _test_linear(
tensor_parallel_mode: str = "column", tensor_parallel_mode: str = "column",
sequence_parallel: bool = False, sequence_parallel: bool = False,
) -> None: ) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -492,21 +489,16 @@ def _test_linear( ...@@ -492,21 +489,16 @@ def _test_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -520,13 +512,11 @@ def _test_linear( ...@@ -520,13 +512,11 @@ def _test_linear(
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -773,9 +763,10 @@ def run_parallel_tests() -> None: ...@@ -773,9 +763,10 @@ def run_parallel_tests() -> None:
if rank == 0: if rank == 0:
print(f"Running _test_all_reduce") print(f"Running _test_all_reduce")
_test_all_reduce() _test_all_reduce()
if rank == 0: for quantization in quantization_list:
print(f"Running _test_all_gather") if rank == 0:
_test_all_gather() print(f"Running _test_all_gather with quantization={quantization}")
_test_all_gather(quantization=quantization)
if rank == 0: if rank == 0:
print(f"Running _test_reduce_scatter") print(f"Running _test_reduce_scatter")
_test_reduce_scatter() _test_reduce_scatter()
......
...@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.append("fp8") quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
quantization_list.append("mxfp8") quantization_list.append("mxfp8")
...@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None: ...@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32, test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda", test_device: torch.device = "cuda",
test_is_fp8: bool = False, test_is_quantized: bool = False,
requires_grad: bool = True, requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values """Construct tensors with the same values
...@@ -131,47 +136,49 @@ def make_reference_and_test_tensors( ...@@ -131,47 +136,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use operations in high precision. The test tensor is intended for use
in Transformer Engine operations. in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
""" """
# Random data # Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor # Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device), scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device), amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
test = quantizer(test) test = quantizer(test)
elif test.data_ptr() == ref.data_ptr(): elif quantization == "fp8_current_scaling":
test = test.clone() quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors represent exact same values # Make sure reference and test tensors match each other
ref.copy_(test) ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad) test.requires_grad_(requires_grad)
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear( def _test_linear(
*, *,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -201,21 +208,16 @@ def _test_linear( ...@@ -201,21 +208,16 @@ def _test_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -229,13 +231,11 @@ def _test_linear( ...@@ -229,13 +231,11 @@ def _test_linear(
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
......
...@@ -7,6 +7,8 @@ from __future__ import annotations ...@@ -7,6 +7,8 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import io import io
import math import math
import pathlib
import sys
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -24,10 +26,20 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -24,10 +26,20 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
) )
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
...@@ -40,6 +52,13 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher ...@@ -40,6 +52,13 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
# Supported devices # Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
_quantization_list.append("mxfp8")
def maybe_skip_quantization( def maybe_skip_quantization(
quantization: Optional[str], quantization: Optional[str],
...@@ -47,13 +66,14 @@ def maybe_skip_quantization( ...@@ -47,13 +66,14 @@ def maybe_skip_quantization(
dims: Optional[Iterable[int] | int] = None, dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None, device: Optional[torch.device | str] = None,
) -> None: ) -> None:
"""Skip test case if a quantization scheme is not supported"""
# Don't skip if there is no quantization # Don't skip if there is no quantization
if quantization is None: if quantization is None:
return return
# Check if quantization scheme is supported # Check if quantization scheme is supported
if quantization == "fp8" and not fp8_available: if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available: if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
...@@ -61,7 +81,7 @@ def maybe_skip_quantization( ...@@ -61,7 +81,7 @@ def maybe_skip_quantization(
if dims is not None: if dims is not None:
if not isinstance(dims, Iterable): if not isinstance(dims, Iterable):
dims = (dims,) dims = (dims,)
if quantization == "fp8": if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 16") pytest.skip("FP8 GEMMs require dims that are divisible by 16")
elif quantization == "mxfp8": elif quantization == "mxfp8":
...@@ -73,47 +93,15 @@ def maybe_skip_quantization( ...@@ -73,47 +93,15 @@ def maybe_skip_quantization(
pytest.skip("Quantization is only supported on CUDA devices") pytest.skip("Quantization is only supported on CUDA devices")
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32, test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda", test_device: torch.device = "cuda",
test_is_fp8: bool = False, test_is_quantized: bool = False,
requires_grad: bool = True, requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values """Construct tensors with the same values
...@@ -122,39 +110,49 @@ def make_reference_and_test_tensors( ...@@ -122,39 +110,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use operations in high precision. The test tensor is intended for use
in Transformer Engine operations. in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
""" """
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device), amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
test = quantizer(test) test = quantizer(test)
elif test.data_ptr() == ref.data_ptr(): elif quantization == "fp8_current_scaling":
test = test.clone() quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test) ref.copy_(test)
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad) test.requires_grad_(requires_grad)
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
class TestSequential: class TestSequential:
"""Tests for sequential container""" """Tests for sequential container"""
...@@ -364,7 +362,7 @@ class TestFuser: ...@@ -364,7 +362,7 @@ class TestFuser:
@pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("init_dtype", _dtypes)
@pytest.mark.parametrize("final_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
def test_dtype_cast( def test_dtype_cast(
self, self,
*, *,
...@@ -377,8 +375,9 @@ class TestFuser: ...@@ -377,8 +375,9 @@ class TestFuser:
"""Check dtype cast functions""" """Check dtype cast functions"""
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, device=device) in_shape = (size, size)
with_quantization = quantization is not None with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data # Random data
dtype = torch.float32 dtype = torch.float32
...@@ -388,9 +387,9 @@ class TestFuser: ...@@ -388,9 +387,9 @@ class TestFuser:
dtype = torch.bfloat16 dtype = torch.bfloat16
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(size, size), (size, size),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=with_quantization,
) )
# Construct operation # Construct operation
...@@ -412,11 +411,11 @@ class TestFuser: ...@@ -412,11 +411,11 @@ class TestFuser:
assert isinstance(op.weight, QuantizedTensor) == with_quantization assert isinstance(op.weight, QuantizedTensor) == with_quantization
assert op.weight.dtype == final_dtype assert op.weight.dtype == final_dtype
w_test = op.weight.to(dtype=torch.float64, device="cpu") w_test = op.weight.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
# Check forward and backward pass # Check forward and backward pass
x = torch.zeros( x = torch.zeros(
(size, size), in_shape,
dtype=init_dtype, dtype=init_dtype,
device=device, device=device,
requires_grad=True, requires_grad=True,
...@@ -429,7 +428,7 @@ class TestFuser: ...@@ -429,7 +428,7 @@ class TestFuser:
@pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("model_dtype", _dtypes)
@pytest.mark.parametrize("autocast_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
def test_pyt_autocast( def test_pyt_autocast(
self, self,
*, *,
...@@ -444,8 +443,9 @@ class TestFuser: ...@@ -444,8 +443,9 @@ class TestFuser:
device = torch.device(device) device = torch.device(device)
# Skip invalid configurations # Skip invalid configurations
in_shape = (size, size)
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization) maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Construct operation # Construct operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
...@@ -454,7 +454,7 @@ class TestFuser: ...@@ -454,7 +454,7 @@ class TestFuser:
# Check forward and backward pass # Check forward and backward pass
x = torch.zeros( x = torch.zeros(
(size, size), in_shape,
dtype=model_dtype, dtype=model_dtype,
device=device, device=device,
requires_grad=True, requires_grad=True,
...@@ -492,33 +492,34 @@ class TestBasicOps: ...@@ -492,33 +492,34 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", _quantization_list)
def test_identity( def test_identity(
self, self,
*, *,
in_shape: Iterable[int] = (1,), in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
# Skip invalid configurations # Skip invalid configurations
if fp8 and not fp8_available: with_quantization = quantization is not None
pytest.skip(reason_for_no_fp8) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_quantized=with_quantization,
requires_grad=False, requires_grad=False,
) )
...@@ -554,7 +555,7 @@ class TestBasicOps: ...@@ -554,7 +555,7 @@ class TestBasicOps:
), ),
) )
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
def test_reshape( def test_reshape(
self, self,
*, *,
...@@ -562,31 +563,32 @@ class TestBasicOps: ...@@ -562,31 +563,32 @@ class TestBasicOps:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = "cuda", device: torch.device = "cuda",
memory_format: torch.memory_format = torch.contiguous_format, memory_format: torch.memory_format = torch.contiguous_format,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
in_shape, out_shape = shapes in_shape, out_shape = shapes
# Skip invalid configurations # Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4: if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors") pytest.skip("torch.channels_last only supports 4D tensors")
if fp8 and not fp8_available: maybe_skip_quantization(quantization, device=device)
pytest.skip(reason_for_no_fp8) with_quantization = quantization is not None
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
x_test = x_test.contiguous(memory_format=memory_format) x_test = x_test.contiguous(memory_format=memory_format)
x_test = x_test.detach().requires_grad_() x_test = x_test.detach().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
x_ref.reshape(out_shape).size(), x_ref.reshape(out_shape).size(),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_quantized=with_quantization,
requires_grad=False, requires_grad=False,
) )
...@@ -615,10 +617,10 @@ class TestBasicOps: ...@@ -615,10 +617,10 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("size", (1, 7, 32)) @pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", _quantization_list)
def test_bias( def test_bias(
self, self,
*, *,
...@@ -626,24 +628,23 @@ class TestBasicOps: ...@@ -626,24 +628,23 @@ class TestBasicOps:
in_shape: Iterable[int], in_shape: Iterable[int],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
# Make input and bias shapes consistent # Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size] in_shape = list(in_shape)[:-1] + [size]
# Skip invalid configurations # Skip invalid configurations
if fp8 and not fp8_available: with_quantization = quantization is not None
pytest.skip(reason_for_no_fp8) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
b_ref, b_test = make_reference_and_test_tensors( b_ref, b_test = make_reference_and_test_tensors(
size, size,
...@@ -652,8 +653,10 @@ class TestBasicOps: ...@@ -652,8 +653,10 @@ class TestBasicOps:
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_quantized=with_quantization,
requires_grad=False, requires_grad=False,
) )
...@@ -678,7 +681,7 @@ class TestBasicOps: ...@@ -678,7 +681,7 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_forward", (False, True))
@pytest.mark.parametrize("cast_backward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True))
def test_quantize( def test_quantize(
...@@ -694,25 +697,26 @@ class TestBasicOps: ...@@ -694,25 +697,26 @@ class TestBasicOps:
"""Quantize""" """Quantize"""
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization) with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=True,
test_is_fp8=True,
) )
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
test_is_fp8=True,
) )
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = x_ref y_ref = x_ref
...@@ -721,13 +725,14 @@ class TestBasicOps: ...@@ -721,13 +725,14 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_autocast(fp8_recipe=recipe): with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
# Check tensor types # Check tensor types
assert isinstance(y_test, QuantizedTensor) == cast_forward if with_quantization:
assert isinstance(x_test.grad, QuantizedTensor) == cast_backward assert isinstance(y_test, QuantizedTensor) == cast_forward
assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
# Check values # Check values
tols = dict(rtol=0, atol=0) tols = dict(rtol=0, atol=0)
...@@ -762,10 +767,25 @@ class TestBasicOps: ...@@ -762,10 +767,25 @@ class TestBasicOps:
# Skip invalid configurations # Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantization == "fp8" and quantized_output and not quantized_compute: quantization_needed = any(
pytest.skip("FP8 output is only supported with FP8 GEMMs") (
if quantization == "fp8" and quantized_grad_input and not quantized_compute: quantized_compute,
pytest.skip("FP8 grad input is only supported with FP8 GEMMs") quantized_input,
quantized_weight,
quantized_output,
quantized_grad_output,
quantized_grad_input,
)
)
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if quantized_output and not quantized_compute:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output: if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input: if quantization == "mxfp8" and quantized_grad_input:
...@@ -774,28 +794,25 @@ class TestBasicOps: ...@@ -774,28 +794,25 @@ class TestBasicOps:
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_input), test_is_quantized=quantized_input,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_grad_output), test_is_quantized=quantized_grad_output,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -858,7 +875,7 @@ class TestBasicOps: ...@@ -858,7 +875,7 @@ class TestBasicOps:
@pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
def test_basic_linear( def test_basic_linear(
self, self,
...@@ -880,7 +897,7 @@ class TestBasicOps: ...@@ -880,7 +897,7 @@ class TestBasicOps:
) )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_input", (False, True)) @pytest.mark.parametrize("quantized_input", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
...@@ -899,6 +916,8 @@ class TestBasicOps: ...@@ -899,6 +916,8 @@ class TestBasicOps:
quantized_grad_input: bool, quantized_grad_input: bool,
) -> None: ) -> None:
"""GEMM with FP8 inputs and outputs""" """GEMM with FP8 inputs and outputs"""
if quantization is None:
pytest.skip("Skipping case without quantization")
self._test_basic_linear( self._test_basic_linear(
dtype=torch.bfloat16, dtype=torch.bfloat16,
quantization=quantization, quantization=quantization,
...@@ -911,7 +930,8 @@ class TestBasicOps: ...@@ -911,7 +930,8 @@ class TestBasicOps:
) )
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True))
...@@ -924,6 +944,7 @@ class TestBasicOps: ...@@ -924,6 +944,7 @@ class TestBasicOps:
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool, quantized_weight: bool,
input_requires_grad: bool, input_requires_grad: bool,
weight_requires_grad: bool, weight_requires_grad: bool,
...@@ -936,26 +957,25 @@ class TestBasicOps: ...@@ -936,26 +957,25 @@ class TestBasicOps:
out_shape = in_shape[:-1] + [out_features] out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations # Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape) maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
with torch.no_grad():
if isinstance(x_test, QuantizedTensor):
x_test = x_test.dequantize()
x_test.requires_grad_(requires_grad=input_requires_grad)
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
...@@ -966,6 +986,7 @@ class TestBasicOps: ...@@ -966,6 +986,7 @@ class TestBasicOps:
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
...@@ -1022,7 +1043,7 @@ class TestBasicOps: ...@@ -1022,7 +1043,7 @@ class TestBasicOps:
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True)) @pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
def test_layer_norm( def test_layer_norm(
self, self,
*, *,
...@@ -1192,7 +1213,7 @@ class TestBasicOps: ...@@ -1192,7 +1213,7 @@ class TestBasicOps:
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True)) @pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
def test_rmsnorm( def test_rmsnorm(
self, self,
*, *,
...@@ -1327,14 +1348,14 @@ class TestBasicOps: ...@@ -1327,14 +1348,14 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place( def test_add_in_place(
self, self,
*, *,
in_shape: Iterable[int] = (1,), in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
"""Add two tensors """Add two tensors
...@@ -1343,28 +1364,30 @@ class TestBasicOps: ...@@ -1343,28 +1364,30 @@ class TestBasicOps:
""" """
# Skip invalid configurations # Skip invalid configurations
if fp8 and not fp8_available: with_quantization = quantization is not None
pytest.skip(reason_for_no_fp8) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data # Random data
x1_ref, x1_test = make_reference_and_test_tensors( x1_ref, x1_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
x2_ref, x2_test = make_reference_and_test_tensors( x2_ref, x2_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_quantized=with_quantization,
requires_grad=False, requires_grad=False,
) )
...@@ -1381,7 +1404,7 @@ class TestBasicOps: ...@@ -1381,7 +1404,7 @@ class TestBasicOps:
# Check results # Check results
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if fp8: if with_quantization:
tols = dtype_tols(x1_test._fp8_dtype) tols = dtype_tols(x1_test._fp8_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
...@@ -1392,14 +1415,14 @@ class TestBasicOps: ...@@ -1392,14 +1415,14 @@ class TestBasicOps:
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", _quantization_list)
def test_make_extra_output( def test_make_extra_output(
self, self,
*, *,
in_shape: Iterable[int] = (1,), in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
"""Output tensor twice """Output tensor twice
...@@ -1408,28 +1431,31 @@ class TestBasicOps: ...@@ -1408,28 +1431,31 @@ class TestBasicOps:
""" """
# Skip invalid configurations # Skip invalid configurations
if fp8 and not fp8_available: with_quantization = quantization is not None
pytest.skip(reason_for_no_fp8) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy1_ref, dy1_test = make_reference_and_test_tensors( dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_quantized=with_quantization,
requires_grad=False, requires_grad=False,
) )
dy2_ref, dy2_test = make_reference_and_test_tensors( dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_quantized=with_quantization,
requires_grad=False, requires_grad=False,
) )
...@@ -1455,7 +1481,7 @@ class TestBasicOps: ...@@ -1455,7 +1481,7 @@ class TestBasicOps:
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cache_quantized_input", (False, True)) @pytest.mark.parametrize("cache_quantized_input", (False, True))
def test_activation( def test_activation(
self, self,
...@@ -1478,26 +1504,21 @@ class TestBasicOps: ...@@ -1478,26 +1504,21 @@ class TestBasicOps:
quantized_compute = quantization is not None quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=in_shape, device=device)
if cache_quantized_input: if cache_quantized_input:
maybe_skip_quantization("fp8", device=device) maybe_skip_quantization("fp8_current_scaling", device=device)
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization="fp8_current_scaling" if cache_quantized_input else None,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref: torch.Tensor y_ref: torch.Tensor
...@@ -1540,8 +1561,6 @@ class TestBasicOps: ...@@ -1540,8 +1561,6 @@ class TestBasicOps:
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if quantized_compute or cache_quantized_input: if quantized_compute or cache_quantized_input:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
if activation == "relu" and not cache_quantized_input:
tols = {"atol": 0, "rtol": 0}
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
...@@ -1550,7 +1569,7 @@ class TestBasicOps: ...@@ -1550,7 +1569,7 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True))
def test_swiglu( def test_swiglu(
...@@ -1628,7 +1647,7 @@ class TestFusedOps: ...@@ -1628,7 +1647,7 @@ class TestFusedOps:
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
def test_forward_linear_bias_activation( def test_forward_linear_bias_activation(
self, self,
...@@ -1660,18 +1679,15 @@ class TestFusedOps: ...@@ -1660,18 +1679,15 @@ class TestFusedOps:
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
...@@ -1682,6 +1698,7 @@ class TestFusedOps: ...@@ -1682,6 +1698,7 @@ class TestFusedOps:
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
...@@ -1738,7 +1755,7 @@ class TestFusedOps: ...@@ -1738,7 +1755,7 @@ class TestFusedOps:
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
def test_forward_linear_bias_add( def test_forward_linear_bias_add(
self, self,
*, *,
...@@ -1767,18 +1784,15 @@ class TestFusedOps: ...@@ -1767,18 +1784,15 @@ class TestFusedOps:
# Random data # Random data
x1_ref, x1_test = make_reference_and_test_tensors( x1_ref, x1_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x1_test, QuantizedTensor):
with torch.no_grad():
x1_test = x1_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
...@@ -1794,6 +1808,7 @@ class TestFusedOps: ...@@ -1794,6 +1808,7 @@ class TestFusedOps:
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
...@@ -1852,7 +1867,7 @@ class TestFusedOps: ...@@ -1852,7 +1867,7 @@ class TestFusedOps:
torch.testing.assert_close(db_test, b_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add( def test_backward_linear_add(
self, self,
*, *,
...@@ -1880,27 +1895,26 @@ class TestFusedOps: ...@@ -1880,27 +1895,26 @@ class TestFusedOps:
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
dy1_ref, dy1_test = make_reference_and_test_tensors( dy1_ref, dy1_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
) )
dy2_ref, dy2_test = make_reference_and_test_tensors( dy2_ref, dy2_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
...@@ -1964,7 +1978,7 @@ class TestCheckpointing: ...@@ -1964,7 +1978,7 @@ class TestCheckpointing:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear( def test_linear(
self, self,
......
...@@ -7,6 +7,7 @@ from __future__ import annotations ...@@ -7,6 +7,7 @@ from __future__ import annotations
import torch import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: ...@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
...@@ -947,7 +947,7 @@ def _all_gather_fp8( ...@@ -947,7 +947,7 @@ def _all_gather_fp8(
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(inp, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like( out._data = torch.empty(
out_shape, out_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=inp.device, device=inp.device,
......
...@@ -22,7 +22,7 @@ from ...distributed import ( ...@@ -22,7 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
...@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation): ...@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer # Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization # Note: Quantizer might have changed if quantization
# recipe changed # recipe changed
if isinstance(weight_quantizer, Float8Quantizer) and isinstance( if isinstance(
weight, Float8TensorBase weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer weight._quantizer = weight_quantizer
@staticmethod @staticmethod
......
...@@ -21,7 +21,7 @@ from ...module.base import ( ...@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
...@@ -208,7 +208,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -208,7 +208,9 @@ class UserbuffersForwardLinear(FusedOperation):
if input_quantizer is not None: if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase): if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer): if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -327,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -327,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = None grad_input_quantizer = None
if with_quantized_compute: if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8(): if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_output_quantizer = linear_op.get_quantizer("backward", 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment