Unverified Commit 25d07e3d authored by Jintao Lin's avatar Jintao Lin Committed by GitHub
Browse files

Support conv layers' own `init_weights` method (#278)

* Support conv layers' own `init_weights` method

* Add related unittest about ConvModule init_weight

* make the comments more specific
parent c35c228c
...@@ -150,14 +150,22 @@ class ConvModule(nn.Module): ...@@ -150,14 +150,22 @@ class ConvModule(nn.Module):
return getattr(self, self.norm_name) return getattr(self, self.norm_name)
def init_weights(self): def init_weights(self):
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': # 1. It is mainly for customized conv layers with their own
nonlinearity = 'leaky_relu' # initialization manners, and we do not want ConvModule to
a = self.act_cfg.get('negative_slope', 0.01) # overrides the initialization.
else: # 2. For customized conv layers without their own initialization
nonlinearity = 'relu' # manners, they will be initialized by this method with default
a = 0 # `kaiming_init`.
# 3. For PyTorch's conv layers, they will be initialized anyway by
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) # their own `reset_parameters` methods.
if not hasattr(self.conv, 'init_weights'):
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
nonlinearity = 'leaky_relu'
a = self.act_cfg.get('negative_slope', 0.01)
else:
nonlinearity = 'relu'
a = 0
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
if self.with_norm: if self.with_norm:
constant_init(self.norm, 1, bias=0) constant_init(self.norm, 1, bias=0)
......
...@@ -4,7 +4,44 @@ import pytest ...@@ -4,7 +4,44 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import CONV_LAYERS, ConvModule
@CONV_LAYERS.register_module()
class ExampleConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
norm_cfg=None):
super(ExampleConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.norm_cfg = norm_cfg
self.output_padding = (0, 0, 0)
self.transposed = False
self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size)
self.init_weights()
def forward(self, x):
x = self.conv0(x)
return x
def init_weights(self):
nn.init.constant_(self.conv0.weight, 0)
def test_conv_module(): def test_conv_module():
...@@ -53,6 +90,11 @@ def test_conv_module(): ...@@ -53,6 +90,11 @@ def test_conv_module():
output = conv(x) output = conv(x)
assert output.shape == (1, 8, 255, 255) assert output.shape == (1, 8, 255, 255)
# conv with its own `init_weights` method
conv_module = ConvModule(
3, 8, 2, conv_cfg=dict(type='ExampleConv'), act_cfg=None)
assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2))
# with_spectral_norm=True # with_spectral_norm=True
conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True) conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True)
assert hasattr(conv.conv, 'weight_orig') assert hasattr(conv.conv, 'weight_orig')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment