"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dce33f2150769825ca175df3209441122f85a814"
Unverified Commit f65fe366 authored by Eldar Kurtic's avatar Eldar Kurtic Committed by GitHub
Browse files

Implementation of activations as pytorch modules (#15616)



* Implement activations as pytorch modules

* Apply fixup

* Add missing tests for activations

* Update docstring
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 66828a19
...@@ -16,7 +16,7 @@ import math ...@@ -16,7 +16,7 @@ import math
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import Tensor, nn
from .utils import logging from .utils import logging
...@@ -24,39 +24,66 @@ from .utils import logging ...@@ -24,39 +24,66 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def gelu_python(x): class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class GELUActivation(nn.Module):
""" """
Original Implementation of the GELU activation function in Google BERT repo when initially created. For Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def gelu_new(x): def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Module):
""" """
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.4"): def forward(self, input: Tensor) -> Tensor:
gelu = gelu_python return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
else:
gelu = nn.functional.gelu
def gelu_fast(x): class QuickGELUActivation(nn.Module):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) """
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def __init__(self):
super().__init__()
def quick_gelu(x): def forward(self, input: Tensor) -> Tensor:
return x * torch.sigmoid(1.702 * x) return input * torch.sigmoid(1.702 * input)
def _silu_python(x): class SiLUActivation(nn.Module):
""" """
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
...@@ -64,46 +91,65 @@ def _silu_python(x): ...@@ -64,46 +91,65 @@ def _silu_python(x):
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later. later.
""" """
return x * torch.sigmoid(x)
def __init__(self):
if version.parse(torch.__version__) < version.parse("1.7"):
self.act = self._silu_python
else:
self.act = nn.functional.silu
if version.parse(torch.__version__) < version.parse("1.7"): def _silu_python(self, input: Tensor) -> Tensor:
silu = _silu_python return input * torch.sigmoid(input)
else:
silu = nn.functional.silu
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
def _mish_python(x):
class MishActivation(nn.Module):
""" """
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish visit the official repository for the paper: https://github.com/digantamisra98/Mish
""" """
return x * torch.tanh(nn.functional.softplus(x))
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9"):
self.act = self._mish_python
else:
self.act = nn.functional.mish
if version.parse(torch.__version__) < version.parse("1.9"): def _mish_python(self, input: Tensor) -> Tensor:
mish = _mish_python return input * torch.tanh(nn.functional.softplus(input))
else:
mish = nn.functional.mish
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
def linear_act(x):
return x class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return input
ACT2FN = { ACT2FN = {
"relu": nn.functional.relu, "relu": nn.ReLU(),
"silu": silu, "silu": SiLUActivation(),
"swish": silu, "swish": SiLUActivation(),
"gelu": gelu, "gelu": GELUActivation(),
"tanh": torch.tanh, "tanh": nn.Tanh(),
"gelu_python": gelu_python, "gelu_python": GELUActivation(use_gelu_python=True),
"gelu_new": gelu_new, "gelu_new": NewGELUActivation(),
"gelu_fast": gelu_fast, "gelu_fast": FastGELUActivation(),
"quick_gelu": quick_gelu, "quick_gelu": QuickGELUActivation(),
"mish": mish, "mish": MishActivation(),
"linear": linear_act, "linear": LinearActivation(),
"sigmoid": torch.sigmoid, "sigmoid": nn.Sigmoid(),
} }
...@@ -112,3 +158,14 @@ def get_activation(activation_string): ...@@ -112,3 +158,14 @@ def get_activation(activation_string):
return ACT2FN[activation_string] return ACT2FN[activation_string]
else: else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
...@@ -40,6 +40,10 @@ class TestActivations(unittest.TestCase): ...@@ -40,6 +40,10 @@ class TestActivations(unittest.TestCase):
get_activation("gelu_new") get_activation("gelu_new")
get_activation("gelu_fast") get_activation("gelu_fast")
get_activation("gelu_python") get_activation("gelu_python")
get_activation("quick_gelu")
get_activation("mish")
get_activation("linear")
get_activation("sigmoid")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
get_activation("bogus") get_activation("bogus")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment