Unverified Commit e963e4a9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add support for FP8 current scaling in operation-based API (#1858)



* Add FP8 current scaling to te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Helper function for test/ref tensors does not produce quantized tensor by default
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to distributed te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to Userbuffers te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug MXFP8 tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 655512c1
...@@ -22,19 +22,28 @@ import transformer_engine.common.recipe ...@@ -22,19 +22,28 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
# Check what quantization schemes are supported # Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.append("fp8") quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
quantization_list.append("mxfp8") quantization_list.append("mxfp8")
...@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None: ...@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32, test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda", test_device: torch.device = "cuda",
test_is_fp8: bool = False, test_is_quantized: bool = False,
requires_grad: bool = True, requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values """Construct tensors with the same values
...@@ -76,78 +86,55 @@ def make_reference_and_test_tensors( ...@@ -76,78 +86,55 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use operations in high precision. The test tensor is intended for use
in Transformer Engine operations. in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
""" """
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device), scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device), amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
test = quantizer(test) test = quantizer(test)
elif test.data_ptr() == ref.data_ptr(): elif quantization == "fp8_current_scaling":
test = test.clone() quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test) ref.copy_(test)
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad) test.requires_grad_(requires_grad)
return ref, test return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_all_reduce( def _test_all_reduce(
*, *,
local_size: int = 17, local_size: int = 32,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
fp8: bool = False, quantization: Optional[str] = None,
) -> None: ) -> None:
# Distributed process group # Distributed process group
...@@ -156,22 +143,25 @@ def _test_all_reduce( ...@@ -156,22 +143,25 @@ def _test_all_reduce(
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions # Tensor dimensions
in_shape = [world_size, local_size] in_shape = [world_size, local_size, local_size]
out_shape = [local_size] out_shape = [local_size, local_size]
# Random data # Random data
reset_rng() reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
# Plain PyTorch implementation # Plain PyTorch implementation
...@@ -199,10 +189,10 @@ def _test_all_reduce( ...@@ -199,10 +189,10 @@ def _test_all_reduce(
def _test_all_gather( def _test_all_gather(
*, *,
local_size: int = 13, local_size: int = 32,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
fp8: bool = False, quantization: Optional[str] = None,
) -> None: ) -> None:
# Distributed process group # Distributed process group
...@@ -211,26 +201,29 @@ def _test_all_gather( ...@@ -211,26 +201,29 @@ def _test_all_gather(
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions # Tensor dimensions
in_shape = [world_size, local_size] in_shape = [world_size, local_size, local_size]
out_shape = [world_size, world_size * local_size] out_shape = [world_size, world_size * local_size, local_size]
# Random data # Random data
reset_rng() reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
y_ref.backward(dy_ref) y_ref.backward(dy_ref)
# Convert to distributed tensors # Convert to distributed tensors
...@@ -257,10 +250,10 @@ def _test_all_gather( ...@@ -257,10 +250,10 @@ def _test_all_gather(
def _test_reduce_scatter( def _test_reduce_scatter(
*, *,
local_size: int = 11, local_size: int = 32,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
fp8: bool = False, quantization: Optional[str] = None,
) -> None: ) -> None:
# Distributed process group # Distributed process group
...@@ -269,22 +262,25 @@ def _test_reduce_scatter( ...@@ -269,22 +262,25 @@ def _test_reduce_scatter(
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions # Tensor dimensions
in_shape = [world_size, world_size * local_size] in_shape = [world_size, world_size * local_size, local_size]
out_shape = [world_size, local_size] out_shape = [world_size, local_size, local_size]
# Random data # Random data
reset_rng() reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8, test_is_quantized=with_quantization,
) )
# Plain PyTorch implementation # Plain PyTorch implementation
...@@ -324,7 +320,11 @@ def _test_basic_linear( ...@@ -324,7 +320,11 @@ def _test_basic_linear(
tensor_parallel_mode: str = "column", tensor_parallel_mode: str = "column",
sequence_parallel: bool = False, sequence_parallel: bool = False,
) -> None: ) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -348,30 +348,23 @@ def _test_basic_linear( ...@@ -348,30 +348,23 @@ def _test_basic_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -468,7 +461,11 @@ def _test_linear( ...@@ -468,7 +461,11 @@ def _test_linear(
tensor_parallel_mode: str = "column", tensor_parallel_mode: str = "column",
sequence_parallel: bool = False, sequence_parallel: bool = False,
) -> None: ) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -492,21 +489,16 @@ def _test_linear( ...@@ -492,21 +489,16 @@ def _test_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -520,13 +512,11 @@ def _test_linear( ...@@ -520,13 +512,11 @@ def _test_linear(
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -773,9 +763,10 @@ def run_parallel_tests() -> None: ...@@ -773,9 +763,10 @@ def run_parallel_tests() -> None:
if rank == 0: if rank == 0:
print(f"Running _test_all_reduce") print(f"Running _test_all_reduce")
_test_all_reduce() _test_all_reduce()
for quantization in quantization_list:
if rank == 0: if rank == 0:
print(f"Running _test_all_gather") print(f"Running _test_all_gather with quantization={quantization}")
_test_all_gather() _test_all_gather(quantization=quantization)
if rank == 0: if rank == 0:
print(f"Running _test_reduce_scatter") print(f"Running _test_reduce_scatter")
_test_reduce_scatter() _test_reduce_scatter()
......
...@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.append("fp8") quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
quantization_list.append("mxfp8") quantization_list.append("mxfp8")
...@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None: ...@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32, test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda", test_device: torch.device = "cuda",
test_is_fp8: bool = False, test_is_quantized: bool = False,
requires_grad: bool = True, requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values """Construct tensors with the same values
...@@ -131,47 +136,49 @@ def make_reference_and_test_tensors( ...@@ -131,47 +136,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use operations in high precision. The test tensor is intended for use
in Transformer Engine operations. in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
""" """
# Random data # Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor # Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device), scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device), amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
test = quantizer(test) test = quantizer(test)
elif test.data_ptr() == ref.data_ptr(): elif quantization == "fp8_current_scaling":
test = test.clone() quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors represent exact same values # Make sure reference and test tensors match each other
ref.copy_(test) ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad) test.requires_grad_(requires_grad)
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear( def _test_linear(
*, *,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -201,21 +208,16 @@ def _test_linear( ...@@ -201,21 +208,16 @@ def _test_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -229,13 +231,11 @@ def _test_linear( ...@@ -229,13 +231,11 @@ def _test_linear(
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
......
This diff is collapsed.
...@@ -7,6 +7,7 @@ from __future__ import annotations ...@@ -7,6 +7,7 @@ from __future__ import annotations
import torch import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: ...@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
...@@ -947,7 +947,7 @@ def _all_gather_fp8( ...@@ -947,7 +947,7 @@ def _all_gather_fp8(
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(inp, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like( out._data = torch.empty(
out_shape, out_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=inp.device, device=inp.device,
......
...@@ -22,7 +22,7 @@ from ...distributed import ( ...@@ -22,7 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
...@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation): ...@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer # Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization # Note: Quantizer might have changed if quantization
# recipe changed # recipe changed
if isinstance(weight_quantizer, Float8Quantizer) and isinstance( if isinstance(
weight, Float8TensorBase weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer weight._quantizer = weight_quantizer
@staticmethod @staticmethod
......
...@@ -21,7 +21,7 @@ from ...module.base import ( ...@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
...@@ -208,7 +208,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -208,7 +208,9 @@ class UserbuffersForwardLinear(FusedOperation):
if input_quantizer is not None: if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase): if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer): if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -327,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -327,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = None grad_input_quantizer = None
if with_quantized_compute: if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8(): if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_output_quantizer = linear_op.get_quantizer("backward", 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment