Unverified Commit 07db17b5 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Expose more activation functions (#2106)



expose more activation functions
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent ccc1abf9
...@@ -1532,7 +1532,10 @@ class TestBasicOps: ...@@ -1532,7 +1532,10 @@ class TestBasicOps:
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize(
"activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"),
)
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
...@@ -1551,7 +1554,7 @@ class TestBasicOps: ...@@ -1551,7 +1554,7 @@ class TestBasicOps:
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"): if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2 in_shape[-1] *= 2
# Skip invalid configurations # Skip invalid configurations
...@@ -1578,14 +1581,26 @@ class TestBasicOps: ...@@ -1578,14 +1581,26 @@ class TestBasicOps:
y_ref: torch.Tensor y_ref: torch.Tensor
if activation == "gelu": if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu": elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1) x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "qgelu":
y_ref = x_ref * torch.sigmoid(1.702 * x_ref)
elif activation == "qgeglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = x1 * torch.sigmoid(1.702 * x1) * x2
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "reglu": elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1) x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2 y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) ** 2 * x2
elif activation == "silu":
y_ref = torch.nn.functional.silu(x_ref)
elif activation == "swiglu": elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1) x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2 y_ref = torch.nn.functional.silu(x1) * x2
...@@ -1597,9 +1612,14 @@ class TestBasicOps: ...@@ -1597,9 +1612,14 @@ class TestBasicOps:
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
make_op = dict( make_op = dict(
gelu=te_ops.GELU, gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU, geglu=te_ops.GEGLU,
qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU,
reglu=te_ops.ReGLU, reglu=te_ops.ReGLU,
srelu=te_ops.SReLU,
sreglu=te_ops.SReGLU,
silu=te_ops.SiLU,
swiglu=te_ops.SwiGLU, swiglu=te_ops.SwiGLU,
)[activation] )[activation]
forward = te_ops.Sequential( forward = te_ops.Sequential(
......
...@@ -79,7 +79,18 @@ batch_sizes = [1, 2] ...@@ -79,7 +79,18 @@ batch_sizes = [1, 2]
all_boolean = [True, False] all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] all_activations = [
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"srelu",
"sreglu",
"silu",
"swiglu",
]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
...@@ -427,13 +438,16 @@ class TorchGroupedLinearWithPadding(nn.Module): ...@@ -427,13 +438,16 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act = { _supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"),
"reglu": nn.ReLU(), "geglu": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"swiglu": nn.SiLU(),
"qgelu": TorchQuickGELU(), "qgelu": TorchQuickGELU(),
"qgeglu": TorchQuickGELU(),
"relu": nn.ReLU(),
"reglu": nn.ReLU(),
"srelu": TorchSquaredRELU(), "srelu": TorchSquaredRELU(),
"sreglu": TorchSquaredRELU(),
"silu": nn.SiLU(),
"swiglu": nn.SiLU(),
} }
......
...@@ -104,7 +104,18 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher ...@@ -104,7 +104,18 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
all_boolean = [True, False] all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2] batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"] all_activations = [
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"srelu",
"sreglu",
"silu",
"swiglu",
]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
......
...@@ -154,38 +154,49 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st ...@@ -154,38 +154,49 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
/* GELU and variants*/
py::object gelu(const at::Tensor &input, py::handle quantizer); py::object gelu(const at::Tensor &input, py::handle quantizer);
py::object relu(const at::Tensor &input, py::handle quantizer); py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object geglu(const at::Tensor &input, py::handle quantizer); py::object geglu(const at::Tensor &input, py::handle quantizer);
py::object qgeglu(const at::Tensor &input, py::handle quantizer); py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object reglu(const at::Tensor &input, py::handle quantizer); py::object qgelu(const at::Tensor &input, py::handle quantizer);
py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object qgelu(const at::Tensor &input, py::handle quantizer); py::object qgeglu(const at::Tensor &input, py::handle quantizer);
py::object srelu(const at::Tensor &input, py::handle quantizer); py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); /* ReLU and variants*/
py::object relu(const at::Tensor &input, py::handle quantizer);
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object reglu(const at::Tensor &input, py::handle quantizer);
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object srelu(const at::Tensor &input, py::handle quantizer);
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object sreglu(const at::Tensor &input, py::handle quantizer);
py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/* Silu and variants*/
py::object silu(const at::Tensor &input, py::handle quantizer);
py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/*************************************************************************************************** /***************************************************************************************************
* LayerNorm * LayerNorm
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -101,6 +101,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i ...@@ -101,6 +101,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
return grad_input_py; return grad_input_py;
} }
/* GELU and variants*/
py::object gelu(const at::Tensor& input, py::handle quantizer) { py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer); return activation_helper<nvte_gelu>(input, quantizer);
} }
...@@ -109,30 +110,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua ...@@ -109,30 +110,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return dactivation_helper<nvte_dgelu>(grad, input, quantizer); return dactivation_helper<nvte_dgelu>(grad, input, quantizer);
} }
py::object relu(const at::Tensor& input, py::handle quantizer) { py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer); return activation_helper<nvte_geglu>(input, quantizer, 2);
} }
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer); return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
} }
py::object geglu(const at::Tensor& input, py::handle quantizer) { py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2); return activation_helper<nvte_qgelu>(input, quantizer);
} }
py::object qgeglu(const at::Tensor& input, py::handle quantizer) { py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2); return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
} }
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer); return activation_helper<nvte_qgeglu>(input, quantizer, 2);
} }
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer); return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer);
} }
/* ReLU and variants*/
py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer);
}
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
}
py::object reglu(const at::Tensor& input, py::handle quantizer) { py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2); return activation_helper<nvte_reglu>(input, quantizer, 2);
} }
...@@ -141,28 +151,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu ...@@ -141,28 +151,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu
return dactivation_helper<nvte_dreglu>(grad, input, quantizer); return dactivation_helper<nvte_dreglu>(grad, input, quantizer);
} }
py::object swiglu(const at::Tensor& input, py::handle quantizer) { py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2); return activation_helper<nvte_srelu>(input, quantizer);
} }
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer); return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
} }
py::object qgelu(const at::Tensor& input, py::handle quantizer) { py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer); return activation_helper<nvte_sreglu>(input, quantizer, 2);
} }
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer); return dactivation_helper<nvte_dsreglu>(grad, input, quantizer);
} }
py::object srelu(const at::Tensor& input, py::handle quantizer) { /* Silu and variants*/
return activation_helper<nvte_srelu>(input, quantizer); py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer);
} }
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer); return dactivation_helper<nvte_dsilu>(grad, input, quantizer);
}
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2);
} }
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
/* GELU and variants*/
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
/* ReLU and variants */
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), /* SwiGLU and variants */
m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
/* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of ReLU and variants */
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of SiLU and variants */
m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
/* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize",
......
...@@ -87,39 +87,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -87,39 +87,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
# bf16 (recipe is None): # bf16 (recipe is None):
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, None), "srelu": (tex.srelu, tex.dsrelu, None),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
} }
if recipe.delayed() or recipe.mxfp8(): if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
return { return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, tex.dbias_dsilu),
"swiglu": (tex.swiglu, tex.dswiglu, None),
} }
# no activation fusion written yet # no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: [] # Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling(): if recipe.float8_current_scaling() or recipe.float8_block_scaling():
return { return {
"gelu": (tex.gelu, tex.dgelu, None), "gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"srelu": (tex.srelu, tex.dsrelu, None), "srelu": (tex.srelu, tex.dsrelu, None),
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
} }
raise NotImplementedError(f"Unhandled recipe type {recipe}") raise NotImplementedError(f"Unhandled recipe type {recipe}")
...@@ -1375,7 +1381,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1375,7 +1381,7 @@ class _LayerNormMLP(torch.autograd.Function):
class LayerNormMLP(TransformerEngineBaseModule): class LayerNormMLP(TransformerEngineBaseModule):
r""" r"""
Applies layer normalization on the input followed by the MLP module, consisting of Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation. 2 successive linear transformations, separated by the activation function.
Parameters Parameters
---------- ----------
...@@ -1391,7 +1397,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1391,7 +1397,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied. type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'. Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...@@ -1592,7 +1599,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1592,7 +1599,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None self.layer_norm_bias = None
# FC1 init # FC1 init
if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]:
fc1_output_features = 2 * self.size_per_partition fc1_output_features = 2 * self.size_per_partition
else: else:
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
...@@ -1973,14 +1980,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1973,14 +1980,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map = { activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"relu": torch.nn.functional.relu,
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1], * x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "relu": torch.nn.functional.relu,
"srelu": torch.nn.functional.softplus, "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"srelu": lambda x: torch.nn.functional.relu(x) ** 2,
"sreglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2
* x.chunk(2, -1)[1],
"silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
} }
if self.activation not in activation_map: if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}") raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU
from .add_extra_input import AddExtraInput from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
......
...@@ -16,6 +16,19 @@ from ...utils import clear_tensor_data ...@@ -16,6 +16,19 @@ from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize from .._common import maybe_dequantize
__all__ = [
"GELU",
"GEGLU",
"QGELU",
"QGEGLU",
"ReLU",
"ReGLU",
"SReLU",
"SReGLU",
"SiLU",
"SwiGLU",
]
class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
r"""Apply activation function r"""Apply activation function
...@@ -147,37 +160,75 @@ class GELU(_ActivationOperation): ...@@ -147,37 +160,75 @@ class GELU(_ActivationOperation):
return tex.dgelu(*args, **kwargs) return tex.dgelu(*args, **kwargs)
class ReLU(_ActivationOperation): class GEGLU(_ActivationOperation):
r"""Rectified linear unit r"""Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math:: .. math::
\text{ReLU}(x) = \max(x,0) \text{GEGLU}(a,b) = \text{GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
""" """
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.relu(*args, **kwargs) return tex.geglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.drelu(*args, **kwargs) return tex.dgeglu(*args, **kwargs)
class GEGLU(_ActivationOperation): class QGELU(_ActivationOperation):
r"""Gaussian error gated linear unit r"""Quick Gaussian Error Linear Unit
Quick GELU from `HuggingFace<https://github.com/huggingface/transformers/blob/3e93dd295b5343557a83bc07b0b2ea64c926f9b4/src/transformers/activations.py#L90>`__
and `paper<https://github.com/hendrycks/GELUs>`__.
.. math::
\text{QGELU}(x) \approx x * \sigma(1.702 * x)
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.qgelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dqgelu(*args, **kwargs)
class QGEGLU(_ActivationOperation):
r"""Quick Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b` The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed: along the last dimension and the following is computed:
.. math:: .. math::
\text{GEGLU}(a,b) = \text{GELU}(a) * b \text{QGEGLU}(a,b) = \text{QGELU}(a) * b
where where
.. math:: .. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) \text{QGELU}(x) \approx x * \sigma(1.702 * x)
.. warning:: .. warning::
...@@ -187,19 +238,33 @@ class GEGLU(_ActivationOperation): ...@@ -187,19 +238,33 @@ class GEGLU(_ActivationOperation):
the first half of the input tensor, while PyTorch applies it to the first half of the input tensor, while PyTorch applies it to
the second half. the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__. """
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.qgeglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dqgeglu(*args, **kwargs)
class ReLU(_ActivationOperation):
r"""Rectified Linear Unit
.. math::
\text{ReLU}(x) = \max(x,0)
""" """
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.geglu(*args, **kwargs) return tex.relu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dgeglu(*args, **kwargs) return tex.drelu(*args, **kwargs)
class ReGLU(_ActivationOperation): class ReGLU(_ActivationOperation):
r"""Rectified gated linear unit r"""Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b` The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed: along the last dimension and the following is computed:
...@@ -227,6 +292,67 @@ class ReGLU(_ActivationOperation): ...@@ -227,6 +292,67 @@ class ReGLU(_ActivationOperation):
return tex.dreglu(*args, **kwargs) return tex.dreglu(*args, **kwargs)
class SReLU(_ActivationOperation):
r"""Squared Rectified Linear Unit
.. math::
\text{SReLU}(x) = \max(x^2,0)
See `Primer: Searching for Efficient Transformers for Language Modeling<https://arxiv.org/abs/2109.08668v2>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.srelu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsrelu(*args, **kwargs)
class SReGLU(_ActivationOperation):
r"""Squared Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{SReGLU}(a,b) = \max(a^2,0) * b
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.sreglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsreglu(*args, **kwargs)
class SiLU(_ActivationOperation):
r"""Sigmoid Linear Unit
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.silu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsilu(*args, **kwargs)
class SwiGLU(_ActivationOperation): class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit r"""Swish gated linear unit
......
...@@ -175,7 +175,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -175,7 +175,8 @@ class TransformerLayer(torch.nn.Module):
if set to `False`, the transformer layer will not learn any additive biases. if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'. Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment