Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
......@@ -57,6 +57,9 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......@@ -93,7 +96,6 @@ class BackwardLinearAdd(FusedOperation):
grad_weight = None
# Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local)
return grad_input, [(grad_weight,), ()], [(), ()]
......@@ -107,13 +109,13 @@ def fuse_backward_linear_add(
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
Updated backward pass operations
"""
......
......@@ -13,11 +13,11 @@ import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearBiasActivation(FusedOperation):
......@@ -59,8 +59,8 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -70,10 +70,12 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias_op_ctx = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
......@@ -83,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
raise NotImplementedError("Activations are not yet supported")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
......@@ -96,18 +98,15 @@ class ForwardLinearBiasActivation(FusedOperation):
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_quantizers("forward") > 0:
output_quantizer = next_op.get_quantizer("forward", 0)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output, x_local, w = BasicLinear._functional_forward(
......@@ -136,7 +135,9 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -13,11 +13,11 @@ import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
class ForwardLinearBiasAdd(FusedOperation):
......@@ -57,8 +57,8 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -68,16 +68,18 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias_op_ctx = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
......@@ -91,14 +93,13 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0]
......@@ -106,6 +107,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input=input_,
weight=linear_op.weight,
bias=bias,
dtype=output.dtype,
out=output,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
......@@ -129,7 +131,9 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -20,10 +20,11 @@ from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
from .._common import maybe_dequantize, is_quantized_tensor
from ..op import FusedOperation, FusibleOperation, OperationContext
......@@ -279,7 +280,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Cast grad output tensor dtype if needed
dy_local = grad_output
if with_quantized_compute:
if not isinstance(dy_local, QuantizedTensorBase):
if not is_quantized_tensor(dy_local):
with_columnwise = weight_requires_grad
if (
with_columnwise
......@@ -293,24 +294,18 @@ class UserbuffersBackwardLinear(FusedOperation):
)
dy_local = grad_output_quantizer(dy_local)
else:
if isinstance(dy_local, QuantizedTensorBase):
dy_local = dy_local.dequantize(dtype=dtype)
elif dy_local.dtype != dtype:
dy_local = dy_local.to(dtype=dtype)
dy_local = maybe_dequantize(dy_local, dtype)
# Cast weight tensor dtype if needed
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = weight
if with_quantized_compute:
if not isinstance(w, QuantizedTensorBase):
if not is_quantized_tensor(w):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if isinstance(w, QuantizedTensorBase):
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
w = maybe_dequantize(w, dtype)
# Cast input tensor dtype if needed
x_local = None
......@@ -319,14 +314,11 @@ class UserbuffersBackwardLinear(FusedOperation):
raise ValueError("Input tensor is required to compute weight grad")
x_local = input
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
if not is_quantized_tensor(x_local):
input_quantizer.set_usage(columnwise=True)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
elif x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
# dgrad GEMM
dx_local = None
......@@ -433,7 +425,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise RuntimeError(
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
if isinstance(dy, QuantizedTensorBase):
if is_quantized_tensor(dy):
dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Initialize input tensor
......@@ -443,7 +435,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise RuntimeError(
"wgrad GEMM requires input tensor, which has not been initialized"
)
if isinstance(x, QuantizedTensorBase):
if is_quantized_tensor(x):
x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Check grad weight tensor
......@@ -516,6 +508,9 @@ class UserbuffersBackwardLinear(FusedOperation):
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......@@ -552,7 +547,6 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_bias = extra_outputs["grad_bias"]
# Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local)
# Return gradients
......@@ -574,13 +568,13 @@ def fuse_userbuffers_backward_linear(
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
Updated backward pass operations
"""
......
......@@ -20,13 +20,12 @@ from ...module.base import (
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
......@@ -88,8 +87,8 @@ class UserbuffersForwardLinear(FusedOperation):
weight: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: torch.device,
dtype: torch.dtype,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
......@@ -112,9 +111,9 @@ class UserbuffersForwardLinear(FusedOperation):
Weight tensor
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
device: torch.device
Tensor device
dtype: torch.dtype, default = default dtype
dtype: torch.dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
......@@ -156,16 +155,10 @@ class UserbuffersForwardLinear(FusedOperation):
"""
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
......@@ -206,7 +199,7 @@ class UserbuffersForwardLinear(FusedOperation):
x = None
if with_ub_all_gather:
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
if not is_quantized_tensor(x_local):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
......@@ -222,26 +215,20 @@ class UserbuffersForwardLinear(FusedOperation):
)
else:
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
if not is_quantized_tensor(x_local):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
x = x_local
# Initialize weight tensor
w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized:
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Construct output tensor if needed
reduce_scatter_output = None
......@@ -271,18 +258,14 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase):
if w is not weight and with_quantized_compute and is_quantized_tensor(w):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
......@@ -299,8 +282,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -309,16 +292,18 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
bias_op = None
bias_op_ctx = None
bias = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata
......@@ -336,14 +321,13 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Userbuffers options
if linear_op._userbuffers_options is None:
......@@ -355,6 +339,7 @@ class UserbuffersForwardLinear(FusedOperation):
weight=linear_op.weight,
bias=bias,
dtype=dtype,
device=linear_op.weight.device,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size,
......@@ -381,7 +366,9 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -10,19 +10,24 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation,
fuse_backward_linear_add,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
......@@ -96,6 +101,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Mark input tensors as not deletable in backward
for tensor in (input_,) + params_and_extra_inputs:
tensor.do_not_clear = True
# Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
basic_op_extra_inputs = []
......@@ -106,6 +115,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:
......@@ -117,31 +127,31 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
]
prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None
if prev_op is not None and with_quantized_compute:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None
if next_op is not None and with_quantized_compute:
next_op_input_quantizer = next_op.get_input_quantizer()
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
x,
basic_op_extra_inputs=extra_inputs,
basic_op_prev_ops=prev_ops,
basic_op_next_ops=next_ops,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
x.requires_grad_(requires_grad=requires_grad)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(requires_grad=requires_grad)
y.requires_grad_(requires_grad)
extra_outputs[idx] = ys
# Flatten list of extra outputs
......@@ -169,6 +179,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
range_end = len(to_save)
ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
else:
func_ctx.save_for_backward(*to_save)
# Other context
......@@ -179,9 +196,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
x.requires_grad_(requires_grad)
if extra_outputs_flat:
return x, *extra_outputs_flat
return x
@staticmethod
......@@ -198,8 +219,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_ops = func_ctx.basic_ops
basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors
# Restore saved tensors
if func_ctx.with_quantized_compute:
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
# Unflatten list of saved tensors
for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None
......@@ -291,15 +317,19 @@ class OperationFuser:
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool, default = `True`
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
"""
def __init__(
self,
ops: list[FusibleOperation],
fuse_ops: bool = True,
fuse_ops: bool,
recipe: Optional[Recipe],
) -> None:
# Get list of basic operations
......@@ -321,7 +351,11 @@ class OperationFuser:
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Flag for checking if this is the first iteration
self._is_first_forward = True
# Fuse ops if needed
self.recipe = recipe
if fuse_ops:
self.fuse_ops()
......@@ -333,6 +367,7 @@ class OperationFuser:
def _fuse_forward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
......@@ -344,16 +379,18 @@ class OperationFuser:
def _fuse_backward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe)
return ops
def fuse_ops(self) -> None:
"""Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops)
self._backward_ops = self._fuse_backward_ops(self._backward_ops)
self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe)
self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe)
def __call__(
self,
......@@ -368,8 +405,10 @@ class OperationFuser:
)
# Initialization before forward pass
if self._is_first_forward:
for op in self._basic_ops:
op.pre_forward()
op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False
# Canonicalize op kwargs
if basic_op_kwargs is None:
......
......@@ -65,17 +65,27 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops"""
def pre_forward(self) -> None:
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor"""
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
"""Forward pass
......@@ -94,12 +104,10 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
basic_op_prev_ops: list of BasicOperation
Basic operations that preceed this operation's basic
operations
basic_op_next_ops: list of BasicOperation
Basic operations that follow this operation's basic
operations
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
basic_op_kwargs: list of dict
Keyword arguments to forward functions of basic
operations.
......@@ -201,17 +209,23 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
return 0
def get_input_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("forward") > 0:
return self.get_quantizer("forward", 0)
return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0)
return None
def _reset_quantization_recipe_state(
self,
*,
recipe: Optional[Recipe] = None,
recipe: Recipe,
) -> None:
"""Construct state for quantization recipe"""
# Quantization recipe
if recipe is None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
# Quantization recipe state for forward and backward pass
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
......@@ -246,14 +260,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def _update_quantization_recipe_state(
self,
*,
recipe: Optional[Recipe] = None,
recipe: Recipe,
) -> None:
"""Make sure quantizer state matches quantization recipe"""
# Quantization recipe
if recipe is None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
# Reset quantization state if needed
if self._fp8_metas is None or self._quantizers is None:
self._reset_quantization_recipe_state(recipe=recipe)
......@@ -327,7 +337,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
if self._quantizers is None:
self._reset_quantization_recipe_state()
self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
return self._quantizers[mode][index]
@torch.no_grad()
......@@ -378,19 +388,16 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_forward(
def pre_first_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[Recipe] = None,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled:
self._update_quantization_recipe_state(recipe=fp8_recipe)
if recipe is not None:
self._update_quantization_recipe_state(recipe=recipe)
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_quantizers("forward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
......@@ -407,8 +414,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
**kwargs: Any,
) -> torch.Tensor:
"""Forward pass
......@@ -419,10 +426,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
prev_op: BasicOperation, optional
Basic operation that preceeds this operation
next_op: BasicOperation, optional
Basic operation that follows this operation
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
Returns
-------
......@@ -461,8 +468,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]:
if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
......@@ -475,8 +482,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward(
basic_op_ctxs[0],
input_,
prev_op=basic_op_prev_ops[0],
next_op=basic_op_next_ops[0],
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
**basic_op_kwargs[0],
)
return output, [()]
......@@ -511,7 +518,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation"""
from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input,
*extra_inputs,
basic_op_kwargs=[kwargs],
......@@ -621,7 +630,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state()
self._reset_quantization_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode]
# Load extra items
......@@ -696,10 +705,16 @@ class FusedOperation(FusibleOperation):
def is_fused_op(self) -> bool:
return True
def pre_forward(self) -> None:
def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer()
def pre_first_forward(self, *args, **kwargs) -> None:
"""Preprocessing before forward pass"""
for op in self.basic_ops:
op.pre_forward()
op.pre_first_forward(*args, **kwargs)
def forward(
self,
......@@ -712,7 +727,9 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input,
*extra_inputs,
basic_op_kwargs=basic_op_kwargs,
......
......@@ -10,6 +10,7 @@ from typing import Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser
......@@ -37,6 +38,9 @@ class Sequential(torch.nn.Module):
self._module_groups: Optional[list[OperationFuser | torch.nn.Module]]
self._module_groups = None
# Global state of last iteration
self._last_global_state = None
# Add modules
if len(args) == 1 and isinstance(args[0], dict):
for key, module in args[0].items():
......@@ -143,6 +147,7 @@ class Sequential(torch.nn.Module):
def _make_module_groups(
cls,
modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together"""
......@@ -157,7 +162,7 @@ class Sequential(torch.nn.Module):
groups.append(module)
for idx, group in enumerate(groups):
if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True)
groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
......@@ -185,9 +190,19 @@ class Sequential(torch.nn.Module):
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass"""
# Get current global state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
self._module_groups = None
self._last_global_state = global_state
# Create module groups if needed
if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values())
self._module_groups = self._make_module_groups(self._modules.values(), recipe)
# Forward pass for each module group
x = input
......
......@@ -349,7 +349,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if restore_shape is None:
restore_shape = inp.shape
num_tokens, hidden_size = restore_shape
num_experts = row_id_map.size(0)
num_experts = (row_id_map.size(1) - 1) // 2
with_probs = merging_probs is not None
if with_probs:
......@@ -651,14 +651,20 @@ class _moe_chunk_sort(torch.autograd.Function):
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx(
inp,
row_id_map = triton_permutation.make_chunk_sort_map(
split_sizes,
sorted_idxs,
num_tokens,
num_splits,
)
output, permuted_probs = triton_permutation.sort_chunks_by_map(
inp,
row_id_map,
probs,
num_tokens,
hidden_size,
num_splits,
is_forward=True,
)
if fp8:
output = Float8Tensor(
......@@ -700,6 +706,7 @@ class _moe_chunk_sort(torch.autograd.Function):
permuted_probs_grad,
ctx.num_tokens,
ctx.hidden_size,
is_forward=False,
)
if fp8:
act_grad = Float8Tensor(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Fused functions used in the MoE router
"""
import torch
import transformer_engine_torch as tex
class FusedTopkScoreFunction(torch.autograd.Function):
"""
Fused Topk with Score Function router.
Currently, only support softmax and sigmoid.
"""
@staticmethod
def forward(
ctx,
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool,
num_groups: int,
group_topk: int,
scaling_factor: float,
score_function: str,
expert_bias: torch.Tensor,
):
# pylint: disable=missing-function-docstring
# Save the shape of the logits
tensor_shape = logits.shape
logits = logits.view(-1, tensor_shape[-1])
# Get the metadata of the viewed logits
num_tokens = logits.size(0)
num_experts = logits.size(1)
probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd(
logits,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
expert_bias,
)
# Restore the shape
probs = probs.view(tensor_shape)
ctx.save_for_backward(routing_map, intermediate_output)
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
ctx.use_pre_softmax = use_pre_softmax
ctx.topk = topk
ctx.scaling_factor = scaling_factor
ctx.score_function = score_function
return probs, routing_map
@staticmethod
def backward(ctx, grad_probs, _):
# pylint: disable=missing-function-docstring
routing_map, intermediate_output = ctx.saved_tensors
# Save the shape of the grad_probs
tensor_shape = grad_probs.shape
# Adjust the shape of the grad_probs to 2D shape
grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1])
grad_logits = tex.fused_topk_with_score_function_bwd(
ctx.num_tokens,
ctx.num_experts,
routing_map,
intermediate_output,
grad_probs,
ctx.topk,
ctx.use_pre_softmax,
ctx.scaling_factor,
ctx.score_function,
)
# Restore the shape
grad_logits = grad_logits.view(tensor_shape)
return grad_logits, None, None, None, None, None, None, None
def fused_topk_with_score_function(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool,
num_groups: int,
group_topk: int,
scaling_factor: float,
score_function: str,
expert_bias: torch.Tensor,
):
"""
Fused topk with score function router.
Parameters
----------
logits: torch.Tensor
topk: int
use_pre_softmax: bool
if enabled, the computation order: softmax -> topk
num_groups: int
used in the group topk
group_topk: int
used in the group topk
scaling_factor: float
score_function: str
currently only support softmax and sigmoid
expert_bias: torch.Tensor
could be used in the sigmoid
Returns
-------
probs: torch.Tensor
routing_map: torch.Tensor
"""
if logits.dtype == torch.float64:
raise ValueError("Current TE does not support float64 router type")
return FusedTopkScoreFunction.apply(
logits,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
expert_bias,
)
class FusedComputeScoresForMoEAuxLoss(torch.autograd.Function):
"""
Fused compute scores for MoE aux loss.
"""
@staticmethod
def forward(
ctx,
logits: torch.Tensor,
topk: int,
score_function: str,
):
# pylint: disable=missing-function-docstring
# Save the shape of the logits
tensor_shape = logits.shape
logits = logits.view(-1, tensor_shape[-1])
# Get the metadata of the viewed logits
num_tokens = logits.size(0)
num_experts = logits.size(1)
scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd(
logits=logits,
topk=topk,
score_function=score_function,
)
ctx.save_for_backward(intermediate_output)
ctx.topk = topk
ctx.score_function = score_function
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
return routing_map, scores
@staticmethod
def backward(ctx, _, grad_scores):
# pylint: disable=missing-function-docstring
intermediate_output = ctx.saved_tensors[0]
# Save the shape of the grad_scores
tensor_shape = grad_scores.shape
# Adjust the shape of the grad_scores to 2D shape
grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1])
grad_logits = tex.fused_score_for_moe_aux_loss_bwd(
num_tokens=ctx.num_tokens,
num_experts=ctx.num_experts,
intermediate_output=intermediate_output,
grad_scores=grad_scores,
topk=ctx.topk,
score_function=ctx.score_function,
)
# Restore the shape
grad_logits = grad_logits.view(tensor_shape)
return grad_logits, None, None
def fused_compute_score_for_moe_aux_loss(
logits: torch.Tensor,
topk: int,
score_function: str,
):
"""
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters
----------
logits: torch.Tensor
topk: int
score_function: str
currently only support softmax and sigmoid
Returns
-------
routing_map: torch.Tensor
scores: torch.Tensor
"""
return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function)
class FusedAuxLoss(torch.autograd.Function):
"""
Fused MoE aux loss.
"""
@staticmethod
def forward(
ctx,
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
num_experts: int,
topk: int,
coeff: float,
):
# pylint: disable=missing-function-docstring
num_rows = probs.size(0)
num_cols = probs.size(1)
aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=total_num_tokens,
num_experts=num_experts,
num_rows=num_rows,
num_cols=num_cols,
topk=topk,
coeff=coeff,
)
ctx.save_for_backward(Const_buf, tokens_per_expert)
ctx.num_rows = num_rows
ctx.num_cols = num_cols
return aux_loss
@staticmethod
def backward(ctx, grad_aux_loss):
# pylint: disable=missing-function-docstring
Const_buf, tokens_per_expert = ctx.saved_tensors
grad_probs = tex.fused_moe_aux_loss_bwd(
Const_buf=Const_buf,
tokens_per_expert=tokens_per_expert,
num_rows=ctx.num_rows,
num_cols=ctx.num_cols,
grad_aux_loss=grad_aux_loss,
)
return grad_probs, None, None, None, None, None
def fused_moe_aux_loss(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
num_experts: int,
topk: int,
coeff: float,
):
"""
Fused MoE aux loss.
Parameters
----------
probs: torch.Tensor
tokens_per_expert: torch.Tensor
the number of tokens per expert
total_num_tokens: int
the total number of tokens, involved in the aux loss calculation
num_experts: int
topk: int
coeff: float
the coefficient of the aux loss
Returns
-------
aux_loss: torch.scalar
"""
return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)
......@@ -43,7 +43,6 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
......@@ -51,9 +50,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
data_format: Float8BlockScaleTensorFormat,
*args,
**kwargs,
):
if cls is Float8BlockwiseQTensorBase:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
......
......@@ -143,6 +143,23 @@ class Float8TensorBase(QuantizedTensorBase):
size = self._transpose.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])])
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
out_data = self._data.view(shape)
out_transpose = None if self._transpose_invalid else self._transpose
if out_transpose is not None:
out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]:
out_transpose = None
return Float8TensorBase(
data=out_data,
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
data_transpose=out_transpose,
quantizer=self._quantizer,
)
def __repr__(self):
return (
"Float8TensorBase("
......
......@@ -6,6 +6,8 @@
from __future__ import annotations
from typing import Optional, Dict, Any, Tuple
from collections.abc import Iterable
import math
import torch
import transformer_engine_torch as tex
......@@ -66,15 +68,18 @@ class MXFP8TensorBase(QuantizedTensorBase):
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: torch.Tensor,
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: torch.Tensor,
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Optional[Quantizer] = None,
quantizer: Optional[Quantizer],
*args,
**kwargs,
):
if cls is MXFP8TensorBase:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
......@@ -145,6 +150,51 @@ class MXFP8TensorBase(QuantizedTensorBase):
return self._rowwise_data.size(*args, **kwargs)
return self._columnwise_data.size(*args, **kwargs)
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape = self.size()
if shape is None or shape == cur_shape:
return self
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"MXFP8Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})"
)
# Construct new tensor
cur_rowwise_data = self._rowwise_data
cur_columnwise_data = self._columnwise_data
new_rowwise_data = None
new_columnwise_data = None
if cur_rowwise_data is not None:
new_rowwise_data = cur_rowwise_data.view(*shape)
if cur_columnwise_data is not None:
new_columnwise_data = cur_columnwise_data.view(*shape)
return MXFP8TensorBase(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=self._columnwise_scale_inv,
fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer,
)
def __repr__(self):
data_rowwise = self.dequantize()
......
......@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex
import os
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
......@@ -297,6 +298,37 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes.
"""
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
fp8_dtype,
quantizer,
is_2D_scaled,
data_format,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
......
......@@ -167,6 +167,21 @@ class Float8Quantizer(Quantizer):
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
data = torch.ops.tex.fp8_quantize(tensor, self.scale.item())
return self.create_tensor_from_data(data, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item())
out = out.to(tensor.dtype)
return out
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
......@@ -328,6 +343,18 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
......
......@@ -136,6 +136,34 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def create_tensor_from_data(
self,
data: torch.Tensor,
scale_inv: torch.Tensor,
fake_dtype: torch.dtype,
fp8_dtype: TE_DType = tex.DType.kFloat8E4M3,
) -> MXFP8Tensor:
"""Create a new MXFP8Tensor from data and scale_inv."""
return MXFP8Tensor(
shape=data.shape,
dtype=fake_dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
if tensor.dtype != torch.float32:
tensor = tensor.to(dtype=torch.float32)
data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor)
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor:
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
......@@ -165,6 +193,32 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Optional[Quantizer],
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
fp8_dtype,
quantizer,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
......@@ -302,6 +356,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
fp8_dtype: TE_DType,
dtype: torch.dtype,
shape: torch.shape,
quantizer: Optional[Quantizer] = None,
) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__
......@@ -317,6 +372,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype,
shape=shape,
quantizer=quantizer,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -331,6 +387,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._fp8_dtype,
self.dtype,
self.shape,
self._quantizer,
),
)
......@@ -437,8 +494,7 @@ class _ViewFunc(torch.autograd.Function):
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.view(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
new_columnwise_data = tensor._columnwise_data.view(*shape)
return MXFP8Tensor(
shape,
tensor.dtype,
......@@ -462,7 +518,7 @@ class _ViewFunc(torch.autograd.Function):
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
if grad._columnwise_data is not None:
new_columnwise_data = grad._columnwise_data.view(ctx.shape[-1], -1)
new_columnwise_data = grad._columnwise_data.view(*ctx.shape)
else:
new_columnwise_data = None
dgrad = MXFP8Tensor(
......@@ -523,8 +579,7 @@ class _ReshapeFunc(torch.autograd.Function):
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.reshape(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
new_columnwise_data = tensor._columnwise_data.view(*shape)
return MXFP8Tensor(
shape,
......@@ -550,8 +605,7 @@ class _ReshapeFunc(torch.autograd.Function):
if grad._rowwise_data is not None:
new_rowwise_data = grad._rowwise_data.view(*ctx.shape)
if grad._columnwise_data is not None:
columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1])
new_columnwise_data = grad._columnwise_data.view(columnwise_shape)
new_columnwise_data = grad._columnwise_data.view(*ctx.shape)
dgrad = MXFP8Tensor(
ctx.shape,
grad.dtype,
......
......@@ -250,6 +250,12 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
return copy.copy(self)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
def onnx_dequantize(self, tensor) -> torch.Tensor:
"""Symbolic function for ONNX export"""
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
......
......@@ -194,7 +194,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
quantizer.update_quantized(master_weight.view(1, -1), shard_model_weight_fp8)
if len(amaxes) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=amaxes[0].device)
dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=amaxes[0].device)
# Reduce amaxes.
packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
......
......@@ -33,6 +33,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type,
)
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -814,7 +815,12 @@ class TransformerLayer(torch.nn.Module):
return output
def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias is not None and bias.numel() != 0:
if (
drop_path is None
and bias is not None
and bias.numel() != 0
and not is_in_onnx_export_mode()
):
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
......
......@@ -98,6 +98,7 @@ def cross_entropy_kernel(
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
......@@ -177,7 +178,13 @@ def cross_entropy_kernel(
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
......@@ -205,7 +212,11 @@ def cross_entropy_kernel(
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)
tl.store(loss_ptr, loss)
......@@ -319,6 +330,7 @@ def cross_entropy_forward(
ignore_idx=ignore_idx,
n_cols=V,
n_non_ignore=n_rows,
reduce_loss=reduce_loss,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=16 if IS_HIP_EXTENSION else 32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment