Unverified Commit c937d395 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Mv wrappers into bricks and use wrappers in registry (#550)

* Mv wrappers into bricks and use wrappers in registry

* resolve import issues

* fix import issues

* set nn op forward to torch 1.6.1

* fix CI bug and add warning

* Fix CI by using patch mock

* mv warnings inside deprecated module's initialization
parent c054a239
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
from .alexnet import AlexNet from .alexnet import AlexNet
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, ConvAWS2d, ConvModule, ConvWS2d, ContextBlock, Conv2d, ConvAWS2d, ConvModule,
DepthwiseSeparableConvModule, GeneralizedAttention, ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule,
HSigmoid, HSwish, NonLocal1d, NonLocal2d, NonLocal3d, GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d,
Scale, Swish, build_activation_layer, build_conv_layer, NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer, build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
...@@ -24,5 +25,6 @@ __all__ = [ ...@@ -24,5 +25,6 @@ __all__ = [
'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule' 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
] ]
...@@ -17,6 +17,7 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -17,6 +17,7 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale from .scale import Scale
from .swish import Swish from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
__all__ = [ __all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer', 'ConvModule', 'build_activation_layer', 'build_conv_layer',
...@@ -25,6 +26,6 @@ __all__ = [ ...@@ -25,6 +26,6 @@ __all__ = [
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding' 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
] ]
...@@ -6,7 +6,6 @@ from .registry import UPSAMPLE_LAYERS ...@@ -6,7 +6,6 @@ from .registry import UPSAMPLE_LAYERS
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample) UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample) UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('deconv', module=nn.ConvTranspose2d)
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle') @UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from ..cnn import CONV_LAYERS from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
class NewEmptyTensorOp(torch.autograd.Function): class NewEmptyTensorOp(torch.autograd.Function):
...@@ -47,6 +47,7 @@ class Conv2d(nn.Conv2d): ...@@ -47,6 +47,7 @@ class Conv2d(nn.Conv2d):
return super().forward(x) return super().forward(x)
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x): def forward(self, x):
...@@ -70,7 +71,8 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -70,7 +71,8 @@ class ConvTranspose2d(nn.ConvTranspose2d):
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
def forward(self, x): def forward(self, x):
if x.numel() == 0 and torch.__version__ <= '1.4': # PyTorch 1.6 does not support empty tensor inference yet
if x.numel() == 0 and torch.__version__ <= '1.6':
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride), _pair(self.padding), _pair(self.stride),
......
...@@ -5,6 +5,10 @@ from .corner_pool import CornerPool ...@@ -5,6 +5,10 @@ from .corner_pool import CornerPool
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack, from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
ModulatedDeformRoIPoolPack, deform_roi_pool) ModulatedDeformRoIPoolPack, deform_roi_pool)
from .deprecated_wrappers import Conv2d_deprecated as Conv2d
from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss) sigmoid_focal_loss, softmax_focal_loss)
from .info import get_compiler_version, get_compiling_cuda_version from .info import get_compiler_version, get_compiling_cuda_version
...@@ -21,7 +25,6 @@ from .roi_pool import RoIPool, roi_pool ...@@ -21,7 +25,6 @@ from .roi_pool import RoIPool, roi_pool
from .saconv import SAConv2d from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm from .sync_bn import SyncBatchNorm
from .tin_shift import TINShift, tin_shift from .tin_shift import TINShift, tin_shift
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
__all__ = [ __all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
......
# This file is for backward compatibility.
# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks.
import warnings
from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
class Conv2d_deprecated(Conv2d):
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead')
class ConvTranspose2d_deprecated(ConvTranspose2d):
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
'deprecated in the future. Please import them from "mmcv.cnn" '
'instead')
class MaxPool2d_deprecated(MaxPool2d):
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead')
class Linear_deprecated(Linear):
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead')
...@@ -5,11 +5,10 @@ from unittest.mock import patch ...@@ -5,11 +5,10 @@ from unittest.mock import patch
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.ops import Conv2d, ConvTranspose2d, Linear, MaxPool2d from mmcv.cnn.bricks import Conv2d, ConvTranspose2d, Linear, MaxPool2d
torch.__version__ = '1.1' # force test
@patch('torch.__version__', '1.1')
def test_conv2d(): def test_conv2d():
""" """
CommandLine: CommandLine:
...@@ -52,6 +51,7 @@ def test_conv2d(): ...@@ -52,6 +51,7 @@ def test_conv2d():
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_conv_transposed_2d(): def test_conv_transposed_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]), ('in_channel', [1, 3]), ('out_channel', [1, 3]),
...@@ -105,6 +105,7 @@ def test_conv_transposed_2d(): ...@@ -105,6 +105,7 @@ def test_conv_transposed_2d():
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1')
def test_max_pool_2d(): def test_max_pool_2d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]), ('in_channel', [1, 3]), ('out_channel', [1, 3]),
...@@ -129,6 +130,7 @@ def test_max_pool_2d(): ...@@ -129,6 +130,7 @@ def test_max_pool_2d():
assert torch.equal(wrapper(x_normal), ref_out) assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1')
def test_linear(): def test_linear():
test_cases = OrderedDict([ test_cases = OrderedDict([
('in_w', [10, 20]), ('in_w', [10, 20]),
...@@ -167,8 +169,8 @@ def test_linear(): ...@@ -167,8 +169,8 @@ def test_linear():
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.6.1')
def test_nn_op_forward_called(): def test_nn_op_forward_called():
torch.__version__ = '1.4.1'
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward: with patch(f'torch.nn.{m}.forward') as nn_module_forward:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment