Commit 8ccea202 authored by dreamerlin's avatar dreamerlin
Browse files

add ConvTranspose3d

parent c390e327
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule, ContextBlock, Conv2d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule, ConvTranspose2d, ConvTranspose3d, ConvWS2d,
GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d, DepthwiseSeparableConvModule, GeneralizedAttention,
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish, HSigmoid, HSwish, Linear, MaxPool2d, NonLocal1d,
NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer, build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer, build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init, from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init, fuse_conv_bn, get_model_complexity_info, kaiming_init,
...@@ -26,5 +29,5 @@ __all__ = [ ...@@ -26,5 +29,5 @@ __all__ = [
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d' 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d'
] ]
...@@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale from .scale import Scale
from .swish import Swish from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d)
__all__ = [ __all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer', 'ConvModule', 'build_activation_layer', 'build_conv_layer',
...@@ -27,5 +28,6 @@ __all__ = [ ...@@ -27,5 +28,6 @@ __all__ = [
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d' 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d'
] ]
...@@ -78,7 +78,30 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -78,7 +78,30 @@ class ConvTranspose2d(nn.ConvTranspose2d):
else: else:
return empty return empty
return super(ConvTranspose2d, self).forward(x) return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
self.dilation, self.output_padding):
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
......
...@@ -5,7 +5,8 @@ from unittest.mock import patch ...@@ -5,7 +5,8 @@ from unittest.mock import patch
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn.bricks import Conv2d, ConvTranspose2d, Linear, MaxPool2d from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
...@@ -105,6 +106,61 @@ def test_conv_transposed_2d(): ...@@ -105,6 +106,61 @@ def test_conv_transposed_2d():
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_conv_transposed_3d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_t', [10, 20]), ('in_channel', [1, 3]),
('out_channel', [1, 3]), ('kernel_size', [3, 5]),
('stride', [1, 2]), ('padding', [0, 1]),
('dilation', [1, 2])])
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(s, d) - 1
torch.manual_seed(0)
wrapper = ConvTranspose3d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose3d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w)
wrapper = ConvTranspose3d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_max_pool_2d(): def test_max_pool_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment