# Copyright (c) OpenMMLab. All rights reserved. import warnings from unittest.mock import patch import pytest import torch import torch.nn as nn from mmengine.registry import MODELS from mmcv.cnn.bricks import ConvModule, HSigmoid, HSwish from mmcv.utils import TORCH_VERSION, digit_version @MODELS.register_module() class ExampleConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm_cfg=None): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.bias = bias self.norm_cfg = norm_cfg self.output_padding = (0, 0, 0) self.transposed = False self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size) self.init_weights() def forward(self, x): x = self.conv0(x) return x def init_weights(self): nn.init.constant_(self.conv0.weight, 0) def test_conv_module(): with pytest.raises(AssertionError): # conv_cfg must be a dict or None conv_cfg = 'conv' ConvModule(3, 8, 2, conv_cfg=conv_cfg) with pytest.raises(AssertionError): # norm_cfg must be a dict or None norm_cfg = 'norm' ConvModule(3, 8, 2, norm_cfg=norm_cfg) with pytest.raises(KeyError): # softmax is not supported act_cfg = dict(type='softmax') ConvModule(3, 8, 2, act_cfg=act_cfg) # conv + norm + act conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) assert conv.with_activation assert hasattr(conv, 'activate') assert conv.with_norm assert hasattr(conv, 'norm') x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # conv + act conv = ConvModule(3, 8, 2) assert conv.with_activation assert hasattr(conv, 'activate') assert not conv.with_norm assert conv.norm is None x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # conv conv = ConvModule(3, 8, 2, act_cfg=None) assert not conv.with_norm assert conv.norm is None assert not conv.with_activation assert not hasattr(conv, 'activate') x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # conv with its own `init_weights` method conv_module = ConvModule( 3, 8, 2, conv_cfg=dict(type='ExampleConv'), act_cfg=None) assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2)) # with_spectral_norm=True conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True) assert hasattr(conv.conv, 'weight_orig') output = conv(x) assert output.shape == (1, 8, 256, 256) # padding_mode='reflect' conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect') assert isinstance(conv.padding_layer, nn.ReflectionPad2d) output = conv(x) assert output.shape == (1, 8, 256, 256) # non-existing padding mode with pytest.raises(KeyError): conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists') # leaky relu conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU')) assert isinstance(conv.activate, nn.LeakyReLU) output = conv(x) assert output.shape == (1, 8, 256, 256) # tanh conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Tanh')) assert isinstance(conv.activate, nn.Tanh) output = conv(x) assert output.shape == (1, 8, 256, 256) # Sigmoid conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Sigmoid')) assert isinstance(conv.activate, nn.Sigmoid) output = conv(x) assert output.shape == (1, 8, 256, 256) # PReLU conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='PReLU')) assert isinstance(conv.activate, nn.PReLU) output = conv(x) assert output.shape == (1, 8, 256, 256) # HSwish conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish')) if (TORCH_VERSION == 'parrots' or digit_version(TORCH_VERSION) < digit_version('1.7')): assert isinstance(conv.activate, HSwish) else: assert isinstance(conv.activate, nn.Hardswish) output = conv(x) assert output.shape == (1, 8, 256, 256) # HSigmoid conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid')) assert isinstance(conv.activate, HSigmoid) output = conv(x) assert output.shape == (1, 8, 256, 256) def test_bias(): # bias: auto, without norm conv = ConvModule(3, 8, 2) assert conv.conv.bias is not None # bias: auto, with norm conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) assert conv.conv.bias is None # bias: False, without norm conv = ConvModule(3, 8, 2, bias=False) assert conv.conv.bias is None # bias: True, with batch norm with pytest.warns(UserWarning) as record: ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN')) assert len(record) == 1 assert record[0].message.args[ 0] == 'Unnecessary conv bias before batch/instance norm' # bias: True, with instance norm with pytest.warns(UserWarning) as record: ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN')) assert len(record) == 1 assert record[0].message.args[ 0] == 'Unnecessary conv bias before batch/instance norm' # bias: True, with other norm with pytest.warns(UserWarning) as record: norm_cfg = dict(type='GN', num_groups=1) ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg) warnings.warn('No warnings') assert len(record) == 1 assert record[0].message.args[0] == 'No warnings' def conv_forward(self, x): return x + '_conv' def bn_forward(self, x): return x + '_bn' def relu_forward(self, x): return x + '_relu' @patch('torch.nn.ReLU.forward', relu_forward) @patch('torch.nn.BatchNorm2d.forward', bn_forward) @patch('torch.nn.Conv2d.forward', conv_forward) def test_order(): with pytest.raises(AssertionError): # order must be a tuple order = ['conv', 'norm', 'act'] ConvModule(3, 8, 2, order=order) with pytest.raises(AssertionError): # length of order must be 3 order = ('conv', 'norm') ConvModule(3, 8, 2, order=order) with pytest.raises(AssertionError): # order must be an order of 'conv', 'norm', 'act' order = ('conv', 'norm', 'norm') ConvModule(3, 8, 2, order=order) with pytest.raises(AssertionError): # order must be an order of 'conv', 'norm', 'act' order = ('conv', 'norm', 'something') ConvModule(3, 8, 2, order=order) # ('conv', 'norm', 'act') conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) out = conv('input') assert out == 'input_conv_bn_relu' # ('norm', 'conv', 'act') conv = ConvModule( 3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act')) out = conv('input') assert out == 'input_bn_conv_relu' # ('conv', 'norm', 'act'), activate=False conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) out = conv('input', activate=False) assert out == 'input_conv_bn' # ('conv', 'norm', 'act'), activate=False conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) out = conv('input', norm=False) assert out == 'input_conv_relu'