Unverified Commit dfa36dfe authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Merge pull request #652 from dreamerlin/3d

[Feature] Add 3D support in wrapper
parents ec43b671 1a12ac75
# Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule,
GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d,
ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init,
......@@ -26,5 +29,6 @@ __all__ = [
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d'
]
......@@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
......@@ -27,5 +28,6 @@ __all__ = [
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
]
......@@ -8,7 +8,7 @@ import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
......@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d):
return super().forward(x)
@CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
......@@ -78,7 +99,30 @@ class ConvTranspose2d(nn.ConvTranspose2d):
else:
return empty
return super(ConvTranspose2d, self).forward(x)
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
self.dilation, self.output_padding):
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
class MaxPool2d(nn.MaxPool2d):
......@@ -99,6 +143,25 @@ class MaxPool2d(nn.MaxPool2d):
return super().forward(x)
class MaxPool3d(nn.MaxPool3d):
def forward(self, x):
# PyTorch 1.7 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
_triple(self.stride),
_triple(self.dilation)):
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
o = math.ceil(o) if self.ceil_mode else math.floor(o)
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
return empty
return super().forward(x)
class Linear(torch.nn.Linear):
def forward(self, x):
......
from collections import OrderedDict
from itertools import product
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import Conv2d, ConvTranspose2d, Linear, MaxPool2d
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
@patch('torch.__version__', '1.1')
def test_conv2d():
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d
"""
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv3d
"""
# train mode
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w)
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True)
x_normal = torch.randn(3, in_channel, in_t, in_h,
in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
ref = nn.Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
......@@ -45,46 +111,49 @@ def test_conv2d():
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_conv_transposed_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(s, d) - 1
op = min(stride, dilation) - 1
torch.manual_seed(0)
wrapper = ConvTranspose2d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w)
x_normal = torch.randn(3, in_channel, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose2d(
in_cha,
out_cha,
k,
stride=s,
padding=p,
dilation=d,
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
ref_out = ref(x_normal)
......@@ -98,30 +167,116 @@ def test_conv_transposed_2d():
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w)
x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = ConvTranspose2d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op)
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
kernel_size, stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(stride, dilation) - 1
torch.manual_seed(0)
wrapper = ConvTranspose3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = ConvTranspose3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_max_pool_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w)
ref = nn.MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
wrapper = MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_h, in_w)
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
ref = nn.MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
......@@ -131,16 +286,9 @@ def test_max_pool_2d():
@patch('torch.__version__', '1.1')
def test_linear():
test_cases = OrderedDict([
('in_w', [10, 20]),
('in_h', [10, 20]),
('in_feature', [1, 3]),
('out_feature', [1, 3]),
])
for in_h, in_w, in_feature, out_feature in product(
*list(test_cases.values())):
@pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1),
(20, 20, 3, 3)])
def test_linear(in_w, in_h, in_feature, out_feature):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_feature, requires_grad=True)
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment