Unverified Commit 37da2d3b authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add backward fusions of dbias+quantize and dbias+dactivation+quantize to `te.Sequential` (#1942)



* Fix clearing tensor data in backward removing is_first_op
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Misc fixes
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use Linear weight dtype and device for compute consistently
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add backward dbias + quantize fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass recipe to OperationFuser to allow recipe-dependent fusions
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove redundant view from activations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add bias activation backward fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent ac76d55c
......@@ -61,7 +61,6 @@ class ForwardLinearBiasActivation(FusedOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -71,10 +70,12 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias_op_ctx = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
......@@ -102,9 +103,10 @@ class ForwardLinearBiasActivation(FusedOperation):
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output, x_local, w = BasicLinear._functional_forward(
......@@ -133,7 +135,9 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = not is_first_op
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -59,7 +59,6 @@ class ForwardLinearBiasAdd(FusedOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -69,10 +68,12 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias_op_ctx = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
......@@ -95,9 +96,10 @@ class ForwardLinearBiasAdd(FusedOperation):
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0]
......@@ -105,6 +107,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input=input_,
weight=linear_op.weight,
bias=bias,
dtype=output.dtype,
out=output,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
......@@ -128,7 +131,9 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = not is_first_op
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -547,8 +547,7 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_bias = extra_outputs["grad_bias"]
# Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local)
clear_tensor_data(x_local)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
......@@ -569,13 +568,13 @@ def fuse_userbuffers_backward_linear(
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
Updated backward pass operations
"""
......
......@@ -23,7 +23,6 @@ from ...module.base import (
from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
......@@ -88,8 +87,8 @@ class UserbuffersForwardLinear(FusedOperation):
weight: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: torch.device,
dtype: torch.dtype,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
......@@ -112,9 +111,9 @@ class UserbuffersForwardLinear(FusedOperation):
Weight tensor
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
device: torch.device
Tensor device
dtype: torch.dtype, default = default dtype
dtype: torch.dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
......@@ -156,16 +155,10 @@ class UserbuffersForwardLinear(FusedOperation):
"""
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
......@@ -291,7 +284,6 @@ class UserbuffersForwardLinear(FusedOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -300,10 +292,12 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
bias_op = None
bias_op_ctx = None
bias = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
......@@ -330,9 +324,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Userbuffers options
if linear_op._userbuffers_options is None:
......@@ -344,6 +339,7 @@ class UserbuffersForwardLinear(FusedOperation):
weight=linear_op.weight,
bias=bias,
dtype=dtype,
device=linear_op.weight.device,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size,
......@@ -370,7 +366,9 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = not is_first_op
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -10,13 +10,14 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation,
fuse_backward_linear_add,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
......@@ -100,6 +101,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Mark input tensors as not deletable in backward
for tensor in (input_,) + params_and_extra_inputs:
tensor.do_not_clear = True
# Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
basic_op_extra_inputs = []
......@@ -110,6 +115,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:
......@@ -125,16 +131,15 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx > 0 else None
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None
if prev_op is not None:
if prev_op is not None and with_quantized_compute:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None
if next_op is not None:
if next_op is not None and with_quantized_compute:
next_op_input_quantizer = next_op.get_input_quantizer()
is_first_op = prev_op is None
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
......@@ -142,7 +147,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_extra_inputs=extra_inputs,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
......@@ -177,7 +181,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
......@@ -195,11 +198,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
x.requires_grad_(requires_grad)
if extra_outputs_flat:
return x, *extra_outputs_flat
x.requires_grad_(requires_grad)
return x
@staticmethod
......@@ -314,15 +317,19 @@ class OperationFuser:
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool, default = `True`
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
"""
def __init__(
self,
ops: list[FusibleOperation],
fuse_ops: bool = True,
fuse_ops: bool,
recipe: Optional[Recipe],
) -> None:
# Get list of basic operations
......@@ -348,6 +355,7 @@ class OperationFuser:
self._is_first_forward = True
# Fuse ops if needed
self.recipe = recipe
if fuse_ops:
self.fuse_ops()
......@@ -359,6 +367,7 @@ class OperationFuser:
def _fuse_forward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
......@@ -370,16 +379,18 @@ class OperationFuser:
def _fuse_backward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe)
return ops
def fuse_ops(self) -> None:
"""Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops)
self._backward_ops = self._fuse_backward_ops(self._backward_ops)
self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe)
self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe)
def __call__(
self,
......@@ -395,10 +406,8 @@ class OperationFuser:
# Initialization before forward pass
if self._is_first_forward:
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
for op in self._basic_ops:
op.pre_first_forward(recipe=recipe)
op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False
# Canonicalize op kwargs
......
......@@ -86,7 +86,6 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
"""Forward pass
......@@ -109,10 +108,6 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
basic_op_kwargs: list of dict
Keyword arguments to forward functions of basic
operations.
......@@ -421,7 +416,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
*,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
**kwargs: Any,
) -> torch.Tensor:
"""Forward pass
......@@ -436,10 +430,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
Returns
-------
......@@ -480,7 +470,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]:
if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
......@@ -495,7 +484,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
**basic_op_kwargs[0],
)
return output, [()]
......@@ -530,7 +518,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation"""
from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input,
*extra_inputs,
basic_op_kwargs=[kwargs],
......@@ -737,7 +727,9 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input,
*extra_inputs,
basic_op_kwargs=basic_op_kwargs,
......
......@@ -10,7 +10,7 @@ from typing import Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser
......@@ -147,6 +147,7 @@ class Sequential(torch.nn.Module):
def _make_module_groups(
cls,
modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together"""
......@@ -161,7 +162,7 @@ class Sequential(torch.nn.Module):
groups.append(module)
for idx, group in enumerate(groups):
if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True)
groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
......@@ -190,9 +191,9 @@ class Sequential(torch.nn.Module):
"""Forward pass"""
# Get current global state
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_enabled else None
global_state = (fp8_enabled, type(fp8_recipe))
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
......@@ -201,7 +202,7 @@ class Sequential(torch.nn.Module):
# Create module groups if needed
if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values())
self._module_groups = self._make_module_groups(self._modules.values(), recipe)
# Forward pass for each module group
x = input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment