Unverified Commit acee61d7 authored by GT9505's avatar GT9505 Committed by GitHub
Browse files

register deconv in CONV_LAYERS (#582)

* register deconv in CONV_LAYERS

* use ConvTranspose2d implemented in MMCV

* remove repetitive register_module

* update

* add unittest for deconv
parent 34127b9f
......@@ -47,6 +47,8 @@ class Conv2d(nn.Conv2d):
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d):
......
......@@ -49,6 +49,15 @@ def test_build_conv_layer():
assert layer.groups == kwargs['groups']
assert layer.dilation == (kwargs['dilation'], kwargs['dilation'])
cfg = dict(type='deconv')
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, nn.ConvTranspose2d)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
assert layer.groups == kwargs['groups']
assert layer.dilation == (kwargs['dilation'], kwargs['dilation'])
for type_name, module in CONV_LAYERS.module_dict.items():
cfg = dict(type=type_name)
layer = build_conv_layer(cfg, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment