Unverified Commit 37da2d3b authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add backward fusions of dbias+quantize and dbias+dactivation+quantize to `te.Sequential` (#1942)



* Fix clearing tensor data in backward removing is_first_op
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Misc fixes
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use Linear weight dtype and device for compute consistently
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add backward dbias + quantize fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass recipe to OperationFuser to allow recipe-dependent fusions
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove redundant view from activations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add bias activation backward fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent ac76d55c
...@@ -61,7 +61,6 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -61,7 +61,6 @@ class ForwardLinearBiasActivation(FusedOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -71,10 +70,12 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -71,10 +70,12 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx = basic_op_ctxs[idx] linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None: if self._op_idxs["bias"] is None:
bias_op = None bias_op = None
bias_op_ctx = None
bias = None bias = None
else: else:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx] bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias bias = bias_op.bias
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
...@@ -102,9 +103,10 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -102,9 +103,10 @@ class ForwardLinearBiasActivation(FusedOperation):
grad_input_quantizer = prev_op_grad_input_quantizer grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward # Linear forward
output, x_local, w = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
...@@ -133,7 +135,9 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -133,7 +135,9 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = not is_first_op if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -59,7 +59,6 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -59,7 +59,6 @@ class ForwardLinearBiasAdd(FusedOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -69,10 +68,12 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -69,10 +68,12 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[idx] linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None: if self._op_idxs["bias"] is None:
bias_op = None bias_op = None
bias_op_ctx = None
bias = None bias = None
else: else:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx] bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias bias = bias_op.bias
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
...@@ -95,9 +96,10 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -95,9 +96,10 @@ class ForwardLinearBiasAdd(FusedOperation):
grad_input_quantizer = prev_op_grad_input_quantizer grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward # Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0] output = basic_op_extra_inputs[self._op_idxs["add"]][0]
...@@ -105,6 +107,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -105,6 +107,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
dtype=output.dtype,
out=output, out=output,
accumulate_into_out=True, accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_mode=linear_op.tensor_parallel_mode,
...@@ -128,7 +131,9 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -128,7 +131,9 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = not is_first_op if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -547,7 +547,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -547,7 +547,6 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_bias = extra_outputs["grad_bias"] grad_bias = extra_outputs["grad_bias"]
# Clear input tensor if possible # Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Return gradients # Return gradients
...@@ -569,13 +568,13 @@ def fuse_userbuffers_backward_linear( ...@@ -569,13 +568,13 @@ def fuse_userbuffers_backward_linear(
Parameters Parameters
---------- ----------
ops: list of tuples ops: list of tuples
Forward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops: list of tuples
Updated forward pass operations Updated backward pass operations
""" """
......
...@@ -23,7 +23,6 @@ from ...module.base import ( ...@@ -23,7 +23,6 @@ from ...module.base import (
from ...tensor.quantized_tensor import Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from .._common import maybe_dequantize, is_quantized_tensor from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import ( from ..op import (
...@@ -88,8 +87,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -88,8 +87,8 @@ class UserbuffersForwardLinear(FusedOperation):
weight: torch.Tensor, weight: torch.Tensor,
*, *,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None, device: torch.device,
dtype: Optional[torch.dtype] = None, dtype: torch.dtype,
tensor_parallel_mode: Optional[str] = None, tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None, tensor_parallel_size: Optional[int] = None,
...@@ -112,9 +111,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -112,9 +111,9 @@ class UserbuffersForwardLinear(FusedOperation):
Weight tensor Weight tensor
bias: torch.Tensor, optional bias: torch.Tensor, optional
Bias tensor Bias tensor
device: torch.device, default = default CUDA device device: torch.device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype: torch.dtype
Tensor datatype Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism Mode for tensor parallelism
...@@ -156,16 +155,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -156,16 +155,10 @@ class UserbuffersForwardLinear(FusedOperation):
""" """
# Check device # Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda": if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})") raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype # Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
...@@ -291,7 +284,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -291,7 +284,6 @@ class UserbuffersForwardLinear(FusedOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -300,10 +292,12 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -300,10 +292,12 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op = self.basic_ops[idx] linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx] linear_op_ctx = basic_op_ctxs[idx]
bias_op = None bias_op = None
bias_op_ctx = None
bias = None bias = None
if self._op_idxs["bias"] is not None: if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx] bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias bias = bias_op.bias
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
...@@ -330,9 +324,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -330,9 +324,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = prev_op_grad_input_quantizer grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Userbuffers options # Userbuffers options
if linear_op._userbuffers_options is None: if linear_op._userbuffers_options is None:
...@@ -344,6 +339,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -344,6 +339,7 @@ class UserbuffersForwardLinear(FusedOperation):
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
device=linear_op.weight.device,
tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
...@@ -370,7 +366,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -370,7 +366,9 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = not is_first_op if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -10,13 +10,14 @@ from typing import Any, Optional ...@@ -10,13 +10,14 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation,
fuse_backward_linear_add, fuse_backward_linear_add,
fuse_forward_linear_bias_activation, fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add, fuse_forward_linear_bias_add,
...@@ -100,6 +101,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -100,6 +101,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Operation autograd contexts # Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Mark input tensors as not deletable in backward
for tensor in (input_,) + params_and_extra_inputs:
tensor.do_not_clear = True
# Unflatten list of parameters and extra tensor inputs # Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
basic_op_extra_inputs = [] basic_op_extra_inputs = []
...@@ -110,6 +115,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -110,6 +115,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = is_grad_enabled and x.requires_grad requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops: for op, basic_op_idxs in fuser._forward_ops:
...@@ -125,16 +131,15 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -125,16 +131,15 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_op_idx = basic_op_idxs[0] - 1 prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx > 0 else None prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None prev_op_grad_input_quantizer = None
if prev_op is not None: if prev_op is not None and with_quantized_compute:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer() prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
next_op_idx = basic_op_idxs[-1] + 1 next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None next_op_input_quantizer = None
if next_op is not None: if next_op is not None and with_quantized_compute:
next_op_input_quantizer = next_op.get_input_quantizer() next_op_input_quantizer = next_op.get_input_quantizer()
is_first_op = prev_op is None
x, fused_op_extra_outputs = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
...@@ -142,7 +147,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -142,7 +147,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_extra_inputs=extra_inputs, basic_op_extra_inputs=extra_inputs,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer, next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
) )
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
...@@ -177,7 +181,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -177,7 +181,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end) ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward # Save tensors for backward
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute: if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save) tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save) func_ctx.save_for_backward(*tensors_to_save)
...@@ -195,11 +198,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -195,11 +198,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute func_ctx.with_quantized_compute = with_quantized_compute
x.requires_grad_(requires_grad)
if extra_outputs_flat: if extra_outputs_flat:
return x, *extra_outputs_flat return x, *extra_outputs_flat
x.requires_grad_(requires_grad)
return x return x
@staticmethod @staticmethod
...@@ -314,15 +317,19 @@ class OperationFuser: ...@@ -314,15 +317,19 @@ class OperationFuser:
---------- ----------
ops: list of FusibleOperation ops: list of FusibleOperation
Pipeline of operations Pipeline of operations
fuse_ops: bool, default = `True` fuse_ops: bool
Whether to attempt fusing operations Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
""" """
def __init__( def __init__(
self, self,
ops: list[FusibleOperation], ops: list[FusibleOperation],
fuse_ops: bool = True, fuse_ops: bool,
recipe: Optional[Recipe],
) -> None: ) -> None:
# Get list of basic operations # Get list of basic operations
...@@ -348,6 +355,7 @@ class OperationFuser: ...@@ -348,6 +355,7 @@ class OperationFuser:
self._is_first_forward = True self._is_first_forward = True
# Fuse ops if needed # Fuse ops if needed
self.recipe = recipe
if fuse_ops: if fuse_ops:
self.fuse_ops() self.fuse_ops()
...@@ -359,6 +367,7 @@ class OperationFuser: ...@@ -359,6 +367,7 @@ class OperationFuser:
def _fuse_forward_ops( def _fuse_forward_ops(
cls, cls,
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass""" """Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops) ops = fuse_userbuffers_forward_linear(ops)
...@@ -370,16 +379,18 @@ class OperationFuser: ...@@ -370,16 +379,18 @@ class OperationFuser:
def _fuse_backward_ops( def _fuse_backward_ops(
cls, cls,
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass""" """Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops) ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops) ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe)
return ops return ops
def fuse_ops(self) -> None: def fuse_ops(self) -> None:
"""Attempt to fuse operations""" """Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops) self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe)
self._backward_ops = self._fuse_backward_ops(self._backward_ops) self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe)
def __call__( def __call__(
self, self,
...@@ -395,10 +406,8 @@ class OperationFuser: ...@@ -395,10 +406,8 @@ class OperationFuser:
# Initialization before forward pass # Initialization before forward pass
if self._is_first_forward: if self._is_first_forward:
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
for op in self._basic_ops: for op in self._basic_ops:
op.pre_first_forward(recipe=recipe) op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False self._is_first_forward = False
# Canonicalize op kwargs # Canonicalize op kwargs
......
...@@ -86,7 +86,6 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -86,7 +86,6 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
"""Forward pass """Forward pass
...@@ -109,10 +108,6 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -109,10 +108,6 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
The grad_input_quantizer of the preceeding operation The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to forward functions of basic Keyword arguments to forward functions of basic
operations. operations.
...@@ -421,7 +416,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -421,7 +416,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
*, *,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass """Forward pass
...@@ -436,10 +430,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -436,10 +430,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
The grad_input_quantizer of the preceeding operation The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation The input_quantizer of the following operation
is_first_op: bool
Does this op have a preceeding op or is it the first one in the
fuser. Used in the backward pass to safely delete the saved input
tensor when no longer needed and there is a preceeding op.
Returns Returns
------- -------
...@@ -480,7 +470,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -480,7 +470,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]: ) -> tuple[torch.Tensor, list[tuple[()]]]:
if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
...@@ -495,7 +484,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -495,7 +484,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_, input_,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
next_op_input_quantizer=next_op_input_quantizer, next_op_input_quantizer=next_op_input_quantizer,
is_first_op=is_first_op,
**basic_op_kwargs[0], **basic_op_kwargs[0],
) )
return output, [()] return output, [()]
...@@ -530,7 +518,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -530,7 +518,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation""" """Apply operation"""
from .fuser import OperationFuser from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)( with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input, input,
*extra_inputs, *extra_inputs,
basic_op_kwargs=[kwargs], basic_op_kwargs=[kwargs],
...@@ -737,7 +727,9 @@ class FusedOperation(FusibleOperation): ...@@ -737,7 +727,9 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)( with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input, input,
*extra_inputs, *extra_inputs,
basic_op_kwargs=basic_op_kwargs, basic_op_kwargs=basic_op_kwargs,
......
...@@ -10,7 +10,7 @@ from typing import Optional ...@@ -10,7 +10,7 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser from transformer_engine.pytorch.ops.fuser import OperationFuser
...@@ -147,6 +147,7 @@ class Sequential(torch.nn.Module): ...@@ -147,6 +147,7 @@ class Sequential(torch.nn.Module):
def _make_module_groups( def _make_module_groups(
cls, cls,
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]: ) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together""" """Make list of modules, with fusible operations grouped together"""
...@@ -161,7 +162,7 @@ class Sequential(torch.nn.Module): ...@@ -161,7 +162,7 @@ class Sequential(torch.nn.Module):
groups.append(module) groups.append(module)
for idx, group in enumerate(groups): for idx, group in enumerate(groups):
if isinstance(group, list): if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True) groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe)
# Check if operations expect extra input or output tensors # Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire # Note: If any op has extra inputs or outputs, then the entire
...@@ -190,9 +191,9 @@ class Sequential(torch.nn.Module): ...@@ -190,9 +191,9 @@ class Sequential(torch.nn.Module):
"""Forward pass""" """Forward pass"""
# Get current global state # Get current global state
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_enabled else None recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (fp8_enabled, type(fp8_recipe)) global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed # Reset module groups is global state changed
if self._last_global_state != global_state: if self._last_global_state != global_state:
...@@ -201,7 +202,7 @@ class Sequential(torch.nn.Module): ...@@ -201,7 +202,7 @@ class Sequential(torch.nn.Module):
# Create module groups if needed # Create module groups if needed
if self._module_groups is None: if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values()) self._module_groups = self._make_module_groups(self._modules.values(), recipe)
# Forward pass for each module group # Forward pass for each module group
x = input x = input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment