Unverified Commit 33ca6150 authored by Kim, Jin (Jay@SKT)'s avatar Kim, Jin (Jay@SKT) Committed by GitHub
Browse files

Add sigmoid GLU (#2656)



* Add sigmoid GLU
Signed-off-by: default avatarKim, Jin <jinn.kim@sk.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarKim, Jin <jinn.kim@sk.com>

* Add test for GLU op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect reshape
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Apply suggestion from @timmoon10
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Add omitted tests for GLU op
Signed-off-by: default avatarKim, Jin <jinn.kim@sk.com>

* Add GLU activation type support in JAX extension
Signed-off-by: default avatarKim, Jin <jinn.kim@sk.com>

* [PyTorch] Add Sigmoid activation for GLU support in numerics test (#2656)
Signed-off-by: default avatarKim, Jin <jinn.kim@sk.com>

---------
Signed-off-by: default avatarKim, Jin <jinn.kim@sk.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 3774aa37
...@@ -1570,7 +1570,19 @@ class TestBasicOps: ...@@ -1570,7 +1570,19 @@ class TestBasicOps:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"activation", "activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), (
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"glu",
"srelu",
"sreglu",
"silu",
"swiglu",
),
) )
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
...@@ -1590,7 +1602,7 @@ class TestBasicOps: ...@@ -1590,7 +1602,7 @@ class TestBasicOps:
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2 in_shape[-1] *= 2
# Skip invalid configurations # Skip invalid configurations
...@@ -1630,6 +1642,13 @@ class TestBasicOps: ...@@ -1630,6 +1642,13 @@ class TestBasicOps:
elif activation == "reglu": elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1) x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2 y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
elif activation == "glu":
x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2)
x = x.flip(-2) # PyTorch GLU swaps gate and linear unit
x = x.reshape(in_shape)
y_ref = torch.nn.functional.glu(x)
elif activation == "srelu": elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2 y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu": elif activation == "sreglu":
...@@ -1649,6 +1668,7 @@ class TestBasicOps: ...@@ -1649,6 +1668,7 @@ class TestBasicOps:
make_op = dict( make_op = dict(
gelu=te_ops.GELU, gelu=te_ops.GELU,
geglu=te_ops.GEGLU, geglu=te_ops.GEGLU,
glu=te_ops.GLU,
qgelu=te_ops.QGELU, qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU, qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU, relu=te_ops.ReLU,
......
...@@ -89,6 +89,7 @@ all_boolean = [True, False] ...@@ -89,6 +89,7 @@ all_boolean = [True, False]
all_activations = [ all_activations = [
"gelu", "gelu",
"geglu", "geglu",
"glu",
"qgelu", "qgelu",
"qgeglu", "qgeglu",
"relu", "relu",
...@@ -479,6 +480,7 @@ class TorchGroupedLinearWithPadding(nn.Module): ...@@ -479,6 +480,7 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act = { _supported_act = {
"gelu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"),
"geglu": nn.GELU(approximate="tanh"), "geglu": nn.GELU(approximate="tanh"),
"glu": nn.Sigmoid(),
"qgelu": TorchQuickGELU(), "qgelu": TorchQuickGELU(),
"qgeglu": TorchQuickGELU(), "qgeglu": TorchQuickGELU(),
"relu": nn.ReLU(), "relu": nn.ReLU(),
......
...@@ -113,6 +113,7 @@ batch_sizes_with_zero = [0, 1, 2] ...@@ -113,6 +113,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations = [ all_activations = [
"gelu", "gelu",
"geglu", "geglu",
"glu",
"qgelu", "qgelu",
"qgeglu", "qgeglu",
"relu", "relu",
......
...@@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources ...@@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources
list(APPEND transformer_engine_cuda_arch_specific_sources list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu activation/gelu.cu
activation/glu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
cast/cast.cu cast/cast.cu
...@@ -354,6 +355,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu ...@@ -354,6 +355,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu) activation/swiglu.cu)
endif() endif()
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_glu);
using namespace transformer_engine;
Empty e = {};
gated_act_fn<fp32, Empty, sigmoid<fp32, fp32>>(input, output, e, stream);
}
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dglu);
using namespace transformer_engine;
Empty e = {};
dgated_act_fn<fp32, Empty, sigmoid<fp32, fp32>, dsigmoid<fp32, fp32>>(grad, input, output, e,
stream);
}
...@@ -31,6 +31,7 @@ extern "C" { ...@@ -31,6 +31,7 @@ extern "C" {
enum class NVTE_Activation_Type { enum class NVTE_Activation_Type {
GELU, GELU,
GEGLU, GEGLU,
GLU,
SILU, SILU,
SWIGLU, SWIGLU,
RELU, RELU,
...@@ -262,6 +263,32 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -262,6 +263,32 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream); NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GLU (Gated Linear Unit) activation of the input.
* GLU(a,b) = sigmoid(a) * b
* See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083)
* and "GLU Variants Improve Transformer" (arXiv:2002.05202).
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes sigmoid(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gated GeLU activation of the input. /*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used. * the block quantization (MXFP8) of the specified shape of the block will be used.
......
...@@ -44,6 +44,7 @@ __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] ...@@ -44,6 +44,7 @@ __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
ActivationEnum = { ActivationEnum = {
("gelu",): NVTE_Activation_Type.GELU, ("gelu",): NVTE_Activation_Type.GELU,
("gelu", "linear"): NVTE_Activation_Type.GEGLU, ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
("sigmoid", "linear"): NVTE_Activation_Type.GLU,
("silu",): NVTE_Activation_Type.SILU, ("silu",): NVTE_Activation_Type.SILU,
("silu", "linear"): NVTE_Activation_Type.SWIGLU, ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
("relu",): NVTE_Activation_Type.RELU, ("relu",): NVTE_Activation_Type.RELU,
......
...@@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case NVTE_Activation_Type::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream); nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::GLU:
nvte_glu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU: case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream); nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break; break;
...@@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case NVTE_Activation_Type::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::GLU:
nvte_dglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU: case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break; break;
......
...@@ -150,6 +150,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -150,6 +150,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local()) pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU) .value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Type::GEGLU) .value("GEGLU", NVTE_Activation_Type::GEGLU)
.value("GLU", NVTE_Activation_Type::GLU)
.value("SILU", NVTE_Activation_Type::SILU) .value("SILU", NVTE_Activation_Type::SILU)
.value("SWIGLU", NVTE_Activation_Type::SWIGLU) .value("SWIGLU", NVTE_Activation_Type::SWIGLU)
.value("RELU", NVTE_Activation_Type::RELU) .value("RELU", NVTE_Activation_Type::RELU)
......
...@@ -163,6 +163,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st ...@@ -163,6 +163,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
/* GLU (sigmoid gate) */
py::object glu(const at::Tensor &input, py::handle quantizer);
py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/* GELU and variants*/ /* GELU and variants*/
py::object gelu(const at::Tensor &input, py::handle quantizer); py::object gelu(const at::Tensor &input, py::handle quantizer);
......
...@@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua ...@@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer); return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
} }
py::object glu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_glu, nullptr>(input, quantizer, 2);
}
py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dglu, nullptr>(grad, input, quantizer);
}
py::object geglu(const at::Tensor& input, py::handle quantizer) { py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2); return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
} }
......
...@@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
/* GLU (sigmoid gate) */
m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"),
py::arg("quantizer"));
/* GELU and variants*/ /* GELU and variants*/
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
...@@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GLU */
m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of GELU and variants */ /* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
......
...@@ -98,6 +98,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -98,6 +98,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None), "relu": (tex.relu, tex.drelu, None),
...@@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return { return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu), "relu": (tex.relu, tex.drelu, tex.dbias_drelu),
...@@ -136,6 +138,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -136,6 +138,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None), "relu": (tex.relu, tex.drelu, None),
...@@ -1665,7 +1668,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1665,7 +1668,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied. type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, Options: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : dict, default = None activation_params : dict, default = None
Additional parameters for the activation function. Additional parameters for the activation function.
...@@ -1884,7 +1887,15 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1884,7 +1887,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None self.layer_norm_bias = None
# FC1 init # FC1 init
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: if self.activation in [
"geglu",
"glu",
"qgeglu",
"reglu",
"sreglu",
"swiglu",
"clamped_swiglu",
]:
fc1_output_features = 2 * self.size_per_partition fc1_output_features = 2 * self.size_per_partition
else: else:
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
...@@ -2308,6 +2319,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2308,6 +2319,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map = { activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"glu": lambda x: torch.sigmoid(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1], * x.chunk(2, -1)[1],
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
from .activation import ( from .activation import (
GELU, GELU,
GEGLU, GEGLU,
GLU,
QGELU, QGELU,
QGEGLU, QGEGLU,
ReLU, ReLU,
......
...@@ -20,6 +20,7 @@ from .._common import maybe_dequantize ...@@ -20,6 +20,7 @@ from .._common import maybe_dequantize
__all__ = [ __all__ = [
"GELU", "GELU",
"GEGLU", "GEGLU",
"GLU",
"QGELU", "QGELU",
"QGEGLU", "QGEGLU",
"ReLU", "ReLU",
...@@ -162,6 +163,38 @@ class GELU(_ActivationOperation): ...@@ -162,6 +163,38 @@ class GELU(_ActivationOperation):
return tex.dgelu(*args, **kwargs) return tex.dgelu(*args, **kwargs)
class GLU(_ActivationOperation):
r"""Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GLU}(a,b) = \sigma(a) * b
where :math:`\sigma` is the sigmoid function.
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `Language Modeling with Gated Convolutional Networks<https://arxiv.org/abs/1612.08083>`__
and `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.glu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dglu(*args, **kwargs)
class GEGLU(_ActivationOperation): class GEGLU(_ActivationOperation):
r"""Gaussian Error Gated Linear Unit r"""Gaussian Error Gated Linear Unit
......
...@@ -184,7 +184,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -184,7 +184,7 @@ class TransformerLayer(torch.nn.Module):
if set to ``False``, the transformer layer will not learn any additive biases. if set to ``False``, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : Optional[dict], default = None activation_params : Optional[dict], default = None
Additional parameters for the activation function. Additional parameters for the activation function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment