Unverified Commit ee841084 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add `in_place` kwarg to extra tensor ops (#1983)



* Mark output tensors as not deletable in backward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add `in_place` kwarg to `MakeExtraOutput`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Rename `AddInPlace` to `AddExtraInput` and add an `in_place` kwarg
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent fe27bf1c
...@@ -272,19 +272,19 @@ class TestSequentialContainer: ...@@ -272,19 +272,19 @@ class TestSequentialContainer:
bias.bias.copy_(torch.rand((size,))) bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1 torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(), # | x1 | x1 [x1] te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b) bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(), # | h1 | h1 [h1] te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1]
te_ops.AddInPlace(), # | h1 [x2] | x2 (= x2 + h1) te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(), # | x2 | x2 [x2] te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2 torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b) bias, # | x2 | h2 (= x2 + b)
te_ops.AddInPlace(), # | h2 [x3] | x3 (= x3 + h2) te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(), # | x3 | x3 [x3] te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3]
te_ops.AddInPlace(), # | x3 [x4] | x4 (= x4 + x3) te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4 torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4 te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(), # | x4 | x4 [x4] te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4 te_ops.Identity(), # | x4 | x4
) )
...@@ -1402,13 +1402,15 @@ class TestBasicOps: ...@@ -1402,13 +1402,15 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place( def test_add_extra_input(
self, self,
*, *,
in_shape: Iterable[int] = (32, 32), in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
quantization: Optional[str], quantization: Optional[str],
...@@ -1454,7 +1456,7 @@ class TestBasicOps: ...@@ -1454,7 +1456,7 @@ class TestBasicOps:
dx2_ref = dy_ref dx2_ref = dy_ref
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.AddInPlace() op = te_ops.AddExtraInput(in_place=in_place)
y_test = op(x1_test, x2_test) y_test = op(x1_test, x2_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1469,6 +1471,7 @@ class TestBasicOps: ...@@ -1469,6 +1471,7 @@ class TestBasicOps:
torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0) torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0) torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
...@@ -1476,6 +1479,7 @@ class TestBasicOps: ...@@ -1476,6 +1479,7 @@ class TestBasicOps:
self, self,
*, *,
in_shape: Iterable[int] = (32, 32), in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
quantization: Optional[str], quantization: Optional[str],
...@@ -1521,7 +1525,7 @@ class TestBasicOps: ...@@ -1521,7 +1525,7 @@ class TestBasicOps:
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.MakeExtraOutput() op = te_ops.MakeExtraOutput(in_place=in_place)
y1_test, y2_test = op(x_test) y1_test, y2_test = op(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward() (y1_test * dy1_test + y2_test * dy2_test).sum().backward()
...@@ -1885,7 +1889,7 @@ class TestFusedOps: ...@@ -1885,7 +1889,7 @@ class TestFusedOps:
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
te_ops.AddInPlace(), te_ops.AddExtraInput(in_place=True),
) )
with torch.no_grad(): with torch.no_grad():
model[0].weight.copy_(w_test) model[0].weight.copy_(w_test)
...@@ -2077,7 +2081,7 @@ class TestFusedOps: ...@@ -2077,7 +2081,7 @@ class TestFusedOps:
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight): with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.MakeExtraOutput(), te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear( te_ops.Linear(
in_features, in_features,
out_features, out_features,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
from .basic_linear import BasicLinear from .basic_linear import BasicLinear
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Fusible operation for in-place add.""" """Fusible operation for adding extra input tensor."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
...@@ -18,16 +18,17 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -18,16 +18,17 @@ from transformer_engine.pytorch.ops.op import (
from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation): class AddExtraInput(BasicOperation):
"""Add in-place """Add extra input tensor
This operation requires an extra tensor input to the operation This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a user. It returns the sum of the main input and the extra input.
view of the extra input is output. If in_place=True, the main input is added in-place to the extra
input, and a view of the extra input is output.
This operation is considered an advanced feature and most users Using this operation with in_place=True is considered an advanced
are discouraged from using it. In-place operations break some feature and most users are discouraged from it. In-place operations
autograd assumptions and they can result in subtle, esoteric bugs. break some autograd assumptions and they can result in subtle, esoteric bugs.
Compare to `MakeExtraOutput`, which does a similar operation in Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass. the backward pass.
...@@ -37,6 +38,10 @@ class AddInPlace(BasicOperation): ...@@ -37,6 +38,10 @@ class AddInPlace(BasicOperation):
# Operation expects buffer for output tensor # Operation expects buffer for output tensor
num_extra_inputs: int = 1 num_extra_inputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place = in_place
def op_forward(self, *args, **kwargs) -> None: def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError( raise RuntimeError(
"{self.__class__.__name__} operation has " "{self.__class__.__name__} operation has "
...@@ -63,8 +68,13 @@ class AddInPlace(BasicOperation): ...@@ -63,8 +68,13 @@ class AddInPlace(BasicOperation):
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach() extra_input = basic_op_extra_inputs[0][0]
output += input_ if self._in_place:
extra_input = extra_input.detach()
extra_input += input_
output = extra_input
else:
output = extra_input + input_
return output, [()] return output, [()]
def fuser_backward( def fuser_backward(
......
...@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation): ...@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation):
If this operation is included in the operation fuser, then the If this operation is included in the operation fuser, then the
operation fuser will return the intermediate tensor as an extra operation fuser will return the intermediate tensor as an extra
tensor output. In the backward pass, the gradient is directly tensor output.
accumulated into the gradient w.r.t. the extra output.
This operation is considered an advanced feature and most users In the backward pass, the gradient may be directly
are discouraged from using it. In-place operations break some accumulated into the gradient w.r.t. the extra output. This is
autograd assumptions and they can result in subtle, esoteric bugs. controlled by the in_place kwarg. Currently, the BackwardLinearAdd
fusion is able to happen only with in_place=True.
Compare to `AddInPlace`, which does a similar operation in the Using this operation with in_place=True is
considered an advanced feature. Most users are discouraged
from enabling it in-place gradient accumulation, as in-place
operations break some autograd assumptions and they can result
in subtle, esoteric bugs.
Compare to `AddExtraInput`, which does a similar operation in the
backward pass. backward pass.
""" """
...@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation): ...@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation):
# Operation expects buffer for output tensor # Operation expects buffer for output tensor
num_extra_outputs: int = 1 num_extra_outputs: int = 1
def __init__(self, *, in_place: bool = False):
super().__init__()
self._in_place: bool = in_place
def op_forward(self, *args, **kwargs) -> None: def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError( raise RuntimeError(
"{self.__class__.__name__} operation has " "{self.__class__.__name__} operation has "
...@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation): ...@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation):
Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]],
]: ]:
grad_input = basic_op_grad_extra_outputs[0][0] grad_extra_output = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output if self._in_place:
grad_extra_output += grad_output
grad_input = grad_extra_output
else:
grad_input = grad_extra_output + grad_output
return grad_input, [()], [()] return grad_input, [()], [()]
...@@ -139,6 +139,8 @@ def fuse_backward_linear_add( ...@@ -139,6 +139,8 @@ def fuse_backward_linear_add(
op, _ = ops[0] op, _ = ops[0]
if not isinstance(op, MakeExtraOutput): if not isinstance(op, MakeExtraOutput):
continue continue
if not op._in_place:
continue
window.extend(ops[:1]) window.extend(ops[:1])
ops = ops[1:] ops = ops[1:]
......
...@@ -11,7 +11,7 @@ from typing import Any, Optional ...@@ -11,7 +11,7 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
...@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation):
*, *,
linear: BasicLinear, linear: BasicLinear,
bias: Optional[Bias], bias: Optional[Bias],
add: AddInPlace, add: AddExtraInput,
) -> None: ) -> None:
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
...@@ -179,8 +179,10 @@ def fuse_forward_linear_bias_add( ...@@ -179,8 +179,10 @@ def fuse_forward_linear_bias_add(
continue continue
op, _ = ops[0] op, _ = ops[0]
# Check if next op is add in-place # Check if next op is in-place add extra input
if not isinstance(op, AddInPlace): if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue continue
add = op add = op
window.extend(ops[:1]) window.extend(ops[:1])
......
...@@ -197,6 +197,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -197,6 +197,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute func_ctx.with_quantized_compute = with_quantized_compute
# Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat:
tensor.do_not_clear = True
x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if extra_outputs_flat: if extra_outputs_flat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment