Unverified Commit e0204fbb authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Refactor `te.ops` (#1951)



* Refactor _OperationFuserAutogradFunction.forward to use less parameters
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit f8f59b1bb184e89468058521df4cfff029ad909c)

* Rename `BackwardBiasActivation` to `BackwardActivationBias`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 397c58fc296f801fe4ad600aadc2daff3b78be45)

* Use forward operation order in backward fused operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 2d37a9385069b066e6cdeff3eb9173c2079cb791)

* Rename `prev_op_grad_input_quantizer` to `prev_op_grad_output_quantizer`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit d7ab5dfb23e216866f7f4fc4d7a99f625d329f1e)

* Make OperationFuser persistent
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 77984d9715d31e87519dc6ea1e02c483a81355a7)

* Distribute extra inputs to and collect extra outputs from multiple module groups in Sequential
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 0716aaad542e59f2c1ac4620167965a0334bbf71)

* Take requires_grad into account when fusing operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Change get_quantizer to return None if no quantization recipe is used
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Refactor pre_first_forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix for failing `test_make_graphed_callables[fp8_recipe0-*-True-*-linear_op]`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix fp8 meta tensors in CUDA Graph capture
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix failing distributed userbuffers tests
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cb504cda
......@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation):
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation):
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation):
# Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = []
if reduce_scatter is not None:
op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
op_idxs["linear"] = len(ops)
ops.append(linear)
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
op_idxs["linear"] = len(ops)
ops.append(linear)
if reduce_scatter is not None:
op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
# Initialize base class
super().__init__(ops)
......@@ -495,7 +495,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
linear_op_ctx = basic_op_ctxs[-1]
bias_op = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
......@@ -556,6 +556,7 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs
......
......@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -356,19 +352,19 @@ class UserbuffersForwardLinear(FusedOperation):
w = extra_outputs["weight"]
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -5,19 +5,20 @@
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import Any, Optional
import itertools
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation,
fuse_backward_activation_bias,
fuse_backward_linear_add,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
......@@ -68,8 +69,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
input_: torch.Tensor,
fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
*params_and_extra_inputs: torch.nn.Parameter,
*params_and_extra_inputs: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
......@@ -83,8 +83,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
......@@ -106,52 +104,53 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
tensor.do_not_clear = True
# Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
basic_op_extra_inputs = []
for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Get environment state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
is_grad_enabled = func_ctx is not None
# Attempt to fuse operations if neccesary
fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs)
# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required
if is_grad_enabled:
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
# Set if backward op is required
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None
if prev_op is not None and with_quantized_compute:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
prev_op_grad_output_quantizer = None
if prev_op is not None:
prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None
if next_op is not None and with_quantized_compute:
if next_op is not None:
next_op_input_quantizer = next_op.get_input_quantizer()
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
x,
basic_op_extra_inputs=extra_inputs,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(requires_grad)
y.requires_grad_(idx >= fuser.first_op_requiring_backward)
extra_outputs[idx] = ys
# Flatten list of extra outputs
......@@ -192,13 +191,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.basic_op_num_params = fuser._basic_op_num_params
func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
x.requires_grad_(requires_grad)
x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if extra_outputs_flat:
return x, *extra_outputs_flat
......@@ -304,7 +303,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx, # input_
None, # fuser
None, # basic_op_kwargs
None, # is_grad_enabled
*grad_params_flat,
*grad_extra_inputs_flat,
)
......@@ -317,19 +315,12 @@ class OperationFuser:
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
"""
def __init__(
self,
ops: list[FusibleOperation],
fuse_ops: bool,
recipe: Optional[Recipe],
) -> None:
# Get list of basic operations
......@@ -343,25 +334,22 @@ class OperationFuser:
self._basic_ops: list[BasicOperation] = basic_ops
# Number of extra tensor inputs
self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)
self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass
# Ops for forward and backward pass, will be populated in fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Flag for checking if this is the first iteration
self._is_first_forward = True
# Fuse ops if needed
self.recipe = recipe
if fuse_ops:
self.fuse_ops()
# Cache and detect change of state relevant for fusing operations
self.recipe_type = None
self.first_op_requiring_backward = 0
self._last_amax_history_len = 0
# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]
self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
self._basic_op_num_params = list(map(len, self._basic_op_params))
self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod
def _fuse_forward_ops(
......@@ -384,13 +372,70 @@ class OperationFuser:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe)
ops = fuse_backward_activation_bias(ops, recipe)
return ops
def fuse_ops(self) -> None:
"""Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe)
self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe)
def maybe_fuse_ops(
self,
is_grad_enabled: bool,
recipe: Optional[Recipe],
input_: torch.Tensor,
extra_inputs: list[Iterable[torch.Tensor]],
):
"""Attempt to fuse operations if neccesary"""
# Determine which basic ops require backward
if not is_grad_enabled:
first_op_requiring_backward = self._num_basic_ops
elif input_.requires_grad:
first_op_requiring_backward = 0
else:
first_op_requiring_backward = self._num_basic_ops
for op_idx in range(self._num_basic_ops):
op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
if any(tensor.requires_grad for tensor in op_inputs):
first_op_requiring_backward = op_idx
break
# Early exit if fusion parameters haven't changed
recipe_type = type(recipe)
fusion_params = (recipe_type, first_op_requiring_backward)
if fusion_params == (self.recipe_type, self.first_op_requiring_backward):
return
# Initialize ops if recipe type has changed
if self.recipe_type != recipe_type:
# Check if this is the first iteration
if self.recipe_type is None:
for op in self._basic_ops:
op.pre_first_fuser_forward()
# Inform ops that the recipe type has changed
for op in self._basic_ops:
op.reset_recipe_type(recipe=recipe)
# Check if amax history was invalidated
elif isinstance(recipe, DelayedScaling):
if recipe.amax_history_len != self._last_amax_history_len:
raise RuntimeError(
"Detected change of amax history length. "
"Changing the length of amax history is currently not supported."
)
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))
# Fuse ops
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)
# Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params
# Save amax history length
if isinstance(recipe, DelayedScaling):
self._last_amax_history_len = recipe.amax_history_len
else:
self._last_amax_history_len = 0
def __call__(
self,
......@@ -399,24 +444,17 @@ class OperationFuser:
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
if len(extra_inputs) != self.num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)
# Initialization before forward pass
if self._is_first_forward:
for op in self._basic_ops:
op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False
# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{}] * self._num_basic_ops
# Fuser forward pass
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
if torch.is_grad_enabled():
forward_func = _OperationFuserAutogradFunction.apply
args = []
else:
......@@ -426,8 +464,7 @@ class OperationFuser:
input,
self,
basic_op_kwargs,
is_grad_enabled,
*self._basic_op_params,
*self._flat_basic_op_params,
*extra_inputs,
)
return forward_func(*args)
......@@ -15,9 +15,6 @@ import torch
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
fp8_autocast,
......@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops"""
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor"""
def get_grad_output_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized output's grad tensor"""
def fuser_forward(
self,
......@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
prev_op_grad_output_quantizer: Quantizer, optional
The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
basic_op_kwargs: list of dict
......@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
super().__init__()
# Objects for quantization
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_type(recipe=recipe)
@property
def is_fused_op(self) -> bool:
......@@ -214,120 +210,90 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return self.get_quantizer("forward", 0)
return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
def get_grad_output_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0)
return None
def _reset_quantization_recipe_state(
def reset_recipe_type(
self,
*,
recipe: Recipe,
recipe: Optional[Recipe],
) -> None:
"""Construct state for quantization recipe"""
# Quantization recipe state for forward and backward pass
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
continue
# Clear quantization state if necessary
if recipe is None:
self._fp8_metas = None
self._quantizers = None
return
if recipe.float8_block_scaling():
raise NotImplementedError(
"Fusible operations do not support FP8 block scaling recipe"
# Skip resetting recipe type if it did not actually change.
# This could happen for example if calling BasicOperation.forward directly, as in that
# case, the OperationFuser is not persistent, or when loading from a checkpoint
need_to_reset_recipe_state = False
if self._fp8_metas is None or self._quantizers is None:
need_to_reset_recipe_state = True
else:
for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
if not isinstance(recipe, type(recipe_state.recipe)):
need_to_reset_recipe_state = True
break
if need_to_reset_recipe_state:
# Quantization recipe state for forward and backward pass
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
continue
# Construct quantization recipe state
recipe_state = RecipeState.create(
recipe,
mode=mode,
num_quantizers=num_quantizers,
)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
}
if recipe.float8_block_scaling():
raise NotImplementedError(
"Fusible operations do not support FP8 block scaling recipe"
)
# Construct builder class for quantized tensors
self._quantizers[mode] = recipe_state.make_quantizers()
# Construct quantization recipe state
recipe_state = RecipeState.create(
recipe,
mode=mode,
num_quantizers=num_quantizers,
)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
}
def _update_quantization_recipe_state(
self,
*,
recipe: Recipe,
) -> None:
"""Make sure quantizer state matches quantization recipe"""
# Construct builder class for quantized tensors
self._quantizers[mode] = recipe_state.make_quantizers()
# Reset quantization state if needed
if self._fp8_metas is None or self._quantizers is None:
self._reset_quantization_recipe_state(recipe=recipe)
return
# Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = (
(recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
if (
FP8GlobalStateManager.is_fp8_enabled()
and self.num_quantizers(mode)
and not FP8GlobalStateManager.fp8_graph_capturing()
):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas[mode],
)
)
if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe)
return
# Quantization recipe state for forward and backward pass
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
continue
# Update FP8 metadata
fp8_meta = self._fp8_metas[mode]
fp8_meta["recipe"] = recipe
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Get recipe state
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
recipe_state = fp8_meta[fp8_meta_key]
# Reallocate amax history if needed
if not recipe.delayed():
continue
current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len
if current_length != target_length:
with torch.no_grad():
if target_length < current_length:
recipe_state.amax_history = recipe_state.amax_history[
:target_length
].clone()
else:
recipe_state.amax_history = torch.nn.functional.pad(
recipe_state.amax_history,
pad=(0, 0, 0, target_length - current_length),
)
self._quantizers[mode] = recipe_state.make_quantizers()
def get_quantizer(
self,
mode: str,
index: int,
) -> Quantizer:
) -> Optional[Quantizer]:
"""Get builder class for quantized tensor
Parameters
......@@ -337,7 +303,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
if self._quantizers is None:
self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
return None
return self._quantizers[mode][index]
@torch.no_grad()
......@@ -388,33 +354,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if recipe is not None:
self._update_quantization_recipe_state(recipe=recipe)
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_quantizers("forward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["forward"],
)
if self.num_quantizers("backward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["backward"],
)
@abc.abstractmethod
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
**kwargs: Any,
) -> torch.Tensor:
......@@ -426,8 +372,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
prev_op_grad_output_quantizer: Quantizer, optional
The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
......@@ -468,7 +414,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]:
......@@ -482,7 +428,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward(
basic_op_ctxs[0],
input_,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
**basic_op_kwargs[0],
)
......@@ -518,9 +464,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation"""
from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
return OperationFuser([self])(
input,
*extra_inputs,
basic_op_kwargs=[kwargs],
......@@ -630,7 +574,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state(recipe=state[mode]["recipe"])
self.reset_recipe_type(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode]
# Load extra items
......@@ -708,13 +652,13 @@ class FusedOperation(FusibleOperation):
def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer()
def get_grad_output_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_output_quantizer()
def pre_first_forward(self, *args, **kwargs) -> None:
"""Preprocessing before forward pass"""
def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass"""
for op in self.basic_ops:
op.pre_first_forward(*args, **kwargs)
op.pre_first_fuser_forward()
def forward(
self,
......@@ -727,9 +671,7 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
return OperationFuser([self])(
input,
*extra_inputs,
basic_op_kwargs=basic_op_kwargs,
......
......@@ -10,7 +10,6 @@ from typing import Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser
......@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module):
def _make_module_groups(
cls,
modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together"""
......@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module):
groups.append(module)
for idx, group in enumerate(groups):
if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
# Sequential must be made up of TE ops.
if len(groups) > 1:
ops = []
for group in groups:
if isinstance(group, OperationFuser):
ops.extend(group._basic_ops)
num_extra_inputs = sum(op.num_extra_inputs for op in ops)
num_extra_outputs = sum(op.num_extra_outputs for op in ops)
if num_extra_inputs > 0 or num_extra_outputs > 0:
raise RuntimeError(
f"`Sequential` expects {num_extra_inputs} extra inputs "
f"and {num_extra_outputs} extra outputs, "
"but it contains non-fusible operations"
)
groups[idx] = OperationFuser(group)
return groups
......@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module):
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass"""
# Get current global state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
self._module_groups = None
self._last_global_state = global_state
# Create module groups if needed
if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values(), recipe)
self._module_groups = self._make_module_groups(self._modules.values())
# Forward pass for each module group
x = input
extra_outputs: list[torch.Tensor] = []
for module_group in self._module_groups:
x = module_group(x, *extra_inputs)
if isinstance(module_group, OperationFuser):
xs, extra_inputs = (
(x,) + extra_inputs[: module_group.num_extra_inputs],
extra_inputs[module_group.num_extra_inputs :],
)
xs = module_group(*xs)
if isinstance(xs, tuple):
x, ys = xs[0], xs[1:]
extra_outputs.extend(ys)
else:
x = xs
else:
x = module_group(x)
if extra_outputs:
return (x,) + tuple(extra_outputs)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment