Unverified Commit 72592763 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support user-defined op fusions (#2597)



* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for custom ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings and numerical test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak pattern matching logic with fixed window sizes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use TF32 tols in fused op tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Backpropagate fixes from #2622
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a0a89a8e
...@@ -2329,13 +2329,13 @@ class TestFusedOps: ...@@ -2329,13 +2329,13 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops backward_ops = model._module_groups[0]._backward_ops
if with_quantization: if with_quantization:
assert len(backward_ops) == 2 assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias) assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], BackwardActivationBias)
else: else:
assert len(backward_ops) == 3 assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type) assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Bias) assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize) assert isinstance(backward_ops[2][0], act_type)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
...@@ -2930,3 +2930,317 @@ class TestSequentialModules: ...@@ -2930,3 +2930,317 @@ class TestSequentialModules:
if bias: if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
class TestCustomOps:
"""Test with ops that are defined externally"""
def test_custom_basic_op(
self,
*,
shape: Iterable[int] = (7, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Custom basic op"""
class CustomScaleOp(te.ops.BasicOperation):
"""Custom op that applies a learnable scale"""
def __init__(self) -> None:
super().__init__()
self.scale: torch.nn.Parameter
scale = torch.ones((), dtype=dtype, device=device)
scale = torch.nn.Parameter(scale)
self.register_parameter("scale", scale)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.save_for_backward(self.scale, input_)
return self.scale * input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> torch.Tensor:
(
scale,
input_,
) = ctx.saved_tensors
grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1))
grad_scale = grad_scale.reshape(())
grad_input = scale * grad_output
return grad_input, (grad_scale,)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = w_ref * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = CustomScaleOp()
forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity())
with torch.no_grad():
op.scale.copy_(w_test)
del w_test
y_test = forward(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_forward_fused_op(
self,
*,
shape: Iterable[int] = (7, 11),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in forward pass"""
class CustomForwardLinearSiLU(te.ops.FusedOperation):
"""Custom fused op for GEMM + SiLU"""
_enabled = True
def __init__(self, *, linear, silu) -> None:
super().__init__((linear, silu))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
**unused,
) -> torch.Tensor:
weight = self.basic_ops[0].weight
dtype = weight.dtype
device = weight.device
# Perform compute on CPU, because why not?
x = input_.cpu()
w = weight.cpu()
y = torch.matmul(x, w.T)
z = torch.nn.functional.silu(y)
out = z.to(device=device)
# Save state for linear backward
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx.save_for_backward(input_, weight)
linear_op_ctx.with_quantized_compute = False
linear_op_ctx.input_quantizer = None
linear_op_ctx.weight_quantizer = None
linear_op_ctx.grad_output_quantizer = None
linear_op_ctx.grad_input_quantizer = None
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = True
linear_op_ctx.weight_requires_grad = True
# Save state for SiLU backward
silu_op_ctx = basic_op_ctxs[1]
silu_op_ctx.save_for_backward(y.to(device=device))
silu_op_ctx.dtype = dtype
silu_op_ctx.prev_op_grad_output_quantizer = None
return out, [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomForwardLinearSiLU._enabled:
CustomForwardLinearSiLU._enabled = False
op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref = torch.nn.functional.silu(y_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops)
model = te.ops.Sequential(
te.ops.Linear(shape[-1], shape[-1], bias=False),
te.ops.SiLU(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_backward_fused_op(
self,
*,
shape: Iterable[int] = (13, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in backward pass"""
class CustomBackwardLinearScale(te.ops.FusedOperation):
"""Custom fused op for backward linear + scale"""
_enabled: bool = True
def __init__(self, *, scale, linear) -> None:
super().__init__((scale, linear))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
**unused,
) -> torch.Tensor:
# Load state from linear forward
linear_op_ctx = basic_op_ctxs[1]
x, w = linear_op_ctx.saved_tensors
dtype = linear_op_ctx.dtype
device = w.device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale = self.basic_ops[0].scale
dy = grad_output.double()
x = x.double()
w = w.double()
dx = torch.matmul(dy, scale * w)
dw = torch.matmul(dy.T, x)
dx = dx.to(dtype=dtype)
dw = dw.to(dtype=dtype)
return dx, [(), (dw,)], [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomBackwardLinearScale._enabled:
CustomBackwardLinearScale._enabled = False
op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
scale = 1.234
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(scale * x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True)
model = te.ops.Sequential(
te.ops.ConstantScale(scale),
te.ops.Linear(shape[-1], shape[-1], bias=False),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], CustomBackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
...@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change. ...@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change.
""" """
from transformer_engine.pytorch.ops.basic import * from .basic import *
from transformer_engine.pytorch.ops.linear import Linear from .fuser import register_backward_fusion, register_forward_fusion
from transformer_engine.pytorch.ops.op import FusibleOperation from .linear import Linear
from transformer_engine.pytorch.ops.sequential import Sequential from .op import BasicOperation, FusedOperation, FusibleOperation
from .sequential import Sequential
from . import fused
...@@ -4,39 +4,27 @@ ...@@ -4,39 +4,27 @@
"""Compound tensor operation supported by the operation fuser.""" """Compound tensor operation supported by the operation fuser."""
from .backward_activation_bias import ( from ..fuser import register_backward_fusion, register_forward_fusion
BackwardActivationBias, from .backward_activation_bias import BackwardActivationBias
fuse_backward_activation_bias, from .backward_add_rmsnorm import BackwardAddRMSNorm
) from .backward_linear_add import BackwardLinearAdd
from .backward_add_rmsnorm import ( from .backward_linear_scale import BackwardLinearScale
BackwardAddRMSNorm, from .forward_linear_bias_activation import ForwardLinearBiasActivation
fuse_backward_add_rmsnorm, from .forward_linear_bias_add import ForwardLinearBiasAdd
) from .forward_linear_scale_add import ForwardLinearScaleAdd
from .backward_linear_add import ( from .userbuffers_backward_linear import UserbuffersBackwardLinear
BackwardLinearAdd, from .userbuffers_forward_linear import UserbuffersForwardLinear
fuse_backward_linear_add,
)
from .backward_linear_scale import ( # Register forward fusions
BackwardLinearScale, register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops)
fuse_backward_linear_scale, register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops)
) register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops)
from .forward_linear_bias_activation import ( register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops)
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation, # Register backward fusions
) register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops)
from .forward_linear_bias_add import ( register_backward_fusion(BackwardLinearAdd.fuse_backward_ops)
ForwardLinearBiasAdd, register_backward_fusion(BackwardLinearScale.fuse_backward_ops)
fuse_forward_linear_bias_add, register_backward_fusion(BackwardActivationBias.fuse_backward_ops)
) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops)
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
)
from .userbuffers_forward_linear import (
UserbuffersForwardLinear,
fuse_userbuffers_forward_linear,
)
...@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation): ...@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation):
]: ]:
# Get basic operation contexts # Get basic operation contexts
activation_op_ctx = basic_op_ctxs[0] bias_op_ctx = basic_op_ctxs[0]
bias_op_ctx = basic_op_ctxs[1] activation_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass # Saved tensors from forward pass
(act_input,) = activation_op_ctx.saved_tensors (act_input,) = activation_op_ctx.saved_tensors
...@@ -79,26 +79,27 @@ class BackwardActivationBias(FusedOperation): ...@@ -79,26 +79,27 @@ class BackwardActivationBias(FusedOperation):
# Clear activation input tensor # Clear activation input tensor
clear_tensor_data(act_input) clear_tensor_data(act_input)
return dx, [(), (db,)], [(), ()] return dx, [(db,), ()], [(), ()]
@staticmethod
def fuse_backward_activation_bias( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
recipe: Optional[Recipe], *,
) -> list[tuple[FusibleOperation, list[int]]]: recipe: Optional[Recipe] = None,
"""Fused backward dact + dbias + quantize **unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding Backward pass operations.
basic operations.
recipe : Recipe, optional recipe : Recipe, optional
Used quantization recipe Quantization recipe.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
...@@ -109,38 +110,28 @@ def fuse_backward_activation_bias( ...@@ -109,38 +110,28 @@ def fuse_backward_activation_bias(
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:3], ops[3:]
while len(ops) >= 3: while len(window) == 3:
out.extend(window) if (
isinstance(window[2], _fusible_activations)
# Check if first op is a supported activation and isinstance(window[1], Bias)
window, ops = ops[:1], ops[1:] and window[0].get_grad_output_quantizer() is not None
op, _ = window[0] ):
if not isinstance(op, _fusible_activations): # Construct fused op if window matches pattern
continue op = BackwardActivationBias(bias=window[1], activation=window[2])
window = [window[0], op]
# Check if second op is bias else:
op, _ = ops[0] # Shift window if window doesn't match pattern
if not isinstance(op, Bias): out.extend(window[:-2])
continue window = window[-2:]
# Check if third op has a grad input quantizer # Adjust window to expected size
op, _ = ops[1] out.extend(window[:-3])
if not op.num_quantizers("backward") > 0: window = window[-3:]
continue while ops and len(window) < 3:
window.append(ops[0])
window.extend(ops[:1])
ops = ops[1:] ops = ops[1:]
# Replace window with fused op
op = BackwardActivationBias(
activation=window[0][0],
bias=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation): ...@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Get basic operations # Get basic operations
rmsnorm_op = self.basic_ops[1] rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0] rmsnorm_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass # Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors x, rstdevs = rmsnorm_op_ctx.saved_tensors
...@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation): ...@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Check input tensors # Check input tensors
dtype = rmsnorm_op_ctx.dtype dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0] extra_grad = basic_op_grad_extra_outputs[0][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
...@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation): ...@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation):
grad_input = dx.view(grad_output.size()) grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims) grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_add_rmsnorm( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fused backward RMNorm + add ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding Backward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
out.extend(window) if (
isinstance(window[0], MakeExtraOutput)
# Check if first op is linear and isinstance(window[1], RMSNorm)
window, ops = ops[:1], ops[1:] and not window[0]._in_place
op, _ = window[0] ):
if not isinstance(op, RMSNorm): # Construct fused op if window matches pattern
continue op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1])
window = [op]
# Check if second op is "make extra output" else:
op, _ = ops[0] # Shift window if window doesn't match pattern
if not isinstance(op, MakeExtraOutput): out.extend(window[:-1])
continue window = window[-1:]
if op._in_place:
continue # Adjust window to expected size
window.extend(ops[:1]) out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:] ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation):
# Get basic operations # Get basic operations
linear_op = self.basic_ops[1] linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0] linear_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
...@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad = False accumulate_into_main_grad = False
# Linear backward pass # Linear backward pass
grad_input = basic_op_grad_extra_outputs[1][0] grad_input = basic_op_grad_extra_outputs[0][0]
grad_input, grad_weight = BasicLinear._functional_backward( grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
...@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation): ...@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False), zero=getattr(weight_param, "zero_out_wgrad", False),
) )
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_linear_add( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fused backward dgrad GEMM + add ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding Backward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
out.extend(window)
# Check if window matches pattern
# Check if first op is linear matches_pattern = True
window, ops = ops[:1], ops[1:] if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)):
op, _ = window[0] matches_pattern = False
if not isinstance(op, BasicLinear): elif not window[0]._in_place:
continue # Fused op accumulates grad input in-place
if op.tensor_parallel_mode == "column": matches_pattern = False
# Row tensor-parallelism requires communication after the elif window[1].tensor_parallel_mode == "column":
# GEMM # Column tensor-parallelism requires communication
continue # after the dgrad GEMM
matches_pattern = False
# Check if second op is "make extra output"
op, _ = ops[0] if matches_pattern:
if not isinstance(op, MakeExtraOutput): # Construct fused op if window matches pattern
continue op = BackwardLinearAdd(backward_add=window[0], linear=window[1])
if not op._in_place: window = [op]
continue else:
window.extend(ops[:1]) # Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:] ops = ops[1:]
# Replace window with fused op
op = BackwardLinearAdd(
linear=window[0][0],
backward_add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation): ...@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation):
# Get basic operations # Get basic operations
linear_op = self.basic_ops[0] linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1] linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1] scale_op = self.basic_ops[1]
# Saved tensors from forward pass # Saved tensors from forward pass
...@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation): ...@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False), zero=getattr(weight_param, "zero_out_wgrad", False),
) )
return grad_input, [(), (grad_weight,)], [(), ()] return grad_input, [(grad_weight,), ()], [(), ()]
@staticmethod
def fuse_backward_linear_scale( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fused backward dgrad GEMM + constant scale ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding Backward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
out.extend(window)
# Check if window matches pattern
# Check if first op is constant scale matches_pattern = True
window, ops = ops[:1], ops[1:] if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)):
op, _ = window[0] matches_pattern = False
if not isinstance(op, ConstantScale): elif window[0].tensor_parallel_mode == "column":
continue # Column tensor-parallelism requires communication
# after the dgrad GEMM
# Check if second op is linear matches_pattern = False
op, _ = ops[0]
if not isinstance(op, BasicLinear): if matches_pattern:
continue # Construct fused op if window matches pattern
if op.tensor_parallel_mode == "column": op = BackwardLinearScale(linear=window[0], scale=window[1])
# Column tensor-parallelism requires communication after the dgrad GEMM window = [op]
continue else:
window.extend(ops[:1]) # Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:] ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_linear_bias_activation( def fuse_forward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fuse forward GEMM + bias + activation ) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Forward pass operations and the indices of the corresponding Forward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated forward pass operations Updated forward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
out.extend(window)
# Check if window matches pattern
# Check if first op is linear matches_pattern = True
window, ops = ops[:1], ops[1:] if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)):
op1, _ = window[0] matches_pattern = False
if not isinstance(op1, BasicLinear): elif window[0].tensor_parallel_mode == "row":
continue # Row tensor-parallelism requires communication after
if op1.tensor_parallel_mode == "row": # the GEMM
# Row tensor-parallelism requires communication after the matches_pattern = False
# GEMM elif window[0].weight.dtype not in (torch.float16, torch.bfloat16):
continue
if op1.weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with # cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output # FP16 and BF16 output
continue matches_pattern = False
# Check if second op is bias if matches_pattern:
op2, _ = ops[0] # Construct fused op if window matches pattern
if not isinstance(op2, Bias):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasActivation( op = ForwardLinearBiasActivation(
linear=window[0][0], linear=window[0],
bias=window[1][0], bias=window[1],
activation=None, activation=None,
) )
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] window = [op]
window = [(op, basic_op_idxs)] else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -131,21 +131,21 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -131,21 +131,21 @@ class ForwardLinearBiasAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_linear_bias_add( def fuse_forward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fuse forward GEMM + bias + add ) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Forward pass operations and the indices of the corresponding Forward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated forward pass operations Updated forward pass operations
""" """
...@@ -153,50 +153,41 @@ def fuse_forward_linear_bias_add( ...@@ -153,50 +153,41 @@ def fuse_forward_linear_bias_add(
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window = []
while len(ops) >= 2: while ops:
# Shift window
out.extend(window) out.extend(window)
window = [ops[0]]
ops = ops[1:]
# Check if first op is linear # Check if first op is linear
window, ops = ops[:1], ops[1:] if not isinstance(window[0], BasicLinear):
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue continue
if op.tensor_parallel_mode == "row": if window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the # Row tensor-parallelism requires communication after
# GEMM # the GEMM
continue continue
linear = op linear = window[0]
op, _ = ops[0]
# Check if next op is bias # Check if next op is bias
bias = None bias = None
if isinstance(op, Bias): if ops and isinstance(ops[0], Bias):
bias = op window.append(ops[0])
window.extend(ops[:1])
ops = ops[1:] ops = ops[1:]
if len(ops) == 0: bias = window[-1]
continue
op, _ = ops[0]
# Check if next op is in-place add extra input # Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput): if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place:
continue window.append(ops[0])
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:] ops = ops[1:]
add = window[-1]
else:
continue
# Replace window with fused op # Replace window with fused op
op = ForwardLinearBiasAdd( op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add)
linear=linear, window = [op]
bias=bias,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation): ...@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_linear_scale_add( def fuse_forward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fuse forward GEMM + scale + add ) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Forward pass operations and the indices of the corresponding Forward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated forward pass operations Updated forward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:3], ops[3:]
while len(ops) >= 3: while len(window) == 3:
out.extend(window)
# Check if window matches pattern
# Check if first op is linear matches_pattern = True
window, ops = ops[:1], ops[1:] if not (
op, _ = window[0] isinstance(window[0], BasicLinear)
if not isinstance(op, BasicLinear): and isinstance(window[1], ConstantScale)
continue and isinstance(window[2], AddExtraInput)
if op.tensor_parallel_mode == "row": ):
# Row tensor-parallelism requires communication after the matches_pattern = False
# GEMM elif window[0].tensor_parallel_mode == "row":
continue # Row tensor-parallelism requires communication after
linear = op # the GEMM
op, _ = ops[0] matches_pattern = False
elif not window[2]._in_place:
# Check if next op is constant scale # Fused op accumulates output in-place
if not isinstance(op, ConstantScale): matches_pattern = False
continue
scale = op if matches_pattern:
window.extend(ops[:1]) # Construct fused op if window matches pattern
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd( op = ForwardLinearScaleAdd(
linear=linear, linear=window[0],
scale=scale, scale=window[1],
add=add, add=window[2],
) )
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] window = [op]
window = [(op, basic_op_idxs)] else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
out.extend(ops)
return out return out
...@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations # Get basic operations
idx = self._op_idxs["linear"] idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx] linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[-1] linear_op_ctx = basic_op_ctxs[0]
bias_op = None bias_op = None
if self._op_idxs["bias"] is not None: if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
...@@ -578,25 +578,26 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -578,25 +578,26 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))] grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs return grad_input, grad_params, grad_extra_inputs
@staticmethod
def fuse_userbuffers_backward_linear( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Substitute linear operations with Userbuffers implementation ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding Backward pass operations.
basic operations. recipe : Recipe, optional
Quantization recipe.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
...@@ -605,46 +606,33 @@ def fuse_userbuffers_backward_linear( ...@@ -605,46 +606,33 @@ def fuse_userbuffers_backward_linear(
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops return ops
# Sliding window in list of ops # Scan through ops, fusing if possible
out = []
window = [] window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[-1][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.insert(0, ops[-1])
ops = ops[:-1]
return window[0][0]
# Scan through ops in reverse order, fusing if possible
out_reversed = []
while ops: while ops:
out_reversed.extend(reversed(window))
window.clear()
# Check if next op is linear # Shift window
next_op = pop_next_op() out.extend(window)
if not isinstance(next_op, BasicLinear): window, ops = ops[:1], ops[1:]
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue continue
linear = next_op linear = window[0]
if linear._userbuffers_options is None: if linear._userbuffers_options is None:
continue continue
# Check if next op is bias # Check if next op is bias
bias = None bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias = pop_next_op() bias, ops = ops[0], ops[1:]
window.append(bias)
# Check if next op is reduce-scatter # Check if next op is reduce-scatter
reduce_scatter = None reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter = pop_next_op() reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations # Check for invalid combinations
if reduce_scatter is None: if reduce_scatter is None:
...@@ -666,11 +654,8 @@ def fuse_userbuffers_backward_linear( ...@@ -666,11 +654,8 @@ def fuse_userbuffers_backward_linear(
bias=bias, bias=bias,
reduce_scatter=reduce_scatter, reduce_scatter=reduce_scatter,
) )
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] window = [op]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out_reversed.extend(reversed(window)) out.extend(window)
out = out_reversed
out.reverse()
return out return out
...@@ -369,21 +369,21 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -369,21 +369,21 @@ class UserbuffersForwardLinear(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_userbuffers_forward_linear( def fuse_forward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Substitute linear operations with Userbuffers implementation ) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters Parameters
---------- ----------
ops : list of tuples ops : list of FusibleOperation
Forward pass operations and the indices of the corresponding Forward pass operations.
basic operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated forward pass operations Updated forward pass operations
""" """
...@@ -392,46 +392,33 @@ def fuse_userbuffers_forward_linear( ...@@ -392,46 +392,33 @@ def fuse_userbuffers_forward_linear(
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[0][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.append(ops[0])
ops = ops[1:]
return window[-1][0]
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = []
while ops: while ops:
# Shift window
out.extend(window) out.extend(window)
window.clear() window, ops = ops[:1], ops[1:]
# Check if next op is linear # Check if first op is linear
next_op = pop_next_op() if not isinstance(window[0], BasicLinear):
if not isinstance(next_op, BasicLinear):
continue continue
linear = next_op linear = window[0]
if linear._userbuffers_options is None: if linear._userbuffers_options is None:
continue continue
# Check if next op is bias # Check if next op is bias
bias = None bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias = pop_next_op() bias, ops = ops[0], ops[1:]
window.append(bias)
# Check if next op is reduce-scatter # Check if next op is reduce-scatter
reduce_scatter = None reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter = pop_next_op() reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations # Check for invalid combinations
if reduce_scatter is None: if reduce_scatter is None:
...@@ -453,8 +440,7 @@ def fuse_userbuffers_forward_linear( ...@@ -453,8 +440,7 @@ def fuse_userbuffers_forward_linear(
bias=bias, bias=bias,
reduce_scatter=reduce_scatter, reduce_scatter=reduce_scatter,
) )
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] window = [op]
window = [(op, basic_op_idxs)]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
......
...@@ -5,33 +5,20 @@ ...@@ -5,33 +5,20 @@
"""Manager class for a pipeline of fusible operations.""" """Manager class for a pipeline of fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Sequence
from typing import Any, Optional
import itertools import itertools
from typing import Any, Optional, TypeAlias
import torch import torch
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import ( from ..quantized_tensor import prepare_for_saving, restore_from_saved
from .op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
FusedOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_add_rmsnorm,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
...@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool: ...@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool:
return _is_graph_capturing_function() return _is_graph_capturing_function()
# Type alias for a function that may perform operation fusion
OperationFusionFunction: TypeAlias = (
"Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]"
)
class _OperationFuserAutogradFunction(torch.autograd.Function): class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations """Autograd function for a pipeline of operations
...@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx = grad_output dx = grad_output
grad_params = [None for _ in range(len(basic_ops))] grad_params = [None for _ in range(len(basic_ops))]
grad_extra_inputs = [None for _ in range(len(basic_ops))] grad_extra_inputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops: for op, basic_op_idxs in reversed(backward_ops):
# Stop if no more gradients are required # Stop if no more gradients are required
if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
...@@ -315,6 +308,10 @@ class OperationFuser: ...@@ -315,6 +308,10 @@ class OperationFuser:
""" """
# Functions to perform operation fusion
forward_fusion_functions: list[OperationFusionFunction] = []
backward_fusion_functions: list[OperationFusionFunction] = []
def __init__( def __init__(
self, self,
ops: list[FusibleOperation], ops: list[FusibleOperation],
...@@ -334,7 +331,7 @@ class OperationFuser: ...@@ -334,7 +331,7 @@ class OperationFuser:
self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass, will be populated in fuse_ops # Ops for forward and backward pass, will be populated in maybe_fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]]
...@@ -349,31 +346,48 @@ class OperationFuser: ...@@ -349,31 +346,48 @@ class OperationFuser:
self._flat_basic_op_params = sum(self._basic_op_params, []) self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod @classmethod
def _fuse_forward_ops( def _fuse_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops
@classmethod
def _fuse_backward_ops(
cls, cls,
ops: list[tuple[FusibleOperation, list[int]]], basic_ops: Sequence[BasicOperation],
fusion_funcs: Iterable[OperationFusionFunction],
recipe: Optional[Recipe], recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass""" """Apply operation fusions"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops) # Apply op fusions
ops = fuse_backward_linear_scale(ops) fused_ops = list(basic_ops)
ops = fuse_backward_activation_bias(ops, recipe) for func in fusion_funcs:
ops = fuse_backward_add_rmsnorm(ops) fused_ops = func(fused_ops, recipe=recipe)
return ops
def raise_mismatch_error() -> None:
"""Throw error indicating invalid op fusion"""
raise RuntimeError(
"Found mismatch after fusing operations "
f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, "
f"fused_ops={[o.__class__.__name__ for o in fused_ops]})"
)
# Determine basic op indices corresponding to each op
out = []
idx = 0
for op in fused_ops:
if isinstance(op, FusedOperation):
idxs = []
for basic_op in op.basic_ops:
if basic_op is not basic_ops[idx]:
raise_mismatch_error()
idxs.append(idx)
idx += 1
out.append((op, idxs))
else:
if op is not basic_ops[idx]:
raise_mismatch_error()
out.append((op, [idx]))
idx += 1
if idx != len(basic_ops):
raise_mismatch_error()
return out
def maybe_fuse_ops( def maybe_fuse_ops(
self, self,
...@@ -424,12 +438,16 @@ class OperationFuser: ...@@ -424,12 +438,16 @@ class OperationFuser:
op.pre_first_fuser_forward() op.pre_first_fuser_forward()
# Prepare basic op lists for fusions # Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] self._forward_ops = OperationFuser._fuse_ops(
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) self._basic_ops,
OperationFuser.forward_fusion_functions,
# Fuse ops recipe=recipe,
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) )
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) self._backward_ops = OperationFuser._fuse_ops(
self._basic_ops,
OperationFuser.backward_fusion_functions,
recipe=recipe,
)
# Save current fusion params # Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params self.recipe_type, self.first_op_requiring_backward = fusion_params
...@@ -491,3 +509,55 @@ class OperationFuser: ...@@ -491,3 +509,55 @@ class OperationFuser:
*extra_inputs, *extra_inputs,
) )
return forward_func(*args) return forward_func(*args)
def register_forward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for forward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.forward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.forward_fusion_functions.append(op_fusion_func)
def register_backward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for backward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.backward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.backward_fusion_functions.append(op_fusion_func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment