Unverified Commit ee841084 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add `in_place` kwarg to extra tensor ops (#1983)



* Mark output tensors as not deletable in backward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add `in_place` kwarg to `MakeExtraOutput`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Rename `AddInPlace` to `AddExtraInput` and add an `in_place` kwarg
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent fe27bf1c
......@@ -272,19 +272,19 @@ class TestSequentialContainer:
bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(), # | x1 | x1 [x1]
te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(), # | h1 | h1 [h1]
te_ops.AddInPlace(), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(), # | x2 | x2 [x2]
te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1]
te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b)
te_ops.AddInPlace(), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(), # | x3 | x3 [x3]
te_ops.AddInPlace(), # | x3 [x4] | x4 (= x4 + x3)
te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3]
te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(), # | x4 | x4 [x4]
te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4
)
......@@ -1402,13 +1402,15 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place(
def test_add_extra_input(
self,
*,
in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype,
device: torch.device,
quantization: Optional[str],
......@@ -1454,7 +1456,7 @@ class TestBasicOps:
dx2_ref = dy_ref
# Implementation with fusible operation
op = te_ops.AddInPlace()
op = te_ops.AddExtraInput(in_place=in_place)
y_test = op(x1_test, x2_test)
y_test.backward(dy_test)
......@@ -1469,6 +1471,7 @@ class TestBasicOps:
torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list)
......@@ -1476,6 +1479,7 @@ class TestBasicOps:
self,
*,
in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype,
device: torch.device,
quantization: Optional[str],
......@@ -1521,7 +1525,7 @@ class TestBasicOps:
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operation
op = te_ops.MakeExtraOutput()
op = te_ops.MakeExtraOutput(in_place=in_place)
y1_test, y2_test = op(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
......@@ -1885,7 +1889,7 @@ class TestFusedOps:
device=device,
dtype=dtype,
),
te_ops.AddInPlace(),
te_ops.AddExtraInput(in_place=True),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
......@@ -2077,7 +2081,7 @@ class TestFusedOps:
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear(
in_features,
out_features,
......
......@@ -5,7 +5,7 @@
"""Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Fusible operation for in-place add."""
"""Fusible operation for adding extra input tensor."""
from __future__ import annotations
from collections.abc import Iterable
......@@ -18,16 +18,17 @@ from transformer_engine.pytorch.ops.op import (
from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation):
"""Add in-place
class AddExtraInput(BasicOperation):
"""Add extra input tensor
This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a
view of the extra input is output.
user. It returns the sum of the main input and the extra input.
If in_place=True, the main input is added in-place to the extra
input, and a view of the extra input is output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
Using this operation with in_place=True is considered an advanced
feature and most users are discouraged from it. In-place operations
break some autograd assumptions and they can result in subtle, esoteric bugs.
Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass.
......@@ -37,6 +38,10 @@ class AddInPlace(BasicOperation):
# Operation expects buffer for output tensor
num_extra_inputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place = in_place
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
......@@ -63,8 +68,13 @@ class AddInPlace(BasicOperation):
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
output += input_
extra_input = basic_op_extra_inputs[0][0]
if self._in_place:
extra_input = extra_input.detach()
extra_input += input_
output = extra_input
else:
output = extra_input + input_
return output, [()]
def fuser_backward(
......
......@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation):
If this operation is included in the operation fuser, then the
operation fuser will return the intermediate tensor as an extra
tensor output. In the backward pass, the gradient is directly
accumulated into the gradient w.r.t. the extra output.
tensor output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
In the backward pass, the gradient may be directly
accumulated into the gradient w.r.t. the extra output. This is
controlled by the in_place kwarg. Currently, the BackwardLinearAdd
fusion is able to happen only with in_place=True.
Compare to `AddInPlace`, which does a similar operation in the
Using this operation with in_place=True is
considered an advanced feature. Most users are discouraged
from enabling it in-place gradient accumulation, as in-place
operations break some autograd assumptions and they can result
in subtle, esoteric bugs.
Compare to `AddExtraInput`, which does a similar operation in the
backward pass.
"""
......@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation):
# Operation expects buffer for output tensor
num_extra_outputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place: bool = in_place
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
......@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation):
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
grad_input = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output
grad_extra_output = basic_op_grad_extra_outputs[0][0]
if self._in_place:
grad_extra_output += grad_output
grad_input = grad_extra_output
else:
grad_input = grad_extra_output + grad_output
return grad_input, [()], [()]
......@@ -139,6 +139,8 @@ def fuse_backward_linear_add(
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if not op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
......
......@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias
from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
......@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation):
*,
linear: BasicLinear,
bias: Optional[Bias],
add: AddInPlace,
add: AddExtraInput,
) -> None:
# Basic operations that comprise this fused operation
......@@ -179,8 +179,10 @@ def fuse_forward_linear_bias_add(
continue
op, _ = ops[0]
# Check if next op is add in-place
if not isinstance(op, AddInPlace):
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
......
......@@ -197,6 +197,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
# Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat:
tensor.do_not_clear = True
x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if extra_outputs_flat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment