Commit e70dca8f authored by dreamerlin's avatar dreamerlin
Browse files

add conv3d

parent 144e7567
......@@ -3,7 +3,7 @@ from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule,
ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
......@@ -30,5 +30,5 @@ __all__ = [
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d'
'MaxPool3d', 'Conv3d'
]
......@@ -17,8 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d, MaxPool3d)
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
......@@ -29,5 +29,5 @@ __all__ = [
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d'
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
]
......@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d):
return super().forward(x)
@CONV_LAYERS.register_module('Conv', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
......
......@@ -4,8 +4,8 @@ import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d, MaxPool3d)
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
@patch('torch.__version__', '1.1')
......@@ -65,6 +65,64 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
wrapper(x_empty)
@patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv3d
"""
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h,
in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment