".github/vscode:/vscode.git/clone" did not exist on "085bd66f5f2693be72de0ff284037ef36f9692b5"
Unverified Commit e0204fbb authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Refactor `te.ops` (#1951)



* Refactor _OperationFuserAutogradFunction.forward to use less parameters
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit f8f59b1bb184e89468058521df4cfff029ad909c)

* Rename `BackwardBiasActivation` to `BackwardActivationBias`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 397c58fc296f801fe4ad600aadc2daff3b78be45)

* Use forward operation order in backward fused operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 2d37a9385069b066e6cdeff3eb9173c2079cb791)

* Rename `prev_op_grad_input_quantizer` to `prev_op_grad_output_quantizer`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit d7ab5dfb23e216866f7f4fc4d7a99f625d329f1e)

* Make OperationFuser persistent
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 77984d9715d31e87519dc6ea1e02c483a81355a7)

* Distribute extra inputs to and collect extra outputs from multiple module groups in Sequential
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 0716aaad542e59f2c1ac4620167965a0334bbf71)

* Take requires_grad into account when fusing operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Change get_quantizer to return None if no quantization recipe is used
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Refactor pre_first_forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix for failing `test_make_graphed_callables[fp8_recipe0-*-True-*-linear_op]`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix fp8 meta tensors in CUDA Graph capture
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix failing distributed userbuffers tests
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cb504cda
...@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation):
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w) if linear_op_ctx.requires_grad:
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.dtype = dtype
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
if bias_op is not None: linear_op_ctx.weight_requires_grad = weight_requires_grad
bias_op_ctx.with_quantized_compute = with_quantized_compute if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() input_quantizer = linear_op.get_quantizer("forward", 0)
input_quantizer = None weight_quantizer = linear_op.get_quantizer("forward", 1)
weight_quantizer = None
output_quantizer = None output_quantizer = None
grad_output_quantizer = None grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = None grad_input_quantizer = prev_op_grad_output_quantizer
if with_quantized_compute: with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation):
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w) if linear_op_ctx.requires_grad:
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.dtype = dtype
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
if bias_op is not None: linear_op_ctx.weight_requires_grad = weight_requires_grad
bias_op_ctx.with_quantized_compute = with_quantized_compute if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation):
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = [] ops = []
if reduce_scatter is not None: op_idxs["linear"] = len(ops)
op_idxs["reduce_scatter"] = len(ops) ops.append(linear)
ops.append(reduce_scatter)
if bias is not None: if bias is not None:
op_idxs["bias"] = len(ops) op_idxs["bias"] = len(ops)
ops.append(bias) ops.append(bias)
op_idxs["linear"] = len(ops) if reduce_scatter is not None:
ops.append(linear) op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
# Initialize base class # Initialize base class
super().__init__(ops) super().__init__(ops)
...@@ -495,7 +495,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -495,7 +495,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations # Get basic operations
idx = self._op_idxs["linear"] idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx] linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx] linear_op_ctx = basic_op_ctxs[-1]
bias_op = None bias_op = None
if self._op_idxs["bias"] is not None: if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
...@@ -556,6 +556,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -556,6 +556,7 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))] grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs return grad_input, grad_params, grad_extra_inputs
......
...@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata # Quantization metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute: if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError( raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
) )
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -356,19 +352,19 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -356,19 +352,19 @@ class UserbuffersForwardLinear(FusedOperation):
w = extra_outputs["weight"] w = extra_outputs["weight"]
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w) if linear_op_ctx.requires_grad:
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.input_dims = input_.size() linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_dims = input_.size()
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
if bias_op is not None: linear_op_ctx.weight_requires_grad = weight_requires_grad
bias_op_ctx.with_quantized_compute = with_quantized_compute if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -5,19 +5,20 @@ ...@@ -5,19 +5,20 @@
"""Manager class for a pipeline of fusible operations.""" """Manager class for a pipeline of fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Iterable
from typing import Any, Optional from typing import Any, Optional
import itertools
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation, fuse_backward_activation_bias,
fuse_backward_linear_add, fuse_backward_linear_add,
fuse_forward_linear_bias_activation, fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add, fuse_forward_linear_bias_add,
...@@ -68,8 +69,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -68,8 +69,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
input_: torch.Tensor, input_: torch.Tensor,
fuser: OperationFuser, fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool, *params_and_extra_inputs: torch.Tensor,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass """Forward pass
...@@ -83,8 +83,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -83,8 +83,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Container for the pipeline of operations to run Container for the pipeline of operations to run
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to BasicOperation Keyword arguments to BasicOperation
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor *params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs. of parameter tensors, followed by extra operation inputs.
...@@ -106,52 +104,53 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -106,52 +104,53 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
tensor.do_not_clear = True tensor.do_not_clear = True
# Unflatten list of parameters and extra tensor inputs # Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
basic_op_extra_inputs = [] basic_op_extra_inputs = []
for op in fuser._basic_ops: for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs) basic_op_extra_inputs.append(xs)
# Get environment state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
is_grad_enabled = func_ctx is not None
# Attempt to fuse operations if neccesary
fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs)
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops: for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required # Set if backward op is required
if is_grad_enabled:
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs: for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_op_idx = basic_op_idxs[0] - 1 prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None prev_op_grad_output_quantizer = None
if prev_op is not None and with_quantized_compute: if prev_op is not None:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer() prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
next_op_idx = basic_op_idxs[-1] + 1 next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None next_op_input_quantizer = None
if next_op is not None and with_quantized_compute: if next_op is not None:
next_op_input_quantizer = next_op.get_input_quantizer() next_op_input_quantizer = next_op.get_input_quantizer()
x, fused_op_extra_outputs = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
x, x,
basic_op_extra_inputs=extra_inputs, basic_op_extra_inputs=extra_inputs,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer, next_op_input_quantizer=next_op_input_quantizer,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
) )
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys: for y in ys:
y.requires_grad_(requires_grad) y.requires_grad_(idx >= fuser.first_op_requiring_backward)
extra_outputs[idx] = ys extra_outputs[idx] = ys
# Flatten list of extra outputs # Flatten list of extra outputs
...@@ -192,13 +191,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -192,13 +191,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.backward_ops = fuser._backward_ops func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params func_ctx.basic_op_num_params = fuser._basic_op_num_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute func_ctx.with_quantized_compute = with_quantized_compute
x.requires_grad_(requires_grad) x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if extra_outputs_flat: if extra_outputs_flat:
return x, *extra_outputs_flat return x, *extra_outputs_flat
...@@ -304,7 +303,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -304,7 +303,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx, # input_ dx, # input_
None, # fuser None, # fuser
None, # basic_op_kwargs None, # basic_op_kwargs
None, # is_grad_enabled
*grad_params_flat, *grad_params_flat,
*grad_extra_inputs_flat, *grad_extra_inputs_flat,
) )
...@@ -317,19 +315,12 @@ class OperationFuser: ...@@ -317,19 +315,12 @@ class OperationFuser:
---------- ----------
ops: list of FusibleOperation ops: list of FusibleOperation
Pipeline of operations Pipeline of operations
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
""" """
def __init__( def __init__(
self, self,
ops: list[FusibleOperation], ops: list[FusibleOperation],
fuse_ops: bool,
recipe: Optional[Recipe],
) -> None: ) -> None:
# Get list of basic operations # Get list of basic operations
...@@ -343,25 +334,22 @@ class OperationFuser: ...@@ -343,25 +334,22 @@ class OperationFuser:
self._basic_ops: list[BasicOperation] = basic_ops self._basic_ops: list[BasicOperation] = basic_ops
# Number of extra tensor inputs # Number of extra tensor inputs
self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops) self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass # Ops for forward and backward pass, will be populated in fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]]
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Flag for checking if this is the first iteration # Cache and detect change of state relevant for fusing operations
self._is_first_forward = True self.recipe_type = None
self.first_op_requiring_backward = 0
# Fuse ops if needed self._last_amax_history_len = 0
self.recipe = recipe
if fuse_ops:
self.fuse_ops()
# Flatten list of parameters # Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] self._basic_op_num_params = list(map(len, self._basic_op_params))
self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod @classmethod
def _fuse_forward_ops( def _fuse_forward_ops(
...@@ -384,13 +372,70 @@ class OperationFuser: ...@@ -384,13 +372,70 @@ class OperationFuser:
"""Attempt to fuse operations in backward pass""" """Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops) ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops) ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe) ops = fuse_backward_activation_bias(ops, recipe)
return ops return ops
def fuse_ops(self) -> None: def maybe_fuse_ops(
"""Attempt to fuse operations""" self,
self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe) is_grad_enabled: bool,
self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe) recipe: Optional[Recipe],
input_: torch.Tensor,
extra_inputs: list[Iterable[torch.Tensor]],
):
"""Attempt to fuse operations if neccesary"""
# Determine which basic ops require backward
if not is_grad_enabled:
first_op_requiring_backward = self._num_basic_ops
elif input_.requires_grad:
first_op_requiring_backward = 0
else:
first_op_requiring_backward = self._num_basic_ops
for op_idx in range(self._num_basic_ops):
op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
if any(tensor.requires_grad for tensor in op_inputs):
first_op_requiring_backward = op_idx
break
# Early exit if fusion parameters haven't changed
recipe_type = type(recipe)
fusion_params = (recipe_type, first_op_requiring_backward)
if fusion_params == (self.recipe_type, self.first_op_requiring_backward):
return
# Initialize ops if recipe type has changed
if self.recipe_type != recipe_type:
# Check if this is the first iteration
if self.recipe_type is None:
for op in self._basic_ops:
op.pre_first_fuser_forward()
# Inform ops that the recipe type has changed
for op in self._basic_ops:
op.reset_recipe_type(recipe=recipe)
# Check if amax history was invalidated
elif isinstance(recipe, DelayedScaling):
if recipe.amax_history_len != self._last_amax_history_len:
raise RuntimeError(
"Detected change of amax history length. "
"Changing the length of amax history is currently not supported."
)
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))
# Fuse ops
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)
# Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params
# Save amax history length
if isinstance(recipe, DelayedScaling):
self._last_amax_history_len = recipe.amax_history_len
else:
self._last_amax_history_len = 0
def __call__( def __call__(
self, self,
...@@ -399,24 +444,17 @@ class OperationFuser: ...@@ -399,24 +444,17 @@ class OperationFuser:
basic_op_kwargs: Optional[list[dict[str, Any]]] = None, basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count # Verify extra input count
if len(extra_inputs) != self._num_extra_inputs: if len(extra_inputs) != self.num_extra_inputs:
raise ValueError( raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
) )
# Initialization before forward pass
if self._is_first_forward:
for op in self._basic_ops:
op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False
# Canonicalize op kwargs # Canonicalize op kwargs
if basic_op_kwargs is None: if basic_op_kwargs is None:
basic_op_kwargs = [{}] * self._num_basic_ops basic_op_kwargs = [{}] * self._num_basic_ops
# Fuser forward pass # Fuser forward pass
is_grad_enabled = torch.is_grad_enabled() if torch.is_grad_enabled():
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply forward_func = _OperationFuserAutogradFunction.apply
args = [] args = []
else: else:
...@@ -426,8 +464,7 @@ class OperationFuser: ...@@ -426,8 +464,7 @@ class OperationFuser:
input, input,
self, self,
basic_op_kwargs, basic_op_kwargs,
is_grad_enabled, *self._flat_basic_op_params,
*self._basic_op_params,
*extra_inputs, *extra_inputs,
) )
return forward_func(*args) return forward_func(*args)
...@@ -15,9 +15,6 @@ import torch ...@@ -15,9 +15,6 @@ import torch
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from ..fp8 import ( from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
fp8_autocast, fp8_autocast,
...@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops""" """Whether this op is the fusion of one or more basic ops"""
def pre_first_forward( def pre_first_fuser_forward(self) -> None:
self, """Preprocessing before first fuser forward pass"""
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]: def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor""" """Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor""" """Get builder class for quantized output's grad tensor"""
def fuser_forward( def fuser_forward(
self, self,
...@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor Input tensor
basic_op_extra_inputs: list of torch.Tensor basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations Extra tensor inputs to basic operations
prev_op_grad_input_quantizer: Quantizer, optional prev_op_grad_output_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation The input_quantizer of the following operation
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
...@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
super().__init__() super().__init__()
# Objects for quantization # Objects for quantization
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_type(recipe=recipe)
@property @property
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
...@@ -214,120 +210,90 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -214,120 +210,90 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return self.get_quantizer("forward", 0) return self.get_quantizer("forward", 0)
return None return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0: if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0) return self.get_quantizer("backward", 0)
return None return None
def _reset_quantization_recipe_state( def reset_recipe_type(
self, self,
*, *,
recipe: Recipe, recipe: Optional[Recipe],
) -> None: ) -> None:
"""Construct state for quantization recipe""" """Construct state for quantization recipe"""
# Quantization recipe state for forward and backward pass # Clear quantization state if necessary
self._fp8_metas = {"forward": None, "backward": None} if recipe is None:
self._quantizers = {"forward": [], "backward": []} self._fp8_metas = None
for mode in ("forward", "backward"): self._quantizers = None
num_quantizers = self.num_quantizers(mode) return
if num_quantizers == 0:
continue
if recipe.float8_block_scaling(): # Skip resetting recipe type if it did not actually change.
raise NotImplementedError( # This could happen for example if calling BasicOperation.forward directly, as in that
"Fusible operations do not support FP8 block scaling recipe" # case, the OperationFuser is not persistent, or when loading from a checkpoint
need_to_reset_recipe_state = False
if self._fp8_metas is None or self._quantizers is None:
need_to_reset_recipe_state = True
else:
for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
) )
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
if not isinstance(recipe, type(recipe_state.recipe)):
need_to_reset_recipe_state = True
break
if need_to_reset_recipe_state:
# Quantization recipe state for forward and backward pass
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
continue
# Construct quantization recipe state if recipe.float8_block_scaling():
recipe_state = RecipeState.create( raise NotImplementedError(
recipe, "Fusible operations do not support FP8 block scaling recipe"
mode=mode, )
num_quantizers=num_quantizers,
)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
}
# Construct builder class for quantized tensors # Construct quantization recipe state
self._quantizers[mode] = recipe_state.make_quantizers() recipe_state = RecipeState.create(
recipe,
mode=mode,
num_quantizers=num_quantizers,
)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
}
def _update_quantization_recipe_state( # Construct builder class for quantized tensors
self, self._quantizers[mode] = recipe_state.make_quantizers()
*,
recipe: Recipe,
) -> None:
"""Make sure quantizer state matches quantization recipe"""
# Reset quantization state if needed # Add meta tensors to global buffer to participate in reduction
if self._fp8_metas is None or self._quantizers is None:
self._reset_quantization_recipe_state(recipe=recipe)
return
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( if (
forward=(mode == "forward"), FP8GlobalStateManager.is_fp8_enabled()
) and self.num_quantizers(mode)
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: and not FP8GlobalStateManager.fp8_graph_capturing()
continue ):
recipe_state = self._fp8_metas[mode][fp8_meta_key] FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
need_to_reset_recipe_state = ( self._fp8_metas[mode],
(recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
) )
)
if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe)
return
# Quantization recipe state for forward and backward pass
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
continue
# Update FP8 metadata
fp8_meta = self._fp8_metas[mode]
fp8_meta["recipe"] = recipe
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Get recipe state
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
recipe_state = fp8_meta[fp8_meta_key]
# Reallocate amax history if needed
if not recipe.delayed():
continue
current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len
if current_length != target_length:
with torch.no_grad():
if target_length < current_length:
recipe_state.amax_history = recipe_state.amax_history[
:target_length
].clone()
else:
recipe_state.amax_history = torch.nn.functional.pad(
recipe_state.amax_history,
pad=(0, 0, 0, target_length - current_length),
)
self._quantizers[mode] = recipe_state.make_quantizers()
def get_quantizer( def get_quantizer(
self, self,
mode: str, mode: str,
index: int, index: int,
) -> Quantizer: ) -> Optional[Quantizer]:
"""Get builder class for quantized tensor """Get builder class for quantized tensor
Parameters Parameters
...@@ -337,7 +303,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -337,7 +303,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
""" """
if self._quantizers is None: if self._quantizers is None:
self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) return None
return self._quantizers[mode][index] return self._quantizers[mode][index]
@torch.no_grad() @torch.no_grad()
...@@ -388,33 +354,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -388,33 +354,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if recipe is not None:
self._update_quantization_recipe_state(recipe=recipe)
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_quantizers("forward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["forward"],
)
if self.num_quantizers("backward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["backward"],
)
@abc.abstractmethod @abc.abstractmethod
def op_forward( def op_forward(
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
*, *,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -426,8 +372,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -426,8 +372,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes Context to coordinate between forward and backward passes
input_: torch.Tensor input_: torch.Tensor
Input tensor Input tensor
prev_op_grad_input_quantizer: Quantizer, optional prev_op_grad_output_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation The input_quantizer of the following operation
...@@ -468,7 +414,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -468,7 +414,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]: ) -> tuple[torch.Tensor, list[tuple[()]]]:
...@@ -482,7 +428,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -482,7 +428,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward( output = self.op_forward(
basic_op_ctxs[0], basic_op_ctxs[0],
input_, input_,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer, next_op_input_quantizer=next_op_input_quantizer,
**basic_op_kwargs[0], **basic_op_kwargs[0],
) )
...@@ -518,9 +464,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -518,9 +464,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation""" """Apply operation"""
from .fuser import OperationFuser from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() return OperationFuser([self])(
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input, input,
*extra_inputs, *extra_inputs,
basic_op_kwargs=[kwargs], basic_op_kwargs=[kwargs],
...@@ -630,7 +574,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -630,7 +574,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed # Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None: if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]): with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state(recipe=state[mode]["recipe"]) self.reset_recipe_type(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode] fp8_meta = self._fp8_metas[mode]
# Load extra items # Load extra items
...@@ -708,13 +652,13 @@ class FusedOperation(FusibleOperation): ...@@ -708,13 +652,13 @@ class FusedOperation(FusibleOperation):
def get_input_quantizer(self) -> Optional[Quantizer]: def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer() return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer() return self.basic_ops[-1].get_grad_output_quantizer()
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
"""Preprocessing before forward pass""" """Preprocessing before first fuser forward pass"""
for op in self.basic_ops: for op in self.basic_ops:
op.pre_first_forward(*args, **kwargs) op.pre_first_fuser_forward()
def forward( def forward(
self, self,
...@@ -727,9 +671,7 @@ class FusedOperation(FusibleOperation): ...@@ -727,9 +671,7 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() return OperationFuser([self])(
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input, input,
*extra_inputs, *extra_inputs,
basic_op_kwargs=basic_op_kwargs, basic_op_kwargs=basic_op_kwargs,
......
...@@ -10,7 +10,6 @@ from typing import Optional ...@@ -10,7 +10,6 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser from transformer_engine.pytorch.ops.fuser import OperationFuser
...@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module): ...@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module):
def _make_module_groups( def _make_module_groups(
cls, cls,
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]: ) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together""" """Make list of modules, with fusible operations grouped together"""
...@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module): ...@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module):
groups.append(module) groups.append(module)
for idx, group in enumerate(groups): for idx, group in enumerate(groups):
if isinstance(group, list): if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe) groups[idx] = OperationFuser(group)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
# Sequential must be made up of TE ops.
if len(groups) > 1:
ops = []
for group in groups:
if isinstance(group, OperationFuser):
ops.extend(group._basic_ops)
num_extra_inputs = sum(op.num_extra_inputs for op in ops)
num_extra_outputs = sum(op.num_extra_outputs for op in ops)
if num_extra_inputs > 0 or num_extra_outputs > 0:
raise RuntimeError(
f"`Sequential` expects {num_extra_inputs} extra inputs "
f"and {num_extra_outputs} extra outputs, "
"but it contains non-fusible operations"
)
return groups return groups
...@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module): ...@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module):
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass""" """Forward pass"""
# Get current global state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
self._module_groups = None
self._last_global_state = global_state
# Create module groups if needed # Create module groups if needed
if self._module_groups is None: if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values(), recipe) self._module_groups = self._make_module_groups(self._modules.values())
# Forward pass for each module group # Forward pass for each module group
x = input x = input
extra_outputs: list[torch.Tensor] = []
for module_group in self._module_groups: for module_group in self._module_groups:
x = module_group(x, *extra_inputs) if isinstance(module_group, OperationFuser):
xs, extra_inputs = (
(x,) + extra_inputs[: module_group.num_extra_inputs],
extra_inputs[module_group.num_extra_inputs :],
)
xs = module_group(*xs)
if isinstance(xs, tuple):
x, ys = xs[0], xs[1:]
extra_outputs.extend(ys)
else:
x = xs
else:
x = module_group(x)
if extra_outputs:
return (x,) + tuple(extra_outputs)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment