Unverified Commit 8735815a authored by ftbabi's avatar ftbabi Committed by GitHub
Browse files

Add GELU activation function (#843)

* Support GELU activation function

* Support different torch version

* Fix bug in importing TORCH_VERSION

* Wrap GELU module for PyTorch 1.3.x

* Wrap GELU module for PyTorch 1.3.x
parent b2ffab13
......@@ -20,7 +20,7 @@ layer = build_conv_layer(cfg, in_channels=3, out_channels=8, kernel_size=3)
- `build_conv_layer`: Supported types are Conv1d, Conv2d, Conv3d, Conv (alias for Conv2d).
- `build_norm_layer`: Supported types are BN1d, BN2d, BN3d, BN (alias for BN2d), SyncBN, GN, LN, IN1d, IN2d, IN3d, IN (alias for IN2d).
- `build_activation_layer`: Supported types are ReLU, LeakyReLU, PReLU, RReLU, ReLU6, ELU, Sigmoid, Tanh.
- `build_activation_layer`: Supported types are ReLU, LeakyReLU, PReLU, RReLU, ReLU6, ELU, Sigmoid, Tanh, GELU.
- `build_upsample_layer`: Supported types are nearest, bilinear, deconv, pixel_shuffle.
- `build_padding_layer`: Supported types are zero, reflect, replicate.
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.utils import build_from_cfg
from mmcv.utils import TORCH_VERSION, build_from_cfg
from .registry import ACTIVATION_LAYERS
for module in [
......@@ -43,6 +44,38 @@ class Clamp(nn.Module):
return torch.clamp(x, min=self.min, max=self.max)
class GELU(nn.Module):
r"""Applies the Gaussian Error Linear Units function:
.. math::
\text{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for
Gaussian Distribution.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: scripts/activation_images/GELU.png
Examples::
>>> m = nn.GELU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input):
return F.gelu(input)
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4':
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
def build_activation_layer(cfg):
"""Build activation layer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment