Unverified Commit 72592763 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support user-defined op fusions (#2597)



* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for custom ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings and numerical test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak pattern matching logic with fixed window sizes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use TF32 tols in fused op tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Backpropagate fixes from #2622
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a0a89a8e
...@@ -2329,13 +2329,13 @@ class TestFusedOps: ...@@ -2329,13 +2329,13 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops backward_ops = model._module_groups[0]._backward_ops
if with_quantization: if with_quantization:
assert len(backward_ops) == 2 assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias) assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], BackwardActivationBias)
else: else:
assert len(backward_ops) == 3 assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type) assert isinstance(backward_ops[0][0], te_ops.Quantize)
assert isinstance(backward_ops[1][0], te_ops.Bias) assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize) assert isinstance(backward_ops[2][0], act_type)
# Expected numerical error # Expected numerical error
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
...@@ -2930,3 +2930,317 @@ class TestSequentialModules: ...@@ -2930,3 +2930,317 @@ class TestSequentialModules:
if bias: if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
class TestCustomOps:
"""Test with ops that are defined externally"""
def test_custom_basic_op(
self,
*,
shape: Iterable[int] = (7, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Custom basic op"""
class CustomScaleOp(te.ops.BasicOperation):
"""Custom op that applies a learnable scale"""
def __init__(self) -> None:
super().__init__()
self.scale: torch.nn.Parameter
scale = torch.ones((), dtype=dtype, device=device)
scale = torch.nn.Parameter(scale)
self.register_parameter("scale", scale)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
ctx.save_for_backward(self.scale, input_)
return self.scale * input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> torch.Tensor:
(
scale,
input_,
) = ctx.saved_tensors
grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1))
grad_scale = grad_scale.reshape(())
grad_input = scale * grad_output
return grad_input, (grad_scale,)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = w_ref * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = CustomScaleOp()
forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity())
with torch.no_grad():
op.scale.copy_(w_test)
del w_test
y_test = forward(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_forward_fused_op(
self,
*,
shape: Iterable[int] = (7, 11),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in forward pass"""
class CustomForwardLinearSiLU(te.ops.FusedOperation):
"""Custom fused op for GEMM + SiLU"""
_enabled = True
def __init__(self, *, linear, silu) -> None:
super().__init__((linear, silu))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
**unused,
) -> torch.Tensor:
weight = self.basic_ops[0].weight
dtype = weight.dtype
device = weight.device
# Perform compute on CPU, because why not?
x = input_.cpu()
w = weight.cpu()
y = torch.matmul(x, w.T)
z = torch.nn.functional.silu(y)
out = z.to(device=device)
# Save state for linear backward
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx.save_for_backward(input_, weight)
linear_op_ctx.with_quantized_compute = False
linear_op_ctx.input_quantizer = None
linear_op_ctx.weight_quantizer = None
linear_op_ctx.grad_output_quantizer = None
linear_op_ctx.grad_input_quantizer = None
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = True
linear_op_ctx.weight_requires_grad = True
# Save state for SiLU backward
silu_op_ctx = basic_op_ctxs[1]
silu_op_ctx.save_for_backward(y.to(device=device))
silu_op_ctx.dtype = dtype
silu_op_ctx.prev_op_grad_output_quantizer = None
return out, [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomForwardLinearSiLU._enabled:
CustomForwardLinearSiLU._enabled = False
op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref = torch.nn.functional.silu(y_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops)
model = te.ops.Sequential(
te.ops.Linear(shape[-1], shape[-1], bias=False),
te.ops.SiLU(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
def test_custom_backward_fused_op(
self,
*,
shape: Iterable[int] = (13, 5),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Custom fused op in backward pass"""
class CustomBackwardLinearScale(te.ops.FusedOperation):
"""Custom fused op for backward linear + scale"""
_enabled: bool = True
def __init__(self, *, scale, linear) -> None:
super().__init__((scale, linear))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
**unused,
) -> torch.Tensor:
# Load state from linear forward
linear_op_ctx = basic_op_ctxs[1]
x, w = linear_op_ctx.saved_tensors
dtype = linear_op_ctx.dtype
device = w.device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale = self.basic_ops[0].scale
dy = grad_output.double()
x = x.double()
w = w.double()
dx = torch.matmul(dy, scale * w)
dw = torch.matmul(dy.T, x)
dx = dx.to(dtype=dtype)
dw = dw.to(dtype=dtype)
return dx, [(), (dw,)], [(), ()]
@staticmethod
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomBackwardLinearScale._enabled:
CustomBackwardLinearScale._enabled = False
op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1])
return [op] + ops[2:]
return ops
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(shape[-1], shape[-1]),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
scale = 1.234
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(scale * x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True)
model = te.ops.Sequential(
te.ops.ConstantScale(scale),
te.ops.Linear(shape[-1], shape[-1], bias=False),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], CustomBackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
...@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change. ...@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change.
""" """
from transformer_engine.pytorch.ops.basic import * from .basic import *
from transformer_engine.pytorch.ops.linear import Linear from .fuser import register_backward_fusion, register_forward_fusion
from transformer_engine.pytorch.ops.op import FusibleOperation from .linear import Linear
from transformer_engine.pytorch.ops.sequential import Sequential from .op import BasicOperation, FusedOperation, FusibleOperation
from .sequential import Sequential
from . import fused
...@@ -4,39 +4,27 @@ ...@@ -4,39 +4,27 @@
"""Compound tensor operation supported by the operation fuser.""" """Compound tensor operation supported by the operation fuser."""
from .backward_activation_bias import ( from ..fuser import register_backward_fusion, register_forward_fusion
BackwardActivationBias, from .backward_activation_bias import BackwardActivationBias
fuse_backward_activation_bias, from .backward_add_rmsnorm import BackwardAddRMSNorm
) from .backward_linear_add import BackwardLinearAdd
from .backward_add_rmsnorm import ( from .backward_linear_scale import BackwardLinearScale
BackwardAddRMSNorm, from .forward_linear_bias_activation import ForwardLinearBiasActivation
fuse_backward_add_rmsnorm, from .forward_linear_bias_add import ForwardLinearBiasAdd
) from .forward_linear_scale_add import ForwardLinearScaleAdd
from .backward_linear_add import ( from .userbuffers_backward_linear import UserbuffersBackwardLinear
BackwardLinearAdd, from .userbuffers_forward_linear import UserbuffersForwardLinear
fuse_backward_linear_add,
)
from .backward_linear_scale import ( # Register forward fusions
BackwardLinearScale, register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops)
fuse_backward_linear_scale, register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops)
) register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops)
from .forward_linear_bias_activation import ( register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops)
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation, # Register backward fusions
) register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops)
from .forward_linear_bias_add import ( register_backward_fusion(BackwardLinearAdd.fuse_backward_ops)
ForwardLinearBiasAdd, register_backward_fusion(BackwardLinearScale.fuse_backward_ops)
fuse_forward_linear_bias_add, register_backward_fusion(BackwardActivationBias.fuse_backward_ops)
) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops)
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
)
from .userbuffers_forward_linear import (
UserbuffersForwardLinear,
fuse_userbuffers_forward_linear,
)
...@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation): ...@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation):
]: ]:
# Get basic operation contexts # Get basic operation contexts
activation_op_ctx = basic_op_ctxs[0] bias_op_ctx = basic_op_ctxs[0]
bias_op_ctx = basic_op_ctxs[1] activation_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass # Saved tensors from forward pass
(act_input,) = activation_op_ctx.saved_tensors (act_input,) = activation_op_ctx.saved_tensors
...@@ -79,68 +79,59 @@ class BackwardActivationBias(FusedOperation): ...@@ -79,68 +79,59 @@ class BackwardActivationBias(FusedOperation):
# Clear activation input tensor # Clear activation input tensor
clear_tensor_data(act_input) clear_tensor_data(act_input)
return dx, [(), (db,)], [(), ()] return dx, [(db,), ()], [(), ()]
@staticmethod
def fuse_backward_activation_bias( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
recipe: Optional[Recipe], *,
) -> list[tuple[FusibleOperation, list[int]]]: recipe: Optional[Recipe] = None,
"""Fused backward dact + dbias + quantize **unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
Parameters """Apply operation fusion for backward pass.
----------
ops : list of tuples Parameters
Backward pass operations and the indices of the corresponding ----------
basic operations. ops : list of FusibleOperation
recipe : Recipe, optional Backward pass operations.
Used quantization recipe recipe : Recipe, optional
Quantization recipe.
Returns
------- Returns
ops : list of tuples -------
Updated backward pass operations ops : list of FusibleOperation
Updated backward pass operations
"""
"""
# Check if recipe supports bias activation fusion
if recipe is None: # Check if recipe supports bias activation fusion
return ops if recipe is None:
return ops
# Scan through ops, fusing if possible
out = [] # Scan through ops, fusing if possible
window = [] out = []
while len(ops) >= 3: window, ops = ops[:3], ops[3:]
while len(window) == 3:
if (
isinstance(window[2], _fusible_activations)
and isinstance(window[1], Bias)
and window[0].get_grad_output_quantizer() is not None
):
# Construct fused op if window matches pattern
op = BackwardActivationBias(bias=window[1], activation=window[2])
window = [window[0], op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window) out.extend(window)
return out
# Check if first op is a supported activation
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, _fusible_activations):
continue
# Check if second op is bias
op, _ = ops[0]
if not isinstance(op, Bias):
continue
# Check if third op has a grad input quantizer
op, _ = ops[1]
if not op.num_quantizers("backward") > 0:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardActivationBias(
activation=window[0][0],
bias=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation): ...@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Get basic operations # Get basic operations
rmsnorm_op = self.basic_ops[1] rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0] rmsnorm_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass # Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors x, rstdevs = rmsnorm_op_ctx.saved_tensors
...@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation): ...@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Check input tensors # Check input tensors
dtype = rmsnorm_op_ctx.dtype dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0] extra_grad = basic_op_grad_extra_outputs[0][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
...@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation): ...@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation):
grad_input = dx.view(grad_output.size()) grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims) grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_add_rmsnorm( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fused backward RMNorm + add ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
---------- Parameters
ops : list of tuples ----------
Backward pass operations and the indices of the corresponding ops : list of FusibleOperation
basic operations. Backward pass operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
if (
isinstance(window[0], MakeExtraOutput)
and isinstance(window[1], RMSNorm)
and not window[0]._in_place
):
# Construct fused op if window matches pattern
op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window) out.extend(window)
return out
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, RMSNorm):
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation):
# Get basic operations # Get basic operations
linear_op = self.basic_ops[1] linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0] linear_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
...@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad = False accumulate_into_main_grad = False
# Linear backward pass # Linear backward pass
grad_input = basic_op_grad_extra_outputs[1][0] grad_input = basic_op_grad_extra_outputs[0][0]
grad_input, grad_weight = BasicLinear._functional_backward( grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
...@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation): ...@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False), zero=getattr(weight_param, "zero_out_wgrad", False),
) )
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_linear_add( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fused backward dgrad GEMM + add ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
---------- Parameters
ops : list of tuples ----------
Backward pass operations and the indices of the corresponding ops : list of FusibleOperation
basic operations. Backward pass operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)):
matches_pattern = False
elif not window[0]._in_place:
# Fused op accumulates grad input in-place
matches_pattern = False
elif window[1].tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardLinearAdd(backward_add=window[0], linear=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window) out.extend(window)
return out
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Row tensor-parallelism requires communication after the
# GEMM
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if not op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearAdd(
linear=window[0][0],
backward_add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation): ...@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation):
# Get basic operations # Get basic operations
linear_op = self.basic_ops[0] linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1] linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1] scale_op = self.basic_ops[1]
# Saved tensors from forward pass # Saved tensors from forward pass
...@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation): ...@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False), zero=getattr(weight_param, "zero_out_wgrad", False),
) )
return grad_input, [(), (grad_weight,)], [(), ()] return grad_input, [(grad_weight,), ()], [(), ()]
@staticmethod
def fuse_backward_linear_scale( def fuse_backward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fused backward dgrad GEMM + constant scale ) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
---------- Parameters
ops : list of tuples ----------
Backward pass operations and the indices of the corresponding ops : list of FusibleOperation
basic operations. Backward pass operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated backward pass operations Updated backward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)):
matches_pattern = False
elif window[0].tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardLinearScale(linear=window[0], scale=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window) out.extend(window)
return out
# Check if first op is constant scale
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, ConstantScale):
continue
# Check if second op is linear
op, _ = ops[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_linear_bias_activation( def fuse_forward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fuse forward GEMM + bias + activation ) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
---------- Parameters
ops : list of tuples ----------
Forward pass operations and the indices of the corresponding ops : list of FusibleOperation
basic operations. Forward pass operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated forward pass operations Updated forward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:2], ops[2:]
while len(ops) >= 2: while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)):
matches_pattern = False
elif window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern = False
elif window[0].weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = ForwardLinearBiasActivation(
linear=window[0],
bias=window[1],
activation=None,
)
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window) out.extend(window)
return out
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op1, _ = window[0]
if not isinstance(op1, BasicLinear):
continue
if op1.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
if op1.weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
# Check if second op is bias
op2, _ = ops[0]
if not isinstance(op2, Bias):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasActivation(
linear=window[0][0],
bias=window[1][0],
activation=None,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -131,72 +131,63 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -131,72 +131,63 @@ class ForwardLinearBiasAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Shift window
out.extend(window)
window = [ops[0]]
ops = ops[1:]
def fuse_forward_linear_bias_add( # Check if first op is linear
ops: list[tuple[FusibleOperation, list[int]]], if not isinstance(window[0], BasicLinear):
) -> list[tuple[FusibleOperation, list[int]]]: continue
"""Fuse forward GEMM + bias + add if window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
Parameters # the GEMM
---------- continue
ops : list of tuples linear = window[0]
Forward pass operations and the indices of the corresponding
basic operations.
Returns # Check if next op is bias
------- bias = None
ops : list of tuples if ops and isinstance(ops[0], Bias):
Updated forward pass operations window.append(ops[0])
ops = ops[1:]
bias = window[-1]
# Check if next op is in-place add extra input
if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place:
window.append(ops[0])
ops = ops[1:]
add = window[-1]
else:
continue
""" # Replace window with fused op
op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add)
window = [op]
# Scan through ops, fusing if possible # Return list of ops
out = []
window = []
while len(ops) >= 2:
out.extend(window) out.extend(window)
return out
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is bias
bias = None
if isinstance(op, Bias):
bias = op
window.extend(ops[:1])
ops = ops[1:]
if len(ops) == 0:
continue
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasAdd(
linear=linear,
bias=bias,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation): ...@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_linear_scale_add( def fuse_forward_ops(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[FusibleOperation],
) -> list[tuple[FusibleOperation, list[int]]]: **unused, # pylint: disable=unused-argument
"""Fuse forward GEMM + scale + add ) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
---------- Parameters
ops : list of tuples ----------
Forward pass operations and the indices of the corresponding ops : list of FusibleOperation
basic operations. Forward pass operations.
Returns Returns
------- -------
ops : list of tuples ops : list of FusibleOperation
Updated forward pass operations Updated forward pass operations
""" """
# Scan through ops, fusing if possible # Scan through ops, fusing if possible
out = [] out = []
window = [] window, ops = ops[:3], ops[3:]
while len(ops) >= 3: while len(window) == 3:
# Check if window matches pattern
matches_pattern = True
if not (
isinstance(window[0], BasicLinear)
and isinstance(window[1], ConstantScale)
and isinstance(window[2], AddExtraInput)
):
matches_pattern = False
elif window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern = False
elif not window[2]._in_place:
# Fused op accumulates output in-place
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = ForwardLinearScaleAdd(
linear=window[0],
scale=window[1],
add=window[2],
)
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window) out.extend(window)
return out
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is constant scale
if not isinstance(op, ConstantScale):
continue
scale = op
window.extend(ops[:1])
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd(
linear=linear,
scale=scale,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations # Get basic operations
idx = self._op_idxs["linear"] idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx] linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[-1] linear_op_ctx = basic_op_ctxs[0]
bias_op = None bias_op = None
if self._op_idxs["bias"] is not None: if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
...@@ -578,99 +578,84 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -578,99 +578,84 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))] grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs return grad_input, grad_params, grad_extra_inputs
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
def fuse_userbuffers_backward_linear( Parameters
ops: list[tuple[FusibleOperation, list[int]]], ----------
) -> list[tuple[FusibleOperation, list[int]]]: ops : list of FusibleOperation
"""Substitute linear operations with Userbuffers implementation Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Parameters Returns
---------- -------
ops : list of tuples ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding Updated backward pass operations
basic operations.
Returns """
-------
ops : list of tuples
Updated backward pass operations
""" # Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Return immediately if environment is not distributed # Scan through ops, fusing if possible
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: out = []
return ops window = []
while ops:
# Sliding window in list of ops
window = [] # Shift window
out.extend(window)
def peek_next_op() -> Optional[FusibleOperation]: window, ops = ops[:1], ops[1:]
"""Get next op in list of ops"""
nonlocal ops # Check if first op is linear
if not ops: if not isinstance(window[0], BasicLinear):
return None
return ops[-1][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.insert(0, ops[-1])
ops = ops[:-1]
return window[0][0]
# Scan through ops in reverse order, fusing if possible
out_reversed = []
while ops:
out_reversed.extend(reversed(window))
window.clear()
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue continue
if reduce_scatter.process_group_size == 1: linear = window[0]
if linear._userbuffers_options is None:
continue continue
# Replace window with fused op # Check if next op is bias
op = UserbuffersBackwardLinear( bias = None
linear=linear, if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias=bias, bias, ops = ops[0], ops[1:]
reduce_scatter=reduce_scatter, window.append(bias)
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] # Check if next op is reduce-scatter
window = [(op, basic_op_idxs)] reduce_scatter = None
if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
# Return list of ops reduce_scatter, ops = ops[0], ops[1:]
out_reversed.extend(reversed(window)) window.append(reduce_scatter)
out = out_reversed
out.reverse() # Check for invalid combinations
return out if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersBackwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
window = [op]
# Return list of ops
out.extend(window)
return out
...@@ -369,93 +369,79 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -369,93 +369,79 @@ class UserbuffersForwardLinear(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
def fuse_userbuffers_forward_linear( Parameters
ops: list[tuple[FusibleOperation, list[int]]], ----------
) -> list[tuple[FusibleOperation, list[int]]]: ops : list of FusibleOperation
"""Substitute linear operations with Userbuffers implementation Forward pass operations.
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
""" Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
# Return immediately if environment is not distributed """
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[0][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.append(ops[0])
ops = ops[1:]
return window[-1][0]
# Scan through ops, fusing if possible
out = []
while ops:
out.extend(window)
window.clear()
# Check if next op is linear # Return immediately if environment is not distributed
next_op = pop_next_op() if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
if not isinstance(next_op, BasicLinear): return ops
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Check if next op is bias # Scan through ops, fusing if possible
bias = None out = []
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): window = []
bias = pop_next_op() while ops:
# Check if next op is reduce-scatter # Shift window
reduce_scatter = None out.extend(window)
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): window, ops = ops[:1], ops[1:]
reduce_scatter = pop_next_op()
# Check for invalid combinations # Check if first op is linear
if reduce_scatter is None: if not isinstance(window[0], BasicLinear):
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue continue
if reduce_scatter.process_group_size == 1: linear = window[0]
if linear._userbuffers_options is None:
continue continue
# Replace window with fused op # Check if next op is bias
op = UserbuffersForwardLinear( bias = None
linear=linear, if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias=bias, bias, ops = ops[0], ops[1:]
reduce_scatter=reduce_scatter, window.append(bias)
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] # Check if next op is reduce-scatter
window = [(op, basic_op_idxs)] reduce_scatter = None
if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersForwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
window = [op]
# Return list of ops # Return list of ops
out.extend(window) out.extend(window)
return out return out
...@@ -5,33 +5,20 @@ ...@@ -5,33 +5,20 @@
"""Manager class for a pipeline of fusible operations.""" """Manager class for a pipeline of fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Sequence
from typing import Any, Optional
import itertools import itertools
from typing import Any, Optional, TypeAlias
import torch import torch
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import ( from ..quantized_tensor import prepare_for_saving, restore_from_saved
from .op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
FusedOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_add_rmsnorm,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
...@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool: ...@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool:
return _is_graph_capturing_function() return _is_graph_capturing_function()
# Type alias for a function that may perform operation fusion
OperationFusionFunction: TypeAlias = (
"Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]"
)
class _OperationFuserAutogradFunction(torch.autograd.Function): class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations """Autograd function for a pipeline of operations
...@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx = grad_output dx = grad_output
grad_params = [None for _ in range(len(basic_ops))] grad_params = [None for _ in range(len(basic_ops))]
grad_extra_inputs = [None for _ in range(len(basic_ops))] grad_extra_inputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops: for op, basic_op_idxs in reversed(backward_ops):
# Stop if no more gradients are required # Stop if no more gradients are required
if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
...@@ -315,6 +308,10 @@ class OperationFuser: ...@@ -315,6 +308,10 @@ class OperationFuser:
""" """
# Functions to perform operation fusion
forward_fusion_functions: list[OperationFusionFunction] = []
backward_fusion_functions: list[OperationFusionFunction] = []
def __init__( def __init__(
self, self,
ops: list[FusibleOperation], ops: list[FusibleOperation],
...@@ -334,7 +331,7 @@ class OperationFuser: ...@@ -334,7 +331,7 @@ class OperationFuser:
self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass, will be populated in fuse_ops # Ops for forward and backward pass, will be populated in maybe_fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]]
...@@ -349,31 +346,48 @@ class OperationFuser: ...@@ -349,31 +346,48 @@ class OperationFuser:
self._flat_basic_op_params = sum(self._basic_op_params, []) self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod @classmethod
def _fuse_forward_ops( def _fuse_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops
@classmethod
def _fuse_backward_ops(
cls, cls,
ops: list[tuple[FusibleOperation, list[int]]], basic_ops: Sequence[BasicOperation],
fusion_funcs: Iterable[OperationFusionFunction],
recipe: Optional[Recipe], recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass""" """Apply operation fusions"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops) # Apply op fusions
ops = fuse_backward_linear_scale(ops) fused_ops = list(basic_ops)
ops = fuse_backward_activation_bias(ops, recipe) for func in fusion_funcs:
ops = fuse_backward_add_rmsnorm(ops) fused_ops = func(fused_ops, recipe=recipe)
return ops
def raise_mismatch_error() -> None:
"""Throw error indicating invalid op fusion"""
raise RuntimeError(
"Found mismatch after fusing operations "
f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, "
f"fused_ops={[o.__class__.__name__ for o in fused_ops]})"
)
# Determine basic op indices corresponding to each op
out = []
idx = 0
for op in fused_ops:
if isinstance(op, FusedOperation):
idxs = []
for basic_op in op.basic_ops:
if basic_op is not basic_ops[idx]:
raise_mismatch_error()
idxs.append(idx)
idx += 1
out.append((op, idxs))
else:
if op is not basic_ops[idx]:
raise_mismatch_error()
out.append((op, [idx]))
idx += 1
if idx != len(basic_ops):
raise_mismatch_error()
return out
def maybe_fuse_ops( def maybe_fuse_ops(
self, self,
...@@ -424,12 +438,16 @@ class OperationFuser: ...@@ -424,12 +438,16 @@ class OperationFuser:
op.pre_first_fuser_forward() op.pre_first_fuser_forward()
# Prepare basic op lists for fusions # Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] self._forward_ops = OperationFuser._fuse_ops(
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) self._basic_ops,
OperationFuser.forward_fusion_functions,
# Fuse ops recipe=recipe,
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) )
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) self._backward_ops = OperationFuser._fuse_ops(
self._basic_ops,
OperationFuser.backward_fusion_functions,
recipe=recipe,
)
# Save current fusion params # Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params self.recipe_type, self.first_op_requiring_backward = fusion_params
...@@ -491,3 +509,55 @@ class OperationFuser: ...@@ -491,3 +509,55 @@ class OperationFuser:
*extra_inputs, *extra_inputs,
) )
return forward_func(*args) return forward_func(*args)
def register_forward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for forward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.forward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.forward_fusion_functions.append(op_fusion_func)
def register_backward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for backward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.backward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.backward_fusion_functions.append(op_fusion_func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment