Unverified Commit e963e4a9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add support for FP8 current scaling in operation-based API (#1858)



* Add FP8 current scaling to te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Helper function for test/ref tensors does not produce quantized tensor by default
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to distributed te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 current scaling to Userbuffers te.Sequential tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug MXFP8 tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 655512c1
......@@ -22,19 +22,28 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
......@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -76,78 +86,55 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_all_reduce(
*,
local_size: int = 17,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -156,22 +143,25 @@ def _test_all_reduce(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
......@@ -199,10 +189,10 @@ def _test_all_reduce(
def _test_all_gather(
*,
local_size: int = 13,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -211,26 +201,29 @@ def _test_all_gather(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [world_size, world_size * local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [world_size, world_size * local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape)
y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
......@@ -257,10 +250,10 @@ def _test_all_gather(
def _test_reduce_scatter(
*,
local_size: int = 11,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -269,22 +262,25 @@ def _test_reduce_scatter(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, world_size * local_size]
out_shape = [world_size, local_size]
in_shape = [world_size, world_size * local_size, local_size]
out_shape = [world_size, local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
......@@ -324,7 +320,11 @@ def _test_basic_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
......@@ -348,30 +348,23 @@ def _test_basic_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -468,7 +461,11 @@ def _test_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
......@@ -492,21 +489,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -520,13 +512,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -773,9 +763,10 @@ def run_parallel_tests() -> None:
if rank == 0:
print(f"Running _test_all_reduce")
_test_all_reduce()
for quantization in quantization_list:
if rank == 0:
print(f"Running _test_all_gather")
_test_all_gather()
print(f"Running _test_all_gather with quantization={quantization}")
_test_all_gather(quantization=quantization)
if rank == 0:
print(f"Running _test_reduce_scatter")
_test_reduce_scatter()
......
......@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype
from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
......@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -131,47 +136,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random data
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors represent exact same values
# Make sure reference and test tensors match each other
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear(
*,
model_config: ModelConfig,
......@@ -201,21 +208,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -229,13 +231,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......
This diff is collapsed.
......@@ -7,6 +7,7 @@ from __future__ import annotations
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
......@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
......@@ -947,7 +947,7 @@ def _all_gather_fp8(
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty_like(
out._data = torch.empty(
out_shape,
dtype=torch.uint8,
device=inp.device,
......
......@@ -22,7 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
......@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(weight_quantizer, Float8Quantizer) and isinstance(
weight, Float8TensorBase
):
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
@staticmethod
......
......@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter
......@@ -208,7 +208,9 @@ class UserbuffersForwardLinear(FusedOperation):
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer):
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -327,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8():
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment