Unverified Commit c3d8eb34 authored by Xiaojie Li's avatar Xiaojie Li Committed by GitHub
Browse files

add Swish activation (#522)



* update impad

* fix docstring

* add shape for impad

* fix unit test

* remove old version & fix doc

* fix linting

* fix doc

* add linear decay learning rate scheduler

* fix impad

* fix setup.cfg

* fix linting

* add yapf

* add swish

* fix lr_updater

* fix lr_updater.py

* update swish

* add swish

* fix inplace

* fix typo
Co-authored-by: default avatarlixiaojie <lixiaojie@sensetime.com>
parent 66a38c86
...@@ -5,7 +5,7 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -5,7 +5,7 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
ContextBlock, ConvAWS2d, ConvModule, ConvWS2d, ContextBlock, ConvAWS2d, ConvModule, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention, DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, NonLocal1d, NonLocal2d, NonLocal3d, HSigmoid, HSwish, NonLocal1d, NonLocal2d, NonLocal3d,
Scale, build_activation_layer, build_conv_layer, Scale, Swish, build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer, build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
...@@ -21,7 +21,7 @@ __all__ = [ ...@@ -21,7 +21,7 @@ __all__ = [
'build_activation_layer', 'build_conv_layer', 'build_norm_layer', 'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer', 'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule' 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule'
......
...@@ -14,6 +14,7 @@ from .plugin import build_plugin_layer ...@@ -14,6 +14,7 @@ from .plugin import build_plugin_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS) PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
from .scale import Scale from .scale import Scale
from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
__all__ = [ __all__ = [
...@@ -23,5 +24,5 @@ __all__ = [ ...@@ -23,5 +24,5 @@ __all__ = [
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule' 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish'
] ]
...@@ -145,7 +145,7 @@ class ConvModule(nn.Module): ...@@ -145,7 +145,7 @@ class ConvModule(nn.Module):
act_cfg_ = act_cfg.copy() act_cfg_ = act_cfg.copy()
# nn.Tanh has no 'inplace' argument # nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [ if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid' 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
]: ]:
act_cfg_.setdefault('inplace', inplace) act_cfg_.setdefault('inplace', inplace)
self.activate = build_activation_layer(act_cfg_) self.activate = build_activation_layer(act_cfg_)
......
import torch
import torch.nn as nn
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class Swish(nn.Module):
"""Swish Module.
This module applies the swish function:
.. math::
Swish(x) = x * Sigmoid(x)
Returns:
Tensor: The output tensor.
"""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
import torch
from torch.nn.functional import sigmoid
from mmcv.cnn.bricks import Swish
def test_swish():
act = Swish()
input = torch.randn(1, 3, 64, 64)
expected_output = input * sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment