Unverified Commit 20b0473c authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Activation operations (#1164)



* Add activation ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix lint warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update to use QuantizedTensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Respect PyTorch autograd dtype
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename CastFloat8 op to Quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for fused dSwiGLU-cast-transpose
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d1488e73
...@@ -1362,6 +1362,166 @@ class TestBasicOps: ...@@ -1362,6 +1362,166 @@ class TestBasicOps:
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_output", (False, True))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_input: bool,
fp8_output: bool,
) -> None:
"""Activation functions"""
# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2
# Skip invalid configurations
if fp8_input or fp8_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_input,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)
# Implementation with fusible operation
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8_output):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if fp8_output:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_output", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (16, 16),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_output: bool,
fp8_grad_input: bool,
):
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
# Skip invalid configurations
fp8 = fp8_output or fp8_grad_input
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)
# Implementation with fusible operation
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=fp8_grad_input),
te_ops.SwiGLU(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
class TestFusedOps: class TestFusedOps:
"""Tests for fused operations""" """Tests for fused operations"""
......
...@@ -16,6 +16,7 @@ __all__ = [ ...@@ -16,6 +16,7 @@ __all__ = [
"fp8_cast_transpose_fused", "fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused", "fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_dswiglu_cast_transpose_fused",
"fp8_multi_cast_transpose_fused", "fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused", "fp8_transpose_bgrad_fused",
] ]
...@@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ...@@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
) )
def fp8_dswiglu_cast_transpose_fused(
grad_output: torch.Tensor,
inp: torch.Tensor,
*,
grad_input: torch.Tensor,
grad_input_transpose: torch.Tensor,
otype: tex.DType,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> None:
"""Fused SwiGLU backward + FP8 cast + FP8 transpose"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
# Launch kernel
return tex.fused_dswiglu_cast_transpose(
grad_output,
inp,
grad_input,
grad_input_transpose,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
**fp8_scales_offsets,
)
def fp8_multi_cast_transpose_fused( def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor], input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
......
...@@ -210,6 +210,12 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -210,6 +210,12 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
int scale_offset = 0, int amax_offset = 0, int scale_offset = 0, int amax_offset = 0,
int scale_inv_offset = 0); int scale_inv_offset = 0);
void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset = 0,
int amax_offset = 0, int scale_inv_offset = 0);
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list, std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list, std::vector<at::Tensor> cast_output_list,
......
...@@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
......
...@@ -196,6 +196,75 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -196,6 +196,75 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose}; return {grad_bias, dgelu, dgelu_transpose};
} }
void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset,
int amax_offset, int scale_inv_offset) {
using namespace transformer_engine;
// Tensor dimensions
auto outer_dim = [](const at::Tensor& tensor) -> size_t {
return tensor.numel() / tensor.size(-1);
};
const auto M = outer_dim(grad_output);
const auto N = static_cast<size_t>(grad_output.size(-1));
// Check tensor dims
NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ",
grad_output.dim());
NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M,
", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N,
", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ",
grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ",
M, ", but found ", outer_dim(grad_input));
NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ",
2 * N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input_transpose.dim() == 2,
"Expected grad input transpose tensor to have 2 dims, but found ",
grad_input_transpose.dim());
NVTE_CHECK(grad_input_transpose.size(0) == 2 * N,
"Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ",
grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(1) == M,
"Expected grad input tensor to have outer dimension of ", M, ", but found ",
grad_input_transpose.size(1));
// Check tensor format
NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous");
NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous");
NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous");
NVTE_CHECK(grad_input_transpose.is_contiguous(),
"Expected grad input transpose tensor to be contiguous");
NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(),
"Expected grad output tensor and input tensor to have same dtype");
NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte,
"Expected grad input tensor to be uint8 buffer");
NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte,
"Expected grad input transpose tensor to be uint8 buffer");
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto dy_cu = makeTransformerEngineTensor(grad_output);
auto x_cu = makeTransformerEngineTensor(input);
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr,
scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype,
amax_dptr, scale_dptr, scale_inv_dptr);
// Launch kernel
nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(),
at::cuda::getCurrentCUDAStream());
}
void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list, void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<void*> scale_dptr_list, std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list, std::vector<at::Tensor> cast_output_list,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace from .add_in_place import AddInPlace
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operations for activation functions."""
from __future__ import annotations
import abc
from typing import Optional
import torch
import transformer_engine_torch
from ...constants import TE_DType
from ...cpp_extensions import (
geglu as tex_geglu,
gelu as tex_gelu,
reglu as tex_reglu,
relu as tex_relu,
swiglu as tex_swiglu,
fp8_dswiglu_cast_transpose_fused,
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import clear_tensor_data, devices_match
from ..op import BasicOperation, OperationContext
class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
r"""Apply activation function
Activation functions are either element-wise unary functions or
variants of the gated linear unit (GLU). Recall that GLU is
computed by splitting the input tensor into chunks :math:`a` and
:math:`b` along the last dimension and computing
.. math::
\text{GLU}(a,b) = \sigma(a) * b
.. warning::
Transformer Engine gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
"""
@abc.abstractmethod
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
"""Forward implementation
Implementation from transformer_engine.pytorch.cpp_extensions.
"""
@abc.abstractmethod
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
"""Backward implementation
Implementation from transformer_engine_torch.
"""
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Compute dtype
dtype: torch.dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = input_.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if x.device.type != "cuda":
x = x.cuda()
if x.dtype != dtype:
x = x.to(dtype=dtype)
if not x.is_contiguous():
x = x.contiguous()
# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
with_fp8_output = False
output_fp8_meta = None
output_dtype = TE_DType[dtype]
output_fp8_scale_inv = None
if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0:
with_fp8_output = True
fp8_meta = next_op.get_fp8_meta("input")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
output_fp8_meta = fp8_meta[fp8_meta_key]
output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device)
# Launch kernel
y = self._activation_forward_impl(
x,
output_fp8_meta,
0,
output_dtype,
scale_inv=output_fp8_scale_inv,
)
# Check output tensor
if y.dim() != x.dim():
y = y.reshape(list(x.shape[:-1]) + [-1])
if with_fp8_output:
y = Float8Tensor(
data=y,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=output_dtype,
fp8_scale_inv=output_fp8_scale_inv,
dtype=dtype,
)
# Save state for backward pass
ctx.save_for_backward(x)
ctx.fp8_enabled = fp8_enabled
ctx.prev_op = prev_op
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(x,) = ctx.saved_tensors
# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
if not devices_match(dy.device, x.device) or dy.dtype != x.dtype:
dy = dy.to(device=x.device, dtype=x.dtype)
if not dy.is_contiguous():
dy = dy.contiguous()
# Launch kernel
dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype])
# Check grad input tensor
if dx.size() != x.size():
dx = dx.reshape(x.size())
# Clear input tensor if possible
if ctx.prev_op is not None:
clear_tensor_data(x)
return dx, ()
class GELU(_ActivationOperation):
r"""Gaussian Error Linear Unit
This computes the "tanh" approximation to GELU:
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
See `Gaussian Error Linear Units (GELUs)<https://arxiv.org/abs/1606.08415>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex_gelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return transformer_engine_torch.dgelu(*args, **kwargs)
class ReLU(_ActivationOperation):
r"""Rectified linear unit
.. math::
\text{ReLU}(x) = \max(x,0)
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex_relu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return transformer_engine_torch.drelu(*args, **kwargs)
class GEGLU(_ActivationOperation):
r"""Gaussian error gated linear unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex_geglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return transformer_engine_torch.dgeglu(*args, **kwargs)
class ReGLU(_ActivationOperation):
r"""Rectified gated linear unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{ReGLU}(a,b) = \max(a,0) * b
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex_reglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return transformer_engine_torch.dreglu(*args, **kwargs)
class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
`GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__
and `Gaussian Error Linear Units (GELUs)<https://arxiv.org/abs/1606.08415>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex_swiglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return transformer_engine_torch.dswiglu(*args, **kwargs)
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(x,) = ctx.saved_tensors
# Tensor attributes
dtype = x.dtype
device = x.device
# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
if not devices_match(dy.device, device) or dy.dtype != dtype:
dy = dy.to(device=device, dtype=dtype)
if not dy.is_contiguous():
dy = dy.contiguous()
# Check if FP8 is enabled
with_fp8_grad_input = False
grad_input_fp8_meta = None
grad_input_dtype = TE_DType[dtype]
grad_input_fp8_scale_inv = None
if (
ctx.fp8_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_fp8_scales("grad_output") > 0
):
with_fp8_grad_input = True
fp8_meta = ctx.prev_op.get_fp8_meta("grad_output")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
grad_input_fp8_meta = fp8_meta[fp8_meta_key]
grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device)
# Launch kernel
if with_fp8_grad_input:
# Fused with FP8 cast-transpose
input_dims = x.size()
flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]]
flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2]
dx = torch.empty(input_dims, dtype=torch.uint8, device=device)
dx_t = torch.empty(
(flat_input_dims[1], flat_input_dims[0]),
dtype=torch.uint8,
device=device,
)
fp8_dswiglu_cast_transpose_fused(
dy.reshape(flat_output_dims),
x.reshape(flat_input_dims),
grad_input=dx.reshape(flat_input_dims),
grad_input_transpose=dx_t,
otype=grad_input_dtype,
fp8_meta=grad_input_fp8_meta,
fp8_meta_index=0,
scale_inv=grad_input_fp8_scale_inv,
)
dx = Float8Tensor(
data=dx,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=grad_input_dtype,
fp8_scale_inv=grad_input_fp8_scale_inv,
dtype=dtype,
)
dx._transpose = dx_t
dx._transpose_invalid = False
else:
# Standard impl
dx = self._activation_backward_impl(dy, x, TE_DType[dtype])
if dx.size() != x.size():
dx = dx.reshape(x.size())
# Note: This fails if op is preceeded by an identity op like Quantize(forward=False)
# # Clear input tensor if possible
# if ctx.prev_op is not None:
# clear_tensor_data(x)
return dx, ()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment