Unverified Commit 56e71a71 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Add Depthwise Seperable ConvModule (#477)

parent 530ae200
......@@ -3,11 +3,11 @@ from .alexnet import AlexNet
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, ConvAWS2d, ConvModule, ConvWS2d,
GeneralizedAttention, HSigmoid, HSwish, NonLocal1d,
NonLocal2d, NonLocal3d, Scale, build_activation_layer,
build_conv_layer, build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, conv_ws_2d,
is_norm)
DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, NonLocal1d, NonLocal2d, NonLocal3d,
Scale, build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init,
......@@ -24,5 +24,5 @@ __all__ = [
'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn'
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule'
]
......@@ -3,6 +3,7 @@ from .context_block import ContextBlock
from .conv import build_conv_layer
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
......@@ -22,5 +23,5 @@ __all__ = [
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d'
'conv_ws_2d', 'DepthwiseSeparableConvModule'
]
import torch.nn as nn
from .conv_module import ConvModule
class DepthwiseSeparableConvModule(nn.Module):
"""Depthwise separable convolution module.
See https://arxiv.org/pdf/1704.04861.pdf for details.
This module can replace a ConvModule with the conv block replaced by two
conv block: depthwise conv block and pointwise conv block. The depthwise
conv block contains depthwise-conv/norm/activation layers. The pointwise
conv block contains pointwise-conv/norm/activation layers. It should be
noted that there will be norm/activation layer in the depthwise conv block
if `norm_cfg` and `act_cfg` are specified.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d. Default: 1.
padding (int or tuple[int]): Same as nn.Conv2d. Default: 0.
dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1.
norm_cfg (dict): Default norm config for both depthwise ConvModule and
pointwise ConvModule. Default: None.
act_cfg (dict): Default activation config for both depthwise ConvModule
and pointwise ConvModule. Default: dict(type='ReLU').
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
kwargs (optional): Other shared arguments for depthwise and pointwise
ConvModule. See ConvModule for ref.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dw_norm_cfg='default',
dw_act_cfg='default',
pw_norm_cfg='default',
pw_act_cfg='default',
**kwargs):
super(DepthwiseSeparableConvModule, self).__init__()
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution
self.depthwise_conv = ConvModule(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
norm_cfg=dw_norm_cfg,
act_cfg=dw_act_cfg,
**kwargs)
self.pointwise_conv = ConvModule(
in_channels,
out_channels,
1,
norm_cfg=pw_norm_cfg,
act_cfg=pw_act_cfg,
**kwargs)
def forward(self, x):
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DepthwiseSeparableConvModule
def test_depthwise_separable_conv():
with pytest.raises(AssertionError):
# conv_cfg must be a dict or None
DepthwiseSeparableConvModule(4, 8, 2, groups=2)
# test default config
conv = DepthwiseSeparableConvModule(3, 8, 2)
assert conv.depthwise_conv.conv.groups == 3
assert conv.pointwise_conv.conv.kernel_size == (1, 1)
assert not conv.depthwise_conv.with_norm
assert not conv.pointwise_conv.with_norm
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test dw_norm_cfg
conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN'))
assert conv.depthwise_conv.norm_name == 'bn'
assert not conv.pointwise_conv.with_norm
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test pw_norm_cfg
conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN'))
assert not conv.depthwise_conv.with_norm
assert conv.pointwise_conv.norm_name == 'bn'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test norm_cfg
conv = DepthwiseSeparableConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
assert conv.depthwise_conv.norm_name == 'bn'
assert conv.pointwise_conv.norm_name == 'bn'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# add test for ['norm', 'conv', 'act']
conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act'))
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, with_spectral_norm=True)
assert hasattr(conv.depthwise_conv.conv, 'weight_orig')
assert hasattr(conv.pointwise_conv.conv, 'weight_orig')
output = conv(x)
assert output.shape == (1, 8, 256, 256)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, padding_mode='reflect')
assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# test dw_act_cfg
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, dw_act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# test pw_act_cfg
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, pw_act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# test act_cfg
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment