"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8ead643bb786fe6bc80c9a4bd1730372d410a9df"
Commit e70dca8f authored by dreamerlin's avatar dreamerlin
Browse files

add conv3d

parent 144e7567
...@@ -3,7 +3,7 @@ from .alexnet import AlexNet ...@@ -3,7 +3,7 @@ from .alexnet import AlexNet
# yapf: disable # yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule, ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvTranspose3d, ConvWS2d, ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention, DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d, HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
...@@ -30,5 +30,5 @@ __all__ = [ ...@@ -30,5 +30,5 @@ __all__ = [
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d' 'MaxPool3d', 'Conv3d'
] ]
...@@ -17,8 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -17,8 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale from .scale import Scale
from .swish import Swish from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear, from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
MaxPool2d, MaxPool3d) Linear, MaxPool2d, MaxPool3d)
__all__ = [ __all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer', 'ConvModule', 'build_activation_layer', 'build_conv_layer',
...@@ -29,5 +29,5 @@ __all__ = [ ...@@ -29,5 +29,5 @@ __all__ = [
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d' 'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
] ]
...@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d): ...@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d):
return super().forward(x) return super().forward(x)
@CONV_LAYERS.register_module('Conv', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
@CONV_LAYERS.register_module() @CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv') @CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True) @UPSAMPLE_LAYERS.register_module('deconv', force=True)
......
...@@ -4,8 +4,8 @@ import pytest ...@@ -4,8 +4,8 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear, from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
MaxPool2d, MaxPool3d) Linear, MaxPool2d, MaxPool3d)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
...@@ -65,6 +65,64 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, ...@@ -65,6 +65,64 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv3d
"""
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h,
in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation', 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment