Unverified Commit 416dff73 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix SiluActivation (#15718)

parent e93763d4
...@@ -30,9 +30,6 @@ class NewGELUActivation(nn.Module): ...@@ -30,9 +30,6 @@ class NewGELUActivation(nn.Module):
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
...@@ -64,9 +61,6 @@ class FastGELUActivation(nn.Module): ...@@ -64,9 +61,6 @@ class FastGELUActivation(nn.Module):
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
""" """
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
...@@ -76,9 +70,6 @@ class QuickGELUActivation(nn.Module): ...@@ -76,9 +70,6 @@ class QuickGELUActivation(nn.Module):
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
""" """
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input) return input * torch.sigmoid(1.702 * input)
...@@ -93,6 +84,7 @@ class SiLUActivation(nn.Module): ...@@ -93,6 +84,7 @@ class SiLUActivation(nn.Module):
""" """
def __init__(self): def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.7"): if version.parse(torch.__version__) < version.parse("1.7"):
self.act = self._silu_python self.act = self._silu_python
else: else:
...@@ -130,9 +122,6 @@ class LinearActivation(nn.Module): ...@@ -130,9 +122,6 @@ class LinearActivation(nn.Module):
Applies the linear activation function, i.e. forwarding input directly to output. Applies the linear activation function, i.e. forwarding input directly to output.
""" """
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return input return input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment