Unverified Commit fef7b2fe authored by louzana's avatar louzana Committed by GitHub
Browse files

add hswish and hsigmoid (#358)

* add hswish and hsigmoid

* fix linting

* fix linting

* delete useless staticmethod in cnn/bricks/hsigmoid.py

* forward(x) ==> forward(seld, x)

* add unittests for hsigmoid and hswish

* use torch.equal instead of numpy.equal
parent e3d9bf42
from .activation import build_activation_layer
from .conv import build_conv_layer
from .conv_module import ConvModule
from .hsigmoid import HSigmoid
from .hswish import HSwish
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
......@@ -12,6 +14,7 @@ from .upsample import build_upsample_layer
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale'
]
import torch.nn as nn
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSigmoid(nn.Module):
"""Hard Sigmoid Module. Apply the hard sigmoid function:
Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
Returns:
Tensor: The output tensor.
"""
def __init__(self):
super(HSigmoid, self).__init__()
def forward(self, x):
x = (x + 1) / 2
return x.clamp_(0, 1)
import torch.nn as nn
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSwish(nn.Module):
"""Hard Swish Module. Apply the hard swish function:
Hswish(x) = x * ReLU6(x + 3) / 6
Args:
inplace (bool): can optionally do the operation in-place.
Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self, inplace=False):
super(HSwish, self).__init__()
self.act = nn.ReLU6(inplace)
def forward(self, x):
return x * self.act(x + 3) / 6
import torch
from mmcv.cnn.bricks import HSigmoid
def test_hsigmoid():
act = HSigmoid()
input_shape = torch.Size([1, 3, 64, 64])
input = torch.randn(input_shape)
output = act(input)
expected_output = torch.min(
torch.max((input + 1) / 2, torch.zeros(input_shape)),
torch.ones(input_shape))
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)
import torch
from torch.nn.functional import relu6
from mmcv.cnn.bricks import HSwish
def test_hswish():
# test inplace
act = HSwish(inplace=True)
assert act.act.inplace
act = HSwish()
assert not act.act.inplace
input = torch.randn(1, 3, 64, 64)
expected_output = input * relu6(input + 3) / 6
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment