Unverified Commit 864880de authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

allow arbitrary layer order for ConvModule (#1078)

parent 36b6577e
......@@ -42,7 +42,7 @@ def build_conv_layer(cfg, *args, **kwargs):
class ConvModule(nn.Module):
"""Conv-Norm-Activation block.
"""A conv block that contains conv/norm/activation layers.
Args:
in_channels (int): Same as nn.Conv2d.
......@@ -59,9 +59,9 @@ class ConvModule(nn.Module):
norm_cfg (dict): Config dict for normalization layer.
activation (str or None): Activation type, "ReLU" by default.
inplace (bool): Whether to use inplace mode for activation.
activate_last (bool): Whether to apply the activation layer in the
last. (Do not use this flag since the behavior and api may be
changed in the future.)
order (tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Examples are
("conv", "norm", "act") and ("act", "conv", "norm").
"""
def __init__(self,
......@@ -77,7 +77,7 @@ class ConvModule(nn.Module):
norm_cfg=None,
activation='relu',
inplace=True,
activate_last=True):
order=('conv', 'norm', 'act')):
super(ConvModule, self).__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
......@@ -85,7 +85,9 @@ class ConvModule(nn.Module):
self.norm_cfg = norm_cfg
self.activation = activation
self.inplace = inplace
self.activate_last = activate_last
self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act'])
self.with_norm = norm_cfg is not None
self.with_activatation = activation is not None
......@@ -121,12 +123,17 @@ class ConvModule(nn.Module):
# build normalization layers
if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels
# norm layer is after conv layer
if order.index('norm') > order.index('conv'):
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.add_module(self.norm_name, norm)
# build activation layer
if self.with_activatation:
# TODO: introduce `act_cfg` and supports more activation layers
if self.activation not in ['relu']:
raise ValueError('{} is currently not supported.'.format(
self.activation))
......@@ -147,17 +154,11 @@ class ConvModule(nn.Module):
constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True):
if self.activate_last:
x = self.conv(x)
if norm and self.with_norm:
for layer in self.order:
if layer == 'conv':
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x)
if activate and self.with_activatation:
elif layer == 'act' and activate and self.with_activatation:
x = self.activate(x)
else:
# WARN: this may be removed or modified
if norm and self.with_norm:
x = self.norm(x)
if activate and self.with_activatation:
x = self.activate(x)
x = self.conv(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment