Unverified Commit 864880de authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

allow arbitrary layer order for ConvModule (#1078)

parent 36b6577e
...@@ -42,7 +42,7 @@ def build_conv_layer(cfg, *args, **kwargs): ...@@ -42,7 +42,7 @@ def build_conv_layer(cfg, *args, **kwargs):
class ConvModule(nn.Module): class ConvModule(nn.Module):
"""Conv-Norm-Activation block. """A conv block that contains conv/norm/activation layers.
Args: Args:
in_channels (int): Same as nn.Conv2d. in_channels (int): Same as nn.Conv2d.
...@@ -59,9 +59,9 @@ class ConvModule(nn.Module): ...@@ -59,9 +59,9 @@ class ConvModule(nn.Module):
norm_cfg (dict): Config dict for normalization layer. norm_cfg (dict): Config dict for normalization layer.
activation (str or None): Activation type, "ReLU" by default. activation (str or None): Activation type, "ReLU" by default.
inplace (bool): Whether to use inplace mode for activation. inplace (bool): Whether to use inplace mode for activation.
activate_last (bool): Whether to apply the activation layer in the order (tuple[str]): The order of conv/norm/activation layers. It is a
last. (Do not use this flag since the behavior and api may be sequence of "conv", "norm" and "act". Examples are
changed in the future.) ("conv", "norm", "act") and ("act", "conv", "norm").
""" """
def __init__(self, def __init__(self,
...@@ -77,7 +77,7 @@ class ConvModule(nn.Module): ...@@ -77,7 +77,7 @@ class ConvModule(nn.Module):
norm_cfg=None, norm_cfg=None,
activation='relu', activation='relu',
inplace=True, inplace=True,
activate_last=True): order=('conv', 'norm', 'act')):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
...@@ -85,7 +85,9 @@ class ConvModule(nn.Module): ...@@ -85,7 +85,9 @@ class ConvModule(nn.Module):
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.activation = activation self.activation = activation
self.inplace = inplace self.inplace = inplace
self.activate_last = activate_last self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act'])
self.with_norm = norm_cfg is not None self.with_norm = norm_cfg is not None
self.with_activatation = activation is not None self.with_activatation = activation is not None
...@@ -121,12 +123,17 @@ class ConvModule(nn.Module): ...@@ -121,12 +123,17 @@ class ConvModule(nn.Module):
# build normalization layers # build normalization layers
if self.with_norm: if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels # norm layer is after conv layer
if order.index('norm') > order.index('conv'):
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.add_module(self.norm_name, norm) self.add_module(self.norm_name, norm)
# build activation layer # build activation layer
if self.with_activatation: if self.with_activatation:
# TODO: introduce `act_cfg` and supports more activation layers
if self.activation not in ['relu']: if self.activation not in ['relu']:
raise ValueError('{} is currently not supported.'.format( raise ValueError('{} is currently not supported.'.format(
self.activation)) self.activation))
...@@ -147,17 +154,11 @@ class ConvModule(nn.Module): ...@@ -147,17 +154,11 @@ class ConvModule(nn.Module):
constant_init(self.norm, 1, bias=0) constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True): def forward(self, x, activate=True, norm=True):
if self.activate_last: for layer in self.order:
x = self.conv(x) if layer == 'conv':
if norm and self.with_norm: x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x) x = self.norm(x)
if activate and self.with_activatation: elif layer == 'act' and activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
else:
# WARN: this may be removed or modified
if norm and self.with_norm:
x = self.norm(x)
if activate and self.with_activatation:
x = self.activate(x)
x = self.conv(x)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment