Unverified Commit 37da2d3b authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add backward fusions of dbias+quantize and dbias+dactivation+quantize to `te.Sequential` (#1942)



* Fix clearing tensor data in backward removing is_first_op
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Misc fixes
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use Linear weight dtype and device for compute consistently
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add backward dbias + quantize fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass recipe to OperationFuser to allow recipe-dependent fusions
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove redundant view from activations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add bias activation backward fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent ac76d55c
...@@ -20,6 +20,7 @@ import transformer_engine.pytorch as te ...@@ -20,6 +20,7 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation,
BackwardLinearAdd, BackwardLinearAdd,
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
...@@ -1865,6 +1866,98 @@ class TestFusedOps: ...@@ -1865,6 +1866,98 @@ class TestFusedOps:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Backward dbias + dact + quantize"""
# Tensor dimensions
in_shape = list(out_shape)
hidden_size = in_shape[-1]
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
b_ref, b_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [hidden_size])
if activation == "gelu":
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(y_ref)
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
act_type = te_ops.GELU if activation == "gelu" else te_ops.ReLU
model = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=True),
te_ops.Bias(hidden_size, device=device, dtype=dtype),
act_type(),
)
with torch.no_grad():
model[1].bias.copy_(b_test)
del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
else:
assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type)
assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize)
# Expected numerical error
tols = dtype_tols(dtype)
if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add( def test_backward_linear_add(
......
...@@ -922,17 +922,20 @@ template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> ...@@ -922,17 +922,20 @@ template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(),
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(),
NVTE_CHECK(input.data.shape[0] == output->data.shape[0], ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
"Input shape[0] must be equal to output shape[0]."); NVTE_CHECK(input.flat_last_dim() % 2 == 0,
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
"Input shape[1] must be 2x larger than output shape[1]."); input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2,
"Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2,
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, output->dtype(), OType,
if (!is_fp8_dtype(output->data.dtype) || if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) { is_delayed_tensor_scaling(output->scaling_mode)) {
...@@ -942,8 +945,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -942,8 +945,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<OType *>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), output->data.shape[0], reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->data.shape[1], {}, stream); output->flat_last_dim(), {}, stream);
} else { } else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*) }); // NOLINT(*)
......
...@@ -85,16 +85,16 @@ __all__ = ["LayerNormMLP"] ...@@ -85,16 +85,16 @@ __all__ = ["LayerNormMLP"]
def _get_act_func_supported_list(recipe: Optional[Recipe] = None): def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
if recipe is None: if recipe is None:
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # bf16 (recipe is None):
return { return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu), "relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None), "reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), "srelu": (tex.srelu, tex.dsrelu, None),
} }
if recipe.delayed() or recipe.mxfp8(): if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
......
...@@ -73,7 +73,6 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -73,7 +73,6 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Compute dtype # Compute dtype
...@@ -95,14 +94,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -95,14 +94,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
quantizer = next_op_input_quantizer quantizer = next_op_input_quantizer
# Launch kernel # Launch kernel
y = self._activation_forward_impl( y = self._activation_forward_impl(x, quantizer)
x.view((-1, x.size(-1))),
quantizer,
)
# Check output tensor
if len(y.size()) != x.dim():
y = y.view(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed # Quantize input to FP8 before caching if needed
if self.cache_quantized_input: if self.cache_quantized_input:
...@@ -114,7 +106,6 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -114,7 +106,6 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
ctx.save_for_backward(x) ctx.save_for_backward(x)
ctx.with_quantized_compute = with_quantized_compute ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype ctx.dtype = dtype
ctx.is_first_op = is_first_op
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer
return y return y
...@@ -140,18 +131,9 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -140,18 +131,9 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
quantizer = ctx.prev_op_grad_input_quantizer quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel # Launch kernel
dx = self._activation_backward_impl( dx = self._activation_backward_impl(dy, x, quantizer)
dy.view((-1, dy.size(-1))),
x.view((-1, x.size(-1))),
quantizer,
)
# Check grad input tensor
if dx.size() != x.size():
dx = dx.view(x.size())
# Clear input tensor if possible # Clear input tensor if possible
if not ctx.is_first_op:
clear_tensor_data(x) clear_tensor_data(x)
return dx, () return dx, ()
......
...@@ -61,7 +61,6 @@ class AddInPlace(BasicOperation): ...@@ -61,7 +61,6 @@ class AddInPlace(BasicOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach() output = basic_op_extra_inputs[0][0].detach()
...@@ -79,4 +78,4 @@ class AddInPlace(BasicOperation): ...@@ -79,4 +78,4 @@ class AddInPlace(BasicOperation):
Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]],
]: ]:
return grad_output, [], [(grad_output,)] return grad_output, [()], [(grad_output,)]
...@@ -42,7 +42,6 @@ class AllGather(BasicOperation): ...@@ -42,7 +42,6 @@ class AllGather(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
out: torch.Tensor out: torch.Tensor
if self.process_group_size == 1: if self.process_group_size == 1:
......
...@@ -44,7 +44,6 @@ class AllReduce(BasicOperation): ...@@ -44,7 +44,6 @@ class AllReduce(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Trivial case # Trivial case
......
...@@ -392,7 +392,7 @@ class BasicLinear(BasicOperation): ...@@ -392,7 +392,7 @@ class BasicLinear(BasicOperation):
Bias tensor Bias tensor
device: torch.device, default = default CUDA device device: torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype: torch.dtype, default = infer from out or weight
Tensor datatype Tensor datatype
out: torch.Tensor, optional out: torch.Tensor, optional
Output tensor Output tensor
...@@ -439,8 +439,14 @@ class BasicLinear(BasicOperation): ...@@ -439,8 +439,14 @@ class BasicLinear(BasicOperation):
# Check datatype # Check datatype
if dtype is None: if dtype is None:
dtype = weight.dtype if out is None else out.dtype if out is not None and isinstance(out, torch.Tensor):
dtype = canonicalize_dtype(dtype) dtype = out.dtype
elif weight is not None and isinstance(out, torch.Tensor):
dtype = weight.dtype
else:
raise ValueError(
"Could not infer dtype from weight nor out and dtype was not provided"
)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
if out is not None and out.dtype != dtype: if out is not None and out.dtype != dtype:
...@@ -890,7 +896,6 @@ class BasicLinear(BasicOperation): ...@@ -890,7 +896,6 @@ class BasicLinear(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check which grads are required # Check which grads are required
...@@ -920,9 +925,10 @@ class BasicLinear(BasicOperation): ...@@ -920,9 +925,10 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=False) weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
else:
dtype = self.weight.dtype
# Linear forward # Linear forward
output, x_local, w = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
...@@ -950,7 +956,6 @@ class BasicLinear(BasicOperation): ...@@ -950,7 +956,6 @@ class BasicLinear(BasicOperation):
ctx.dtype = dtype ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad ctx.weight_requires_grad = weight_requires_grad
ctx.has_prev_op = not is_first_op
return output return output
...@@ -1001,7 +1006,6 @@ class BasicLinear(BasicOperation): ...@@ -1001,7 +1006,6 @@ class BasicLinear(BasicOperation):
) )
# Clear input tensor if possible # Clear input tensor if possible
if ctx.has_prev_op:
clear_tensor_data(x_local) clear_tensor_data(x_local)
if accumulate_into_main_grad: if accumulate_into_main_grad:
......
...@@ -9,6 +9,7 @@ from typing import Optional ...@@ -9,6 +9,7 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
OperationContext, OperationContext,
...@@ -17,6 +18,7 @@ from ...utils import ( ...@@ -17,6 +18,7 @@ from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
) )
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer from ...tensor import Quantizer
...@@ -123,10 +125,23 @@ class Bias(BasicOperation): ...@@ -123,10 +125,23 @@ class Bias(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
x = input_ x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if previous op quantizes its output's gradient
grad_input_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
grad_input_quantizer = prev_op_grad_input_quantizer
if requires_grad:
ctx.with_quantized_compute = with_quantized_compute
ctx.grad_input_quantizer = grad_input_quantizer
return x + b return x + b
def op_backward( def op_backward(
...@@ -136,6 +151,10 @@ class Bias(BasicOperation): ...@@ -136,6 +151,10 @@ class Bias(BasicOperation):
) -> tuple[torch.Tensor, tuple[()]]: ) -> tuple[torch.Tensor, tuple[()]]:
dy = grad_output dy = grad_output
if dy.dim() > 1: if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer
if ctx.with_quantized_compute and quantizer is not None:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
db = dy.sum(tuple(range(dy.dim() - 1))) db = dy.sum(tuple(range(dy.dim() - 1)))
else: else:
db = dy db = dy
......
...@@ -25,7 +25,6 @@ class Identity(BasicOperation): ...@@ -25,7 +25,6 @@ class Identity(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
return input_ return input_
......
...@@ -76,7 +76,6 @@ class L2Normalization(BasicOperation): ...@@ -76,7 +76,6 @@ class L2Normalization(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors # Use input directly - torch.compile can handle multi-dimensional tensors
x = maybe_dequantize(input_) x = maybe_dequantize(input_)
...@@ -97,7 +96,6 @@ class L2Normalization(BasicOperation): ...@@ -97,7 +96,6 @@ class L2Normalization(BasicOperation):
# Save state for backward pass # Save state for backward pass
if requires_grad: if requires_grad:
ctx.save_for_backward(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = not is_first_op
return y return y
...@@ -116,7 +114,6 @@ class L2Normalization(BasicOperation): ...@@ -116,7 +114,6 @@ class L2Normalization(BasicOperation):
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible # Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x) clear_tensor_data(x)
clear_tensor_data(rsqrt_norm) clear_tensor_data(rsqrt_norm)
......
...@@ -178,7 +178,6 @@ class LayerNorm(BasicOperation): ...@@ -178,7 +178,6 @@ class LayerNorm(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check tensor dims # Check tensor dims
...@@ -225,7 +224,6 @@ class LayerNorm(BasicOperation): ...@@ -225,7 +224,6 @@ class LayerNorm(BasicOperation):
if requires_grad: if requires_grad:
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
ctx.has_prev_op = not is_first_op
# Reshape output tensor # Reshape output tensor
out = y.view(input_dims) out = y.view(input_dims)
...@@ -261,7 +259,6 @@ class LayerNorm(BasicOperation): ...@@ -261,7 +259,6 @@ class LayerNorm(BasicOperation):
) )
# Clear saved tensors if possible # Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x) clear_tensor_data(x)
clear_tensor_data(means) clear_tensor_data(means)
clear_tensor_data(rstdevs) clear_tensor_data(rstdevs)
......
...@@ -61,7 +61,6 @@ class MakeExtraOutput(BasicOperation): ...@@ -61,7 +61,6 @@ class MakeExtraOutput(BasicOperation):
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
return input_, [(input_,)] return input_, [(input_,)]
...@@ -79,4 +78,4 @@ class MakeExtraOutput(BasicOperation): ...@@ -79,4 +78,4 @@ class MakeExtraOutput(BasicOperation):
]: ]:
grad_input = basic_op_grad_extra_outputs[0][0] grad_input = basic_op_grad_extra_outputs[0][0]
grad_input += grad_output grad_input += grad_output
return grad_input, [], [()] return grad_input, [()], [()]
...@@ -52,7 +52,6 @@ class Quantize(BasicOperation): ...@@ -52,7 +52,6 @@ class Quantize(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check if FP8 is enabled # Check if FP8 is enabled
......
...@@ -42,7 +42,6 @@ class ReduceScatter(BasicOperation): ...@@ -42,7 +42,6 @@ class ReduceScatter(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Trivial case # Trivial case
......
...@@ -40,7 +40,6 @@ class Reshape(BasicOperation): ...@@ -40,7 +40,6 @@ class Reshape(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
ctx.input_shape = input_.size() ctx.input_shape = input_.size()
return input_.reshape(*self._shape) return input_.reshape(*self._shape)
......
...@@ -161,7 +161,6 @@ class RMSNorm(BasicOperation): ...@@ -161,7 +161,6 @@ class RMSNorm(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check tensor dims # Check tensor dims
...@@ -206,7 +205,6 @@ class RMSNorm(BasicOperation): ...@@ -206,7 +205,6 @@ class RMSNorm(BasicOperation):
if requires_grad: if requires_grad:
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
ctx.has_prev_op = not is_first_op
# Reshape output tensor # Reshape output tensor
out = y.view(input_dims) out = y.view(input_dims)
...@@ -241,7 +239,6 @@ class RMSNorm(BasicOperation): ...@@ -241,7 +239,6 @@ class RMSNorm(BasicOperation):
) )
# Clear saved tensors if possible # Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x) clear_tensor_data(x)
clear_tensor_data(rstdevs) clear_tensor_data(rstdevs)
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
"""Compound tensor operation supported by the operation fuser.""" """Compound tensor operation supported by the operation fuser."""
from .backward_bias_activation import (
BackwardBiasActivation,
fuse_backward_bias_activation,
)
from .backward_linear_add import ( from .backward_linear_add import (
BackwardLinearAdd, BackwardLinearAdd,
fuse_backward_linear_add, fuse_backward_linear_add,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dbias + dact + quantize."""
from __future__ import annotations
from typing import Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import Recipe
from transformer_engine.pytorch.ops.basic import Bias
from transformer_engine.pytorch.ops.basic.activation import (
_ActivationOperation,
GELU,
ReLU,
)
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
_fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations = tuple(_fused_activations.keys())
class BackwardBiasActivation(FusedOperation):
"""Fused backward dbias + dact + quantize
Uses the next operation's input quantizer.
"""
def __init__(self, *, bias: Bias, activation: _ActivationOperation):
super().__init__((bias, activation))
self._fused_function = _fused_activations[type(activation)]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operation contexts
activation_op_ctx = basic_op_ctxs[0]
bias_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
(act_input,) = activation_op_ctx.saved_tensors
# Check activation input tensor
act_input = maybe_dequantize(act_input.contiguous(), activation_op_ctx.dtype)
# Check grad output tensor
dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype)
# Get previous op quantizer
if not bias_op_ctx.with_quantized_compute:
raise RuntimeError(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer = bias_op_ctx.grad_input_quantizer
if quantizer is None:
raise RuntimeError(
"BackwardBiasActivation requires previous op's grad output quantizer, "
"but Bias context has no quantizer"
)
# Launch kernel
db, dx = self._fused_function(dy, act_input, quantizer)
# Clear activation input tensor
clear_tensor_data(act_input)
return dx, [(), (db,)], [(), ()]
def fuse_backward_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dbias + dact + quantize
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe: Recipe, optional
Used quantization recipe
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if recipe is None or not (recipe.delayed() or recipe.mxfp8()):
return ops
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
out.extend(window)
# Check if first op is a supported activation
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, _fusible_activations):
continue
# Check if second op is bias
op, _ = ops[0]
if not isinstance(op, Bias):
continue
# Check if third op has a grad input quantizer
op, _ = ops[1]
if not op.num_quantizers("backward") > 0:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardBiasActivation(
activation=window[0][0],
bias=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
...@@ -96,7 +96,6 @@ class BackwardLinearAdd(FusedOperation): ...@@ -96,7 +96,6 @@ class BackwardLinearAdd(FusedOperation):
grad_weight = None grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
if linear_op_ctx.has_prev_op:
clear_tensor_data(x_local) clear_tensor_data(x_local)
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(grad_weight,), ()], [(), ()]
...@@ -110,13 +109,13 @@ def fuse_backward_linear_add( ...@@ -110,13 +109,13 @@ def fuse_backward_linear_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops: list of tuples
Forward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops: list of tuples
Updated forward pass operations Updated backward pass operations
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment