Unverified Commit e0aa7992 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Branching operations (#1027)



* Add op for in-place add
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add op for in-place add
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add op that adds extra output to fuser
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add fused op for GEMM+bias+add
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add fused op for dgrad+add
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @ptrendx

Output tensor dtype and device take precedence over weight tensor in linear functional API. Move some index calculation to fuser constructor. Avoid some unnecessary dereferences.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update transformer_engine/pytorch/ops/fuser.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8b326059
...@@ -15,8 +15,10 @@ from transformer_engine.pytorch.float8_tensor import Float8Tensor ...@@ -15,8 +15,10 @@ from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused_forward import ( from transformer_engine.pytorch.ops.fused import (
BackwardLinearAdd,
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
) )
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -84,14 +86,13 @@ def make_reference_and_test_tensors( ...@@ -84,14 +86,13 @@ def make_reference_and_test_tensors(
""" """
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if test_is_fp8:
test = Float8Tensor.to_float8(ref) test = Float8Tensor.to_float8(test)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous() test._transpose = test._transpose.contiguous()
test._transpose_invalid = False test._transpose_invalid = False
else: elif test.data_ptr() == ref.data_ptr():
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone() test = test.clone()
ref.copy_(test) ref.copy_(test)
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
...@@ -320,14 +321,13 @@ class TestBasicOps: ...@@ -320,14 +321,13 @@ class TestBasicOps:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("in_shape", ((1,),))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("fp8", (False, True))
def test_identity( def test_identity(
self, self,
*, *,
in_shape: Iterable[int], in_shape: Iterable[int] = (1,),
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
fp8: bool, fp8: bool,
...@@ -737,6 +737,123 @@ class TestBasicOps: ...@@ -737,6 +737,123 @@ class TestBasicOps:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
def test_add_in_place(
self,
*,
in_shape: Iterable[int] = (1,),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
x2_ref, x2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x2_ref.detach()
y_ref += x1_ref
dx1_ref = dy_ref
dx2_ref = dy_ref
# Implementation with fusible operation
op = te_ops.AddInPlace()
y_test = op(x1_test, x2_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(x1_test._fp8_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
def test_make_extra_output(
self,
*,
in_shape: Iterable[int] = (1,),
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y1_ref = x_ref
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operation
op = te_ops.MakeExtraOutput()
y1_test, y2_test = op(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
# Check results
tols = dtype_tols(dtype)
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0)
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
class TestFusedOps: class TestFusedOps:
"""Tests for fused operations""" """Tests for fused operations"""
...@@ -754,7 +871,7 @@ class TestFusedOps: ...@@ -754,7 +871,7 @@ class TestFusedOps:
@pytest.mark.parametrize("fp8_compute", (False, True)) @pytest.mark.parametrize("fp8_compute", (False, True))
@pytest.mark.parametrize("fp8_input", (False, True)) @pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_weight", (False, True)) @pytest.mark.parametrize("fp8_weight", (False, True))
def test_linear_bias_activation( def test_forward_linear_bias_activation(
self, self,
*, *,
bias: bool = True, bias: bool = True,
...@@ -766,7 +883,7 @@ class TestFusedOps: ...@@ -766,7 +883,7 @@ class TestFusedOps:
fp8_input: bool, fp8_input: bool,
fp8_weight: bool, fp8_weight: bool,
) -> None: ) -> None:
"""GEMM + bias + activation""" """Forward GEMM + bias + activation"""
# Make input and weight shapes consistent # Make input and weight shapes consistent
out_features, in_features = weight_shape out_features, in_features = weight_shape
...@@ -951,3 +1068,247 @@ class TestFusedOps: ...@@ -951,3 +1068,247 @@ class TestFusedOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw0_test, w0_ref.grad, **tols) torch.testing.assert_close(dw0_test, w0_ref.grad, **tols)
torch.testing.assert_close(dw1_test, w1_ref.grad, **tols) torch.testing.assert_close(dw1_test, w1_ref.grad, **tols)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_compute", (False, True))
def test_forward_linear_bias_add(
self,
*,
bias: bool,
weight_shape: tuple[int, int] = (16, 16),
in_shape: Iterable[int] = (16, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_compute: bool,
fp8_input: bool = False,
fp8_weight: bool = False,
fp8_output: bool = False,
) -> None:
"""Forward GEMM + bias + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
if fp8_input or fp8_weight or fp8_output or fp8_compute:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
if fp8_compute:
if (
math.prod(in_shape[:-1]) % 16 != 0
or in_features % 16 != 0
or out_features % 16 != 0
):
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
if fp8_output and not fp8_compute:
pytest.skip("FP8 output requires FP8 compute")
if fp8_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
)
x2_ref, x2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_output,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref
y_ref.backward(dy_ref)
# Implementation with fusible operations
with te.fp8_model_init(enabled=fp8_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
),
te_ops.AddInPlace(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
if bias:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_compute", (False, True))
def test_backward_linear_add(
self,
*,
weight_shape: tuple[int, int] = (16, 16),
in_shape: Iterable[int] = (16, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_compute: bool,
fp8_input: bool = False,
fp8_weight: bool = False,
fp8_output: bool = False,
) -> None:
"""Backward dgrad GEMM + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
if fp8_input or fp8_weight or fp8_output or fp8_compute:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
if fp8_compute:
if (
math.prod(in_shape[:-1]) % 16 != 0
or in_features % 16 != 0
or out_features % 16 != 0
):
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
if fp8_output and not fp8_compute:
pytest.skip("FP8 output requires FP8 compute")
if fp8_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y1_ref = torch.nn.functional.linear(x_ref, w_ref)
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operations
with te.fp8_model_init(enabled=fp8_weight):
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=fp8_compute):
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardLinearAdd)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[1].weight._fp8_dtype
if is_float8_tensor(model[1].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, **tols)
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
...@@ -9,11 +9,13 @@ This operation-based API is experimental and subject to change. ...@@ -9,11 +9,13 @@ This operation-based API is experimental and subject to change.
""" """
from transformer_engine.pytorch.ops.basic import ( from transformer_engine.pytorch.ops.basic import (
AddInPlace,
AllGather, AllGather,
AllReduce, AllReduce,
BasicLinear, BasicLinear,
Bias, Bias,
Identity, Identity,
MakeExtraOutput,
ReduceScatter, ReduceScatter,
Reshape, Reshape,
) )
......
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .add_in_place import AddInPlace
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
from .basic_linear import BasicLinear from .basic_linear import BasicLinear
from .bias import Bias from .bias import Bias
from .identity import Identity from .identity import Identity
from .make_extra_output import MakeExtraOutput
from .reduce_scatter import ReduceScatter from .reduce_scatter import ReduceScatter
from .reshape import Reshape from .reshape import Reshape
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for in-place add."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
class AddInPlace(BasicOperation):
"""Add in-place
This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a
view of the extra input is output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass.
"""
# Operation expects buffer for output tensor
num_extra_inputs: int = 1
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
output += input_
return output, [()]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
return grad_output, [], [(grad_output,)]
...@@ -12,7 +12,11 @@ from typing import Any, Optional ...@@ -12,7 +12,11 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, gemm from transformer_engine.pytorch.cpp_extensions import (
FP8TensorMeta,
fp8_gemm,
gemm,
)
from transformer_engine.pytorch.distributed import ( from transformer_engine.pytorch.distributed import (
CudaRNGStatesTracker, CudaRNGStatesTracker,
gather_along_first_dim, gather_along_first_dim,
...@@ -32,6 +36,7 @@ from .._common import ( ...@@ -32,6 +36,7 @@ from .._common import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
convert_tensor, convert_tensor,
devices_match,
is_float8_tensor, is_float8_tensor,
reshape, reshape,
) )
...@@ -308,6 +313,8 @@ class BasicLinear(BasicOperation): ...@@ -308,6 +313,8 @@ class BasicLinear(BasicOperation):
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
accumulate_into_out: bool = False,
tensor_parallel_mode: Optional[str] = None, tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -330,6 +337,10 @@ class BasicLinear(BasicOperation): ...@@ -330,6 +337,10 @@ class BasicLinear(BasicOperation):
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype: torch.dtype, default = default dtype
Tensor datatype Tensor datatype
out: torch.Tensor, optional
Output tensor
accumulate_into_out: bool, default = `False`
Add result to output tensor instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
...@@ -365,19 +376,25 @@ class BasicLinear(BasicOperation): ...@@ -365,19 +376,25 @@ class BasicLinear(BasicOperation):
# Check device # Check device
if device is None: if device is None:
device = weight.device device = weight.device if out is None else out.device
device = canonicalize_device(device) device = canonicalize_device(device)
if device.type != "cuda": if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})") raise ValueError(f"Only CUDA devices are supported (got {device})")
if out is not None and not devices_match(out.device, device):
raise ValueError(
f"Output tensor has invalid device (expected {device}, got {out.device})"
)
# Check datatype # Check datatype
if dtype is None: if dtype is None:
dtype = weight.dtype dtype = weight.dtype if out is None else out.dtype
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
if out is not None and out.dtype != dtype:
raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})")
# Check tensor dims # Check input tensor dims
input_dims = tuple(input.size()) input_dims = tuple(input.size())
weight_dims = tuple(weight.size()) weight_dims = tuple(weight.size())
if len(weight_dims) != 2: if len(weight_dims) != 2:
...@@ -389,6 +406,32 @@ class BasicLinear(BasicOperation): ...@@ -389,6 +406,32 @@ class BasicLinear(BasicOperation):
"are not compatible" "are not compatible"
) )
# Check output tensor dims
output_dims: list[int]
if out is None:
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
else:
output_dims = list(out.size())
if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check if accumulating into output tensor
if accumulate_into_out:
if out is None:
raise ValueError(
"Attempted to accumulate into output tensor without providing output tensor"
)
if tensor_parallel_mode == "row":
raise ValueError(
"Accumulating into output tensor is not supported with row tensor parallelism"
)
# Check if FP8 is enabled # Check if FP8 is enabled
if with_fp8_compute: if with_fp8_compute:
if input_fp8_meta is None and not is_float8_tensor(input): if input_fp8_meta is None and not is_float8_tensor(input):
...@@ -399,9 +442,18 @@ class BasicLinear(BasicOperation): ...@@ -399,9 +442,18 @@ class BasicLinear(BasicOperation):
input_fp8_meta = None input_fp8_meta = None
weight_fp8_meta = None weight_fp8_meta = None
output_fp8_meta = None output_fp8_meta = None
with_fp8_output = ( with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row"
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None if out is None:
with_fp8_output = with_fp8_output and output_fp8_meta is not None
else:
if is_float8_tensor(out):
if not with_fp8_output:
raise ValueError(
"Output tensor is a Float8Tensor, but FP8 output is not supported"
) )
out._reset_caches()
else:
with_fp8_output = False
# Check input tensor # Check input tensor
x_local = reshape( x_local = reshape(
...@@ -476,7 +528,9 @@ class BasicLinear(BasicOperation): ...@@ -476,7 +528,9 @@ class BasicLinear(BasicOperation):
# Construct output tensor # Construct output tensor
y = None y = None
if with_fp8_output: if out is not None:
y = reshape(out, (-1, output_dims[-1]))
elif with_fp8_output:
fp8_dtype = get_fp8_te_dtype( fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"], output_fp8_meta["recipe"],
fprop_tensor=True, fprop_tensor=True,
...@@ -506,19 +560,31 @@ class BasicLinear(BasicOperation): ...@@ -506,19 +560,31 @@ class BasicLinear(BasicOperation):
x_async = None x_async = None
if with_fp8_compute: if with_fp8_compute:
kwargs = dict( kwargs = dict(
accumulate=accumulate_into_out,
out=y, out=y,
bias=b, bias=b,
use_bias=(b is not None), use_bias=(b is not None),
) )
if with_fp8_output: if with_fp8_output:
if y._fp8_meta is None:
# Hackily create FP8TensorMeta if needed
fp8_meta = FP8TensorMeta()
fp8_meta.scale = y._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device)
fp8_meta.scale_inv = y._scale_inv
fp8_meta_index = 0
else:
# Get FP8TensorMeta from Float8Tensor
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=y._fp8_meta_forward, forward=y._fp8_meta_forward,
) )
fp8_meta = y._fp8_meta[fp8_meta_key]
fp8_meta_index = y._fp8_meta_index
kwargs.update( kwargs.update(
dict( dict(
out=y._data, out=y._data,
out_index=y._fp8_meta_index, out_index=fp8_meta_index,
fp8_meta_tensor=y._fp8_meta[fp8_meta_key], fp8_meta_tensor=fp8_meta,
D_dtype=y._fp8_dtype, D_dtype=y._fp8_dtype,
) )
) )
...@@ -541,6 +607,7 @@ class BasicLinear(BasicOperation): ...@@ -541,6 +607,7 @@ class BasicLinear(BasicOperation):
x, x,
y.dtype, y.dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_into_out,
out=y, out=y,
bias=b, bias=b,
use_bias=(b is not None), use_bias=(b is not None),
...@@ -553,13 +620,11 @@ class BasicLinear(BasicOperation): ...@@ -553,13 +620,11 @@ class BasicLinear(BasicOperation):
else: else:
torch.distributed.all_reduce(y, group=tensor_parallel_group) torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Reshape output tensor # Reshape output tensor if needed
output_dims = list(input_dims) if out is None:
output_dims[0] = -1 out = reshape(y, output_dims)
output_dims[-1] = weight_dims[0]
output = reshape(y, output_dims)
return output, x_local, w return out, x_local, w
@staticmethod @staticmethod
def _functional_backward( def _functional_backward(
...@@ -573,6 +638,10 @@ class BasicLinear(BasicOperation): ...@@ -573,6 +638,10 @@ class BasicLinear(BasicOperation):
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None,
accumulate_into_grad_weight: bool = False,
grad_input: Optional[torch.Tensor] = None,
accumulate_into_grad_input: bool = False,
tensor_parallel_mode: Optional[str] = None, tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -581,8 +650,6 @@ class BasicLinear(BasicOperation): ...@@ -581,8 +650,6 @@ class BasicLinear(BasicOperation):
weight_fp8_meta: Optional[dict[str, Any]] = None, weight_fp8_meta: Optional[dict[str, Any]] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None, grad_output_fp8_meta: Optional[dict[str, Any]] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None, grad_input_fp8_meta: Optional[dict[str, Any]] = None,
accumulate_into_grad_weight: bool = False,
grad_weight: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Functional API for backward pass """Functional API for backward pass
...@@ -608,6 +675,14 @@ class BasicLinear(BasicOperation): ...@@ -608,6 +675,14 @@ class BasicLinear(BasicOperation):
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype: torch.dtype, default = default dtype
Tensor datatype Tensor datatype
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting
grad_input: torch.Tensor, optional
Loss gradient w.r.t. input tensor
accumulate_into_grad_input: bool, default = `False`
Add result to input grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
...@@ -632,10 +707,6 @@ class BasicLinear(BasicOperation): ...@@ -632,10 +707,6 @@ class BasicLinear(BasicOperation):
grad_output_fp8_meta: dict, optional grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8 tensor to FP8
accumulate_into_grad_weight: bool, default = `False`
Accumulate into weight grad instead of overwriting
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
Returns Returns
------- -------
...@@ -678,6 +749,34 @@ class BasicLinear(BasicOperation): ...@@ -678,6 +749,34 @@ class BasicLinear(BasicOperation):
f"and weight tensor (shape={weight_dims}) " f"and weight tensor (shape={weight_dims}) "
"are not compatible" "are not compatible"
) )
if grad_input is not None and tuple(grad_input.size()) != input_dims:
raise ValueError(
f"Grad input tensor (shape={tuple(grad_input.size())}) "
f"does not match expected shape ({input_dims})"
)
# Check grad input tensor
if not input_requires_grad:
grad_input = None
if grad_input is not None and not devices_match(grad_input.device, device):
raise ValueError(
f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})"
)
if grad_input is not None and grad_input.dtype != dtype:
raise ValueError(
f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})"
)
if accumulate_into_grad_input:
if grad_input is None:
raise ValueError(
"Attempted to accumulate into grad input tensor "
"without providing grad input tensor"
)
if tensor_parallel_mode == "column":
raise ValueError(
"Accumulating into grad input tensor "
"is not supported with column tensor parallelism"
)
# Check if FP8 is enabled # Check if FP8 is enabled
if with_fp8_compute: if with_fp8_compute:
...@@ -689,11 +788,19 @@ class BasicLinear(BasicOperation): ...@@ -689,11 +788,19 @@ class BasicLinear(BasicOperation):
grad_output_fp8_meta = None grad_output_fp8_meta = None
grad_input_fp8_meta = None grad_input_fp8_meta = None
with_fp8_grad_input = ( with_fp8_grad_input = (
with_fp8_compute with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column"
and input_requires_grad )
and tensor_parallel_mode != "column" if grad_input is None:
and grad_input_fp8_meta is not None with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None
else:
if is_float8_tensor(grad_input):
if not with_fp8_grad_input:
raise ValueError(
"Grad input tensor is a Float8Tensor, but FP8 output is not supported"
) )
grad_input._reset_caches()
else:
with_fp8_grad_input = False
# Check grad output tensor # Check grad output tensor
dy_async = None dy_async = None
...@@ -806,7 +913,9 @@ class BasicLinear(BasicOperation): ...@@ -806,7 +913,9 @@ class BasicLinear(BasicOperation):
w = w.from_float8() w = w.from_float8()
# Construct grad input tensor # Construct grad input tensor
if with_fp8_grad_input: if grad_input is not None:
dx = reshape(grad_input, (-1, input_dims[-1]))
elif with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype( fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"], grad_input_fp8_meta["recipe"],
fprop_tensor=False, fprop_tensor=False,
...@@ -835,16 +944,32 @@ class BasicLinear(BasicOperation): ...@@ -835,16 +944,32 @@ class BasicLinear(BasicOperation):
_wait_async(dy_async) _wait_async(dy_async)
dy_async = None dy_async = None
if with_fp8_compute: if with_fp8_compute:
kwargs = dict(out=dx) kwargs = dict(
accumulate=accumulate_into_grad_input,
out=dx,
)
if with_fp8_grad_input: if with_fp8_grad_input:
if dx._fp8_meta is None:
# Hackily create FP8TensorMeta if needed
fp8_meta = FP8TensorMeta()
fp8_meta.scale = dx._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(
1, 1, dtype=torch.float32, device=device
)
fp8_meta.scale_inv = dx._scale_inv
fp8_meta_index = 0
else:
# Get FP8TensorMeta from Float8Tensor
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dx._fp8_meta_forward, forward=dx._fp8_meta_forward,
) )
fp8_meta = dx._fp8_meta[fp8_meta_key]
fp8_meta_index = dx._fp8_meta_index
kwargs.update( kwargs.update(
dict( dict(
out=dx._data, out=dx._data,
out_index=dx._fp8_meta_index, out_index=fp8_meta_index,
fp8_meta_tensor=dx._fp8_meta[fp8_meta_key], fp8_meta_tensor=fp8_meta,
D_dtype=dx._fp8_dtype, D_dtype=dx._fp8_dtype,
) )
) )
...@@ -867,6 +992,7 @@ class BasicLinear(BasicOperation): ...@@ -867,6 +992,7 @@ class BasicLinear(BasicOperation):
dy, dy,
dx.dtype, dx.dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_into_grad_input,
layout="NN", layout="NN",
out=dx, out=dx,
) )
...@@ -936,8 +1062,7 @@ class BasicLinear(BasicOperation): ...@@ -936,8 +1062,7 @@ class BasicLinear(BasicOperation):
_wait_async(dy_async) _wait_async(dy_async)
_wait_async(x_async) _wait_async(x_async)
_wait_async(dx_async) _wait_async(dx_async)
grad_input = None if dx is not None and grad_input is None:
if dx is not None:
grad_input = reshape(dx, input_dims) grad_input = reshape(dx, input_dims)
return grad_input, grad_weight return grad_input, grad_weight
...@@ -1027,6 +1152,8 @@ class BasicLinear(BasicOperation): ...@@ -1027,6 +1152,8 @@ class BasicLinear(BasicOperation):
weight_requires_grad=ctx.weight_requires_grad, weight_requires_grad=ctx.weight_requires_grad,
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
...@@ -1034,8 +1161,6 @@ class BasicLinear(BasicOperation): ...@@ -1034,8 +1161,6 @@ class BasicLinear(BasicOperation):
weight_fp8_meta=ctx.weight_fp8_meta, weight_fp8_meta=ctx.weight_fp8_meta,
grad_output_fp8_meta=ctx.grad_output_fp8_meta, grad_output_fp8_meta=ctx.grad_output_fp8_meta,
grad_input_fp8_meta=ctx.grad_input_fp8_meta, grad_input_fp8_meta=ctx.grad_input_fp8_meta,
accumulate_into_grad_weight=accumulate_into_main_grad,
grad_weight=grad_weight,
) )
# Clear input tensor if possible # Clear input tensor if possible
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Make extra tensor output in operation fuser."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
class MakeExtraOutput(BasicOperation):
"""Make extra output in operation fuser
If this operation is included in the operation fuser, then the
operation fuser will return the intermediate tensor as an extra
tensor output. In the backward pass, the gradient is directly
accumulated into the gradient w.r.t. the extra output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
Compare to `AddInPlace`, which does a similar operation in the
backward pass.
"""
# Operation expects buffer for output tensor
num_extra_outputs: int = 1
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
return input_, [(input_,)]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
grad_input = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output
return grad_input, [], [()]
...@@ -4,7 +4,15 @@ ...@@ -4,7 +4,15 @@
"""Compound tensor operation supported by the operation fuser.""" """Compound tensor operation supported by the operation fuser."""
from .linear_bias_activation import ( from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
)
from .forward_linear_bias_activation import (
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation, fuse_forward_linear_bias_activation,
) )
from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dgrad GEMM + add."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
class BackwardLinearAdd(FusedOperation):
"""Fused backward dgrad GEMM + add
Column tensor parallelism is not supported since that requires
communication immediately after the dgrad GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
backward_add: MakeExtraOutput,
) -> None:
super().__init__((linear, backward_add))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
else:
accumulate_into_main_grad = False
# Linear backward pass
grad_input = basic_op_grad_extra_outputs[1][0]
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
input_dims=linear_op_ctx.input_dims,
weight_dims=linear_op.weight.size(),
input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
device=linear_op.device,
dtype=linear_op.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
grad_input=grad_input,
accumulate_into_grad_input=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_fp8_compute=linear_op_ctx.with_fp8_compute,
weight_fp8_meta=linear_op_ctx.weight_fp8_meta,
grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta,
grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_linear_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + add
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Row tensor-parallelism requires communication after the
# GEMM
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearAdd(
linear=window[0][0],
backward_add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Fused operation for GEMM, bias, activation in the forward pass.""" """Fused operation for forward GEMM + bias + activation."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -20,7 +21,7 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -20,7 +21,7 @@ from transformer_engine.pytorch.ops.op import (
class ForwardLinearBiasActivation(FusedOperation): class ForwardLinearBiasActivation(FusedOperation):
"""Fused GEMM, bias, activation in the forward pass """Fused forward GEMM + bias + activation
Bias and activation are both optional. Row tensor parallelism is Bias and activation are both optional. Row tensor parallelism is
not supported since that requires communication immediately after not supported since that requires communication immediately after
...@@ -60,10 +61,12 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -60,10 +61,12 @@ class ForwardLinearBiasActivation(FusedOperation):
self, self,
basic_op_ctxs: list[OperationContext], basic_op_ctxs: list[OperationContext],
input_: torch.Tensor, input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations # Get basic operations
idx = self._op_idxs["linear"] idx = self._op_idxs["linear"]
...@@ -128,13 +131,13 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -128,13 +131,13 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_bias_activation( def fuse_forward_linear_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse GEMM, bias, activation in the forward pass """Fuse forward GEMM + bias + activation
Parameters Parameters
---------- ----------
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for forward GEMM + bias + add."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
class ForwardLinearBiasAdd(FusedOperation):
"""Fused forward GEMM + bias + add
Bias is optional. Row tensor parallelism is not supported since
that requires communication immediately after the GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
bias: Optional[Bias],
add: AddInPlace,
) -> None:
# Basic operations that comprise this fused operation
op_idxs = dict(
linear=0,
bias=None,
add=None,
)
ops = [linear]
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
op_idxs["add"] = len(ops)
ops.append(add)
# Initialize base class
super().__init__(ops)
# Index of each basic operations
self._op_idxs: dict[str, Optional[int]] = op_idxs
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = linear_op.get_fp8_meta("input")
weight_fp8_meta = linear_op.get_fp8_meta("param")
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0]
output, x_local, _ = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=linear_op.dtype,
out=output,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_bias_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + bias + add
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is bias
bias = None
if isinstance(op, Bias):
bias = op
window.extend(ops[:1])
ops = ops[1:]
if len(ops) == 0:
continue
op, _ = ops[0]
# Check if next op is add in-place
if not isinstance(op, AddInPlace):
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasAdd(
linear=linear,
bias=bias,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -16,11 +16,18 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -16,11 +16,18 @@ from transformer_engine.pytorch.ops.op import (
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.ops.fused_forward import ( from transformer_engine.pytorch.ops.fused import (
fuse_backward_linear_add,
fuse_forward_linear_bias_activation, fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
) )
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
"""Split tuple at index"""
return t[:idx], t[idx:]
class _OperationFuserAutogradFunction(torch.autograd.Function): class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations """Autograd function for a pipeline of operations
...@@ -38,8 +45,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -38,8 +45,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
backward_ops: list[tuple[FusibleOperation, list[int]]], backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation], basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
*params: torch.nn.Parameter, num_params: int,
) -> torch.Tensor: num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass """Forward pass
Parameters Parameters
...@@ -60,39 +69,82 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -60,39 +69,82 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Basic operations Basic operations
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to BasicOperation Keyword arguments to BasicOperation
*params: torch.nn.Parameter num_params: int
Parameters in operation pipeline Number of parameter tensors to include in autograd graph.
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
Returns
-------
Output tensor(s). If none of the operations have any extra
tensor outputs, then the pipeline's output tensor is returned.
Otherwise, a tuple with the pipeline's output tensor and extra
tensor outputs is returned.
""" """
# Operation autograd contexts # Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]
# Unflatten list of parameters and extra tensor inputs
if len(params_and_extra_inputs) != num_params + num_extra_inputs:
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(params_and_extra_inputs)}"
)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
basic_op_extra_inputs = []
for op in basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = x.requires_grad requires_grad = x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in forward_ops: for op, basic_op_idxs in forward_ops:
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [ next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
] ]
x = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
x, x,
prev_ops, basic_op_extra_inputs=extra_inputs,
next_ops, basic_op_prev_ops=prev_ops,
[basic_op_kwargs[idx] for idx in basic_op_idxs], basic_op_next_ops=next_ops,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
) )
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
extra_outputs[idx] = ys
# Check if backward op is required # Check if backward op is required
if not requires_grad: if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters()) requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs: for idx in basic_op_idxs:
basic_op_ctxs[idx]._requires_grad = requires_grad basic_op_ctxs[idx]._requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad) x.requires_grad_(requires_grad=requires_grad)
# Flatten list of extra outputs
extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs):
ys = list(ys)
num_extra_outputs = basic_ops[idx].num_extra_outputs
if len(ys) != num_extra_outputs:
raise RuntimeError(
f"Expected op {idx} to generate "
"{num_extra_outputs} extra inputs, "
f"but got {len(ys)}"
)
extra_outputs_flat.extend(ys)
# Flatten list of saved tensors # Flatten list of saved tensors
to_save = [] to_save = []
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
...@@ -108,8 +160,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -108,8 +160,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.backward_ops = backward_ops func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.num_params = num_params
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if extra_outputs_flat:
return x, *extra_outputs_flat
return x return x
@staticmethod @staticmethod
...@@ -117,6 +174,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -117,6 +174,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def backward( def backward(
func_ctx: Any, func_ctx: Any,
grad_output: torch.Tensor, grad_output: torch.Tensor,
*grad_extra_outputs: torch.Tensor,
) -> tuple[Optional[torch.Tensor], ...]: ) -> tuple[Optional[torch.Tensor], ...]:
"""Backward pass""" """Backward pass"""
...@@ -126,15 +184,25 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -126,15 +184,25 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors # Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)] ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None ctx._saved_tensors_range = None
del saved_tensors
# Unflatten list of extra tensor output grads
if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
raise ValueError(
f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, "
f"but got {len(grad_extra_outputs)}"
)
basic_op_grad_extra_outputs = []
for op in basic_ops:
dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
basic_op_grad_extra_outputs.append(dys)
# Apply backward ops # Apply backward ops
dx = grad_output dx = grad_output
grad_params = [None for _ in range(len(basic_ops))] grad_params = [None for _ in range(len(basic_ops))]
grad_extra_inputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops: for op, basic_op_idxs in backward_ops:
# Stop if no more gradients are required # Stop if no more gradients are required
...@@ -143,13 +211,17 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -143,13 +211,17 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
break break
# Backward op # Backward op
dx, fused_op_dparams = op.fuser_backward( grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs]
dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
dx, dx,
basic_op_grad_extra_outputs=grad_extra_outputs,
) )
for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams): for idx, dparams in zip(basic_op_idxs, fused_op_grad_params):
grad_params[idx] = basic_op_dparams grad_params[idx] = dparams
basic_op_ctxs[idx].saved_tensors = None basic_op_ctxs[idx].saved_tensors = None
for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs):
grad_extra_inputs[idx] = dxs
# Flatten list of parameter gradients # Flatten list of parameter gradients
grad_params_flat = [] grad_params_flat = []
...@@ -166,6 +238,22 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -166,6 +238,22 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
) )
grad_params_flat.extend(dparams) grad_params_flat.extend(dparams)
# Flatten list of parameter gradients
grad_extra_inputs_flat = []
for idx, dxs in enumerate(grad_extra_inputs):
num_extra_inputs = basic_ops[idx].num_extra_inputs
if dxs is None:
dxs = [None for _ in range(num_extra_inputs)]
else:
dxs = list(dxs)
if len(dxs) != num_extra_inputs:
raise RuntimeError(
f"Expected op {idx} to generate grads "
f"for {num_extra_inputs} extra inputs, "
f"but got {len(dxs)}"
)
grad_extra_inputs_flat.extend(dxs)
# Update FP8 scaling factors # Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing(): if func_ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
...@@ -176,7 +264,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -176,7 +264,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
None, # backward_ops None, # backward_ops
None, # basic_ops None, # basic_ops
None, # basic_op_kwargs None, # basic_op_kwargs
*grad_params_flat, # params None, # num_params
None, # num_extra_inputs
*grad_params_flat,
*grad_extra_inputs_flat,
) )
...@@ -208,6 +299,9 @@ class OperationFuser: ...@@ -208,6 +299,9 @@ class OperationFuser:
self._num_basic_ops: int = len(basic_ops) self._num_basic_ops: int = len(basic_ops)
self._basic_ops: list[BasicOperation] = basic_ops self._basic_ops: list[BasicOperation] = basic_ops
# Number of extra tensor inputs
self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)
# Ops for forward and backward pass # Ops for forward and backward pass
self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]]
...@@ -224,6 +318,7 @@ class OperationFuser: ...@@ -224,6 +318,7 @@ class OperationFuser:
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass""" """Attempt to fuse operations in forward pass"""
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops) ops = fuse_forward_linear_bias_activation(ops)
return ops return ops
...@@ -233,6 +328,7 @@ class OperationFuser: ...@@ -233,6 +328,7 @@ class OperationFuser:
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass""" """Attempt to fuse operations in backward pass"""
ops = fuse_backward_linear_add(ops)
return ops return ops
def fuse_ops(self) -> None: def fuse_ops(self) -> None:
...@@ -243,8 +339,9 @@ class OperationFuser: ...@@ -243,8 +339,9 @@ class OperationFuser:
def __call__( def __call__(
self, self,
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
*extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None, basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Initialization before forward pass # Initialization before forward pass
for op in self._basic_ops: for op in self._basic_ops:
...@@ -255,9 +352,7 @@ class OperationFuser: ...@@ -255,9 +352,7 @@ class OperationFuser:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]
# Flatten list of parameters # Flatten list of parameters
params = [] params = [param for op in self._basic_ops for param in op.parameters()]
for op in self._basic_ops:
params.extend(op.parameters())
# Fuser forward pass # Fuser forward pass
return _OperationFuserAutogradFunction.apply( return _OperationFuserAutogradFunction.apply(
...@@ -266,5 +361,8 @@ class OperationFuser: ...@@ -266,5 +361,8 @@ class OperationFuser:
self._backward_ops, self._backward_ops,
self._basic_ops, self._basic_ops,
basic_op_kwargs, basic_op_kwargs,
len(params),
self._num_extra_inputs,
*params, *params,
*extra_inputs,
) )
...@@ -67,10 +67,12 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -67,10 +67,12 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
self, self,
basic_op_ctxs: list[OperationContext], basic_op_ctxs: list[OperationContext],
input_: torch.Tensor, input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
"""Forward pass """Forward pass
This op is either a basic op or the fusion of basic ops, so This op is either a basic op or the fusion of basic ops, so
...@@ -82,24 +84,27 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -82,24 +84,27 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Parameters Parameters
---------- ----------
basic_op_ctxs: list of OperationContext basic_op_ctxs: list of OperationContext
Contexts for corresponding basic operations Contexts for basic operations
input_: torch.Tensor input_: torch.Tensor
Input tensor Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
basic_op_prev_ops: list of BasicOperation basic_op_prev_ops: list of BasicOperation
Basic operations that preceed each of the corresponding Basic operations that preceed this operation's basic
basic operations (or `None` if corresponding basic op is operations
first)
basic_op_next_ops: list of BasicOperation basic_op_next_ops: list of BasicOperation
Basic operations that follow each of the corresponding Basic operations that follow this operation's basic
basic operations (or `None` if corresponding basic op is operations
last)
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to forward functions of corresponding Keyword arguments to forward functions of basic
basic operations operations.
Returns Returns
------- -------
torch.Tensor: Output tensor. torch.Tensor:
Output tensor.
Iterable of torch.Tensor:
Extra tensor outputs from basic operations.
""" """
raise NotImplementedError( raise NotImplementedError(
...@@ -110,7 +115,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -110,7 +115,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
self, self,
basic_op_ctxs: list[OperationContext], basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: *,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
"""Backward pass """Backward pass
This op is either a basic op or the fusion of basic ops, so This op is either a basic op or the fusion of basic ops, so
...@@ -122,24 +133,21 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -122,24 +133,21 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Parameters Parameters
---------- ----------
basic_op_ctxs: list of OperationContext basic_op_ctxs: list of OperationContext
Contexts for corresponding basic operations. Contexts for basic operations
grad_output: torch.Tensor grad_output: torch.Tensor
Loss gradient w.r.t. operation output. Loss gradient w.r.t. operation output
basic_op_prev_ops: list of BasicOperation basic_op_grad_extra_outputs: list of tuple of torch.Tensor
Basic operations that preceed each of the corresponding Loss gradients w.r.t. extra tensor outputs from basic
basic operations (or `None` if corresponding basic op is operations.
first)
basic_op_next_ops: list of BasicOperation
Basic operations that follow each of the corresponding
basic operations (or `None` if corresponding basic op is
last)
Returns Returns
------- -------
torch.Tensor: torch.Tensor:
Loss gradient w.r.t. operation input Loss gradient w.r.t. operation input
Iterable of iterable of torch.Tensor: Iterable of iterable of torch.Tensor:
Loss gradients w.r.t. parameters for corresponding basic Loss gradients w.r.t. parameters for basic operations
Iterable of iterable of torch.Tensor:
Loss gradients w.r.t. extra tensor inputs to basic
operations operations
""" """
...@@ -156,6 +164,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -156,6 +164,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
""" """
# Number of extra tensor inputs
num_extra_inputs: int = 0
# Number of extra tensor outputs
num_extra_outputs: int = 0
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -297,6 +310,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -297,6 +310,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
*,
prev_op: Optional[BasicOperation] = None, prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None,
**kwargs: Any, **kwargs: Any,
...@@ -309,6 +323,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -309,6 +323,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes Context to coordinate between forward and backward passes
input_: torch.Tensor input_: torch.Tensor
Input tensor Input tensor
prev_op: BasicOperation, optional
Basic operation that preceeds this operation
next_op: BasicOperation, optional
Basic operation that follows this operation
Returns Returns
------- -------
...@@ -345,35 +363,63 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -345,35 +363,63 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self, self,
basic_op_ctxs: list[OperationContext], basic_op_ctxs: list[OperationContext],
input_: torch.Tensor, input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor: ) -> tuple[torch.Tensor, list[tuple[()]]]:
return self.op_forward( if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It should override `fuser_forward` instead of `op_forward`."
)
output = self.op_forward(
basic_op_ctxs[0], basic_op_ctxs[0],
input_, input_,
basic_op_prev_ops[0], prev_op=basic_op_prev_ops[0],
basic_op_next_ops[0], next_op=basic_op_next_ops[0],
**basic_op_kwargs[0], **basic_op_kwargs[0],
) )
return output, [()]
def fuser_backward( def fuser_backward(
self, self,
basic_op_ctxs: list[OperationContext], basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: *,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[Iterable[Optional[torch.Tensor]]],
list[tuple[()]],
]:
if self.num_extra_inputs > 0 or self.num_extra_outputs > 0:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It should override `fuser_backward` instead of `op_backward`."
)
grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output) grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output)
return grad_input, [grad_params] return grad_input, [grad_params], [()]
def forward( def forward(
self, self,
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
*extra_inputs: torch.Tensor,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply operation""" """Apply operation"""
from .fuser import OperationFuser from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(input, [kwargs]) return OperationFuser([self], fuse_ops=False)(
input,
*extra_inputs,
basic_op_kwargs=[kwargs],
)
class FusedOperation(FusibleOperation): class FusedOperation(FusibleOperation):
...@@ -417,6 +463,7 @@ class FusedOperation(FusibleOperation): ...@@ -417,6 +463,7 @@ class FusedOperation(FusibleOperation):
def forward( def forward(
self, self,
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
*extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None, basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply operation""" """Apply operation"""
...@@ -424,4 +471,8 @@ class FusedOperation(FusibleOperation): ...@@ -424,4 +471,8 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser from .fuser import OperationFuser
return OperationFuser([self], fuse_ops=False)(input, basic_op_kwargs) return OperationFuser([self], fuse_ops=False)(
input,
*extra_inputs,
basic_op_kwargs=basic_op_kwargs,
)
...@@ -144,28 +144,44 @@ class Sequential(torch.nn.Module): ...@@ -144,28 +144,44 @@ class Sequential(torch.nn.Module):
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
) -> list[OperationFuser | torch.nn.Module]: ) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together""" """Make list of modules, with fusible operations grouped together"""
module_groups = []
fusible_ops = []
def maybe_add_fuser():
nonlocal fusible_ops
if fusible_ops:
module_groups.append(OperationFuser(fusible_ops, fuse_ops=True))
fusible_ops = []
# Group fusible operations together
groups = []
for module in modules: for module in modules:
if isinstance(module, FusibleOperation): if isinstance(module, FusibleOperation):
fusible_ops.append(module) if not groups or not isinstance(groups[-1], list):
groups.append([])
groups[-1].append(module)
else: else:
maybe_add_fuser() groups.append(module)
module_groups.append(module) for idx, group in enumerate(groups):
maybe_add_fuser() if isinstance(group, list):
return module_groups groups[idx] = OperationFuser(group, fuse_ops=True)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
# Sequential must be made up of TE ops.
if len(groups) > 1:
ops = []
for group in groups:
if isinstance(group, OperationFuser):
ops.extend(group._basic_ops)
num_extra_inputs = sum(op.num_extra_inputs for op in ops)
num_extra_outputs = sum(op.num_extra_outputs for op in ops)
if num_extra_inputs > 0 or num_extra_outputs > 0:
raise RuntimeError(
f"`Sequential` expects {num_extra_inputs} extra inputs "
f"and {num_extra_outputs} extra outputs, "
"but it contains non-fusible operations"
)
return groups
def forward( def forward(
self, self,
input: torch.Tensor, # pylint: disable=redefined-builtin input: torch.Tensor, # pylint: disable=redefined-builtin
) -> torch.Tensor: *extra_inputs: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass""" """Forward pass"""
# Create module groups if needed # Create module groups if needed
...@@ -175,5 +191,5 @@ class Sequential(torch.nn.Module): ...@@ -175,5 +191,5 @@ class Sequential(torch.nn.Module):
# Forward pass for each module group # Forward pass for each module group
x = input x = input
for module_group in self._module_groups: for module_group in self._module_groups:
x = module_group(x) x = module_group(x, *extra_inputs)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment