Unverified Commit 72592763 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support user-defined op fusions (#2597)



* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for custom ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings and numerical test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak pattern matching logic with fixed window sizes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use TF32 tols in fused op tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Backpropagate fixes from #2622
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a0a89a8e
......@@ -2329,13 +2329,13 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops
if with_quantization:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], BackwardActivationBias)
else:
assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type)
assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize)
assert isinstance(backward_ops[2][0], act_type)
# Expected numerical error
tols = dtype_tols(dtype)
......@@ -2930,3 +2930,317 @@ class TestSequentialModules:
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
class TestCustomOps:
"""Test with ops that are defined externally"""
def test_custom_basic_op(
self,
*,
shape: Iterable[int] = (7, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Custom basic op"""
class CustomScaleOp(te.ops.BasicOperation):
"""Custom op that applies a learnable scale"""
def __init__(self) -> None:
super().__init__()
self.scale: torch.nn.Parameter
scale = torch.ones((), dtype=dtype, device=device)
scale = torch.nn.Parameter(scale)
self.register_parameter("scale", scale)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.save_for_backward(self.scale, input_)
return self.scale * input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> torch.Tensor:
(
scale,
input_,
) = ctx.saved_tensors
grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1))
grad_scale = grad_scale.reshape(())
grad_input = scale * grad_output
return grad_input, (grad_scale,)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = w_ref * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = CustomScaleOp()
forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity())
with torch.no_grad():
op.scale.copy_(w_test)
del w_test
y_test = forward(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_forward_fused_op(
self,
*,
shape: Iterable[int] = (7, 11),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in forward pass"""
class CustomForwardLinearSiLU(te.ops.FusedOperation):
"""Custom fused op for GEMM + SiLU"""
_enabled = True
def __init__(self, *, linear, silu) -> None:
super().__init__((linear, silu))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
**unused,
) -> torch.Tensor:
weight = self.basic_ops[0].weight
dtype = weight.dtype
device = weight.device
# Perform compute on CPU, because why not?
x = input_.cpu()
w = weight.cpu()
y = torch.matmul(x, w.T)
z = torch.nn.functional.silu(y)
out = z.to(device=device)
# Save state for linear backward
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx.save_for_backward(input_, weight)
linear_op_ctx.with_quantized_compute = False
linear_op_ctx.input_quantizer = None
linear_op_ctx.weight_quantizer = None
linear_op_ctx.grad_output_quantizer = None
linear_op_ctx.grad_input_quantizer = None
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = True
linear_op_ctx.weight_requires_grad = True
# Save state for SiLU backward
silu_op_ctx = basic_op_ctxs[1]
silu_op_ctx.save_for_backward(y.to(device=device))
silu_op_ctx.dtype = dtype
silu_op_ctx.prev_op_grad_output_quantizer = None
return out, [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomForwardLinearSiLU._enabled:
CustomForwardLinearSiLU._enabled = False
op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref = torch.nn.functional.silu(y_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops)
model = te.ops.Sequential(
te.ops.Linear(shape[-1], shape[-1], bias=False),
te.ops.SiLU(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_backward_fused_op(
self,
*,
shape: Iterable[int] = (13, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in backward pass"""
class CustomBackwardLinearScale(te.ops.FusedOperation):
"""Custom fused op for backward linear + scale"""
_enabled: bool = True
def __init__(self, *, scale, linear) -> None:
super().__init__((scale, linear))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
**unused,
) -> torch.Tensor:
# Load state from linear forward
linear_op_ctx = basic_op_ctxs[1]
x, w = linear_op_ctx.saved_tensors
dtype = linear_op_ctx.dtype
device = w.device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale = self.basic_ops[0].scale
dy = grad_output.double()
x = x.double()
w = w.double()
dx = torch.matmul(dy, scale * w)
dw = torch.matmul(dy.T, x)
dx = dx.to(dtype=dtype)
dw = dw.to(dtype=dtype)
return dx, [(), (dw,)], [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomBackwardLinearScale._enabled:
CustomBackwardLinearScale._enabled = False
op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
scale = 1.234
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(scale * x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True)
model = te.ops.Sequential(
te.ops.ConstantScale(scale),
te.ops.Linear(shape[-1], shape[-1], bias=False),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], CustomBackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
......@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change.
"""
from transformer_engine.pytorch.ops.basic import *
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
from .basic import *
from .fuser import register_backward_fusion, register_forward_fusion
from .linear import Linear
from .op import BasicOperation, FusedOperation, FusibleOperation
from .sequential import Sequential
from . import fused
......@@ -4,39 +4,27 @@
"""Compound tensor operation supported by the operation fuser."""
from .backward_activation_bias import (
BackwardActivationBias,
fuse_backward_activation_bias,
)
from .backward_add_rmsnorm import (
BackwardAddRMSNorm,
fuse_backward_add_rmsnorm,
)
from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
)
from .backward_linear_scale import (
BackwardLinearScale,
fuse_backward_linear_scale,
)
from .forward_linear_bias_activation import (
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation,
)
from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
)
from .userbuffers_forward_linear import (
UserbuffersForwardLinear,
fuse_userbuffers_forward_linear,
)
from ..fuser import register_backward_fusion, register_forward_fusion
from .backward_activation_bias import BackwardActivationBias
from .backward_add_rmsnorm import BackwardAddRMSNorm
from .backward_linear_add import BackwardLinearAdd
from .backward_linear_scale import BackwardLinearScale
from .forward_linear_bias_activation import ForwardLinearBiasActivation
from .forward_linear_bias_add import ForwardLinearBiasAdd
from .forward_linear_scale_add import ForwardLinearScaleAdd
from .userbuffers_backward_linear import UserbuffersBackwardLinear
from .userbuffers_forward_linear import UserbuffersForwardLinear
# Register forward fusions
register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops)
register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops)
register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops)
register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops)
# Register backward fusions
register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops)
register_backward_fusion(BackwardLinearAdd.fuse_backward_ops)
register_backward_fusion(BackwardLinearScale.fuse_backward_ops)
register_backward_fusion(BackwardActivationBias.fuse_backward_ops)
register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops)
......@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation):
]:
# Get basic operation contexts
activation_op_ctx = basic_op_ctxs[0]
bias_op_ctx = basic_op_ctxs[1]
bias_op_ctx = basic_op_ctxs[0]
activation_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
(act_input,) = activation_op_ctx.saved_tensors
......@@ -79,68 +79,59 @@ class BackwardActivationBias(FusedOperation):
# Clear activation input tensor
clear_tensor_data(act_input)
return dx, [(), (db,)], [(), ()]
return dx, [(db,), ()], [(), ()]
def fuse_backward_activation_bias(
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dact + dbias + quantize
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe : Recipe, optional
Used quantization recipe
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if recipe is None:
return ops
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
*,
recipe: Optional[Recipe] = None,
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if recipe is None:
return ops
# Scan through ops, fusing if possible
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:
if (
isinstance(window[2], _fusible_activations)
and isinstance(window[1], Bias)
and window[0].get_grad_output_quantizer() is not None
):
# Construct fused op if window matches pattern
op = BackwardActivationBias(bias=window[1], activation=window[2])
window = [window[0], op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is a supported activation
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, _fusible_activations):
continue
# Check if second op is bias
op, _ = ops[0]
if not isinstance(op, Bias):
continue
# Check if third op has a grad input quantizer
op, _ = ops[1]
if not op.num_quantizers("backward") > 0:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardActivationBias(
activation=window[0][0],
bias=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Get basic operations
rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0]
rmsnorm_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors
......@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Check input tensors
dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0]
extra_grad = basic_op_grad_extra_outputs[0][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
......@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation):
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_add_rmsnorm(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
if (
isinstance(window[0], MakeExtraOutput)
and isinstance(window[1], RMSNorm)
and not window[0]._in_place
):
# Construct fused op if window matches pattern
op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, RMSNorm):
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation):
# Get basic operations
linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
......@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad = False
# Linear backward pass
grad_input = basic_op_grad_extra_outputs[1][0]
grad_input = basic_op_grad_extra_outputs[0][0]
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
......@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_linear_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + add
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)):
matches_pattern = False
elif not window[0]._in_place:
# Fused op accumulates grad input in-place
matches_pattern = False
elif window[1].tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardLinearAdd(backward_add=window[0], linear=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Row tensor-parallelism requires communication after the
# GEMM
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if not op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearAdd(
linear=window[0][0],
backward_add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation):
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1]
linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1]
# Saved tensors from forward pass
......@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()]
def fuse_backward_linear_scale(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + constant scale
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
return grad_input, [(grad_weight,), ()], [(), ()]
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)):
matches_pattern = False
elif window[0].tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardLinearScale(linear=window[0], scale=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is constant scale
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, ConstantScale):
continue
# Check if second op is linear
op, _ = ops[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + bias + activation
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)):
matches_pattern = False
elif window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern = False
elif window[0].weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = ForwardLinearBiasActivation(
linear=window[0],
bias=window[1],
activation=None,
)
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op1, _ = window[0]
if not isinstance(op1, BasicLinear):
continue
if op1.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
if op1.weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
# Check if second op is bias
op2, _ = ops[0]
if not isinstance(op2, Bias):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasActivation(
linear=window[0][0],
bias=window[1][0],
activation=None,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -131,72 +131,63 @@ class ForwardLinearBiasAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Shift window
out.extend(window)
window = [ops[0]]
ops = ops[1:]
def fuse_forward_linear_bias_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + bias + add
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue
if window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
continue
linear = window[0]
Returns
-------
ops : list of tuples
Updated forward pass operations
# Check if next op is bias
bias = None
if ops and isinstance(ops[0], Bias):
window.append(ops[0])
ops = ops[1:]
bias = window[-1]
# Check if next op is in-place add extra input
if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place:
window.append(ops[0])
ops = ops[1:]
add = window[-1]
else:
continue
"""
# Replace window with fused op
op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add)
window = [op]
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is bias
bias = None
if isinstance(op, Bias):
bias = op
window.extend(ops[:1])
ops = ops[1:]
if len(ops) == 0:
continue
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasAdd(
linear=linear,
bias=bias,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_scale_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + scale + add
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:
# Check if window matches pattern
matches_pattern = True
if not (
isinstance(window[0], BasicLinear)
and isinstance(window[1], ConstantScale)
and isinstance(window[2], AddExtraInput)
):
matches_pattern = False
elif window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern = False
elif not window[2]._in_place:
# Fused op accumulates output in-place
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = ForwardLinearScaleAdd(
linear=window[0],
scale=window[1],
add=window[2],
)
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is constant scale
if not isinstance(op, ConstantScale):
continue
scale = op
window.extend(ops[:1])
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd(
linear=linear,
scale=scale,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[-1]
linear_op_ctx = basic_op_ctxs[0]
bias_op = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
......@@ -578,99 +578,84 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
def fuse_userbuffers_backward_linear(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Substitute linear operations with Userbuffers implementation
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
"""
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[-1][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.insert(0, ops[-1])
ops = ops[:-1]
return window[0][0]
# Scan through ops in reverse order, fusing if possible
out_reversed = []
while ops:
out_reversed.extend(reversed(window))
window.clear()
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Shift window
out.extend(window)
window, ops = ops[:1], ops[1:]
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue
if reduce_scatter.process_group_size == 1:
linear = window[0]
if linear._userbuffers_options is None:
continue
# Replace window with fused op
op = UserbuffersBackwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out_reversed.extend(reversed(window))
out = out_reversed
out.reverse()
return out
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias, ops = ops[0], ops[1:]
window.append(bias)
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersBackwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
window = [op]
# Return list of ops
out.extend(window)
return out
......@@ -369,93 +369,79 @@ class UserbuffersForwardLinear(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
def fuse_userbuffers_forward_linear(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Substitute linear operations with Userbuffers implementation
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
"""
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[0][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.append(ops[0])
ops = ops[1:]
return window[-1][0]
# Scan through ops, fusing if possible
out = []
while ops:
out.extend(window)
window.clear()
"""
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Shift window
out.extend(window)
window, ops = ops[:1], ops[1:]
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue
if reduce_scatter.process_group_size == 1:
linear = window[0]
if linear._userbuffers_options is None:
continue
# Replace window with fused op
op = UserbuffersForwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias, ops = ops[0], ops[1:]
window.append(bias)
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersForwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
window = [op]
# Return list of ops
out.extend(window)
return out
# Return list of ops
out.extend(window)
return out
......@@ -5,33 +5,20 @@
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from collections.abc import Callable, Iterable
from typing import Any, Optional
from collections.abc import Callable, Iterable, Sequence
import itertools
from typing import Any, Optional, TypeAlias
import torch
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import (
from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from ..quantized_tensor import prepare_for_saving, restore_from_saved
from .op import (
BasicOperation,
FusibleOperation,
FusedOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_add_rmsnorm,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
......@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool:
return _is_graph_capturing_function()
# Type alias for a function that may perform operation fusion
OperationFusionFunction: TypeAlias = (
"Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]"
)
class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations
......@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx = grad_output
grad_params = [None for _ in range(len(basic_ops))]
grad_extra_inputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops:
for op, basic_op_idxs in reversed(backward_ops):
# Stop if no more gradients are required
if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
......@@ -315,6 +308,10 @@ class OperationFuser:
"""
# Functions to perform operation fusion
forward_fusion_functions: list[OperationFusionFunction] = []
backward_fusion_functions: list[OperationFusionFunction] = []
def __init__(
self,
ops: list[FusibleOperation],
......@@ -334,7 +331,7 @@ class OperationFuser:
self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass, will be populated in fuse_ops
# Ops for forward and backward pass, will be populated in maybe_fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
......@@ -349,31 +346,48 @@ class OperationFuser:
self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod
def _fuse_forward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops
@classmethod
def _fuse_backward_ops(
def _fuse_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: Sequence[BasicOperation],
fusion_funcs: Iterable[OperationFusionFunction],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
ops = fuse_backward_add_rmsnorm(ops)
return ops
"""Apply operation fusions"""
# Apply op fusions
fused_ops = list(basic_ops)
for func in fusion_funcs:
fused_ops = func(fused_ops, recipe=recipe)
def raise_mismatch_error() -> None:
"""Throw error indicating invalid op fusion"""
raise RuntimeError(
"Found mismatch after fusing operations "
f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, "
f"fused_ops={[o.__class__.__name__ for o in fused_ops]})"
)
# Determine basic op indices corresponding to each op
out = []
idx = 0
for op in fused_ops:
if isinstance(op, FusedOperation):
idxs = []
for basic_op in op.basic_ops:
if basic_op is not basic_ops[idx]:
raise_mismatch_error()
idxs.append(idx)
idx += 1
out.append((op, idxs))
else:
if op is not basic_ops[idx]:
raise_mismatch_error()
out.append((op, [idx]))
idx += 1
if idx != len(basic_ops):
raise_mismatch_error()
return out
def maybe_fuse_ops(
self,
......@@ -424,12 +438,16 @@ class OperationFuser:
op.pre_first_fuser_forward()
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))
# Fuse ops
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)
self._forward_ops = OperationFuser._fuse_ops(
self._basic_ops,
OperationFuser.forward_fusion_functions,
recipe=recipe,
)
self._backward_ops = OperationFuser._fuse_ops(
self._basic_ops,
OperationFuser.backward_fusion_functions,
recipe=recipe,
)
# Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params
......@@ -491,3 +509,55 @@ class OperationFuser:
*extra_inputs,
)
return forward_func(*args)
def register_forward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for forward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.forward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.forward_fusion_functions.append(op_fusion_func)
def register_backward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for backward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.backward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.backward_fusion_functions.append(op_fusion_func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment