Unverified Commit 50c255bc authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Feature] Support to use name of the base classes in init_cfg (#1057)

* [Fix] Support names of base classes matching in init_cfg

* revise bool to len
parent bf2c9fa8
...@@ -93,6 +93,10 @@ def bias_init_with_prob(prior_prob): ...@@ -93,6 +93,10 @@ def bias_init_with_prob(prior_prob):
return bias_init return bias_init
def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object): class BaseInit(object):
def __init__(self, *, bias=0, bias_prob=None, layer=None): def __init__(self, *, bias=0, bias_prob=None, layer=None):
...@@ -146,7 +150,8 @@ class ConstantInit(BaseInit): ...@@ -146,7 +150,8 @@ class ConstantInit(BaseInit):
constant_init(m, self.val, self.bias) constant_init(m, self.val, self.bias)
else: else:
layername = m.__class__.__name__ layername = m.__class__.__name__
if layername in self.layer: basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias) constant_init(m, self.val, self.bias)
module.apply(init) module.apply(init)
...@@ -183,7 +188,8 @@ class XavierInit(BaseInit): ...@@ -183,7 +188,8 @@ class XavierInit(BaseInit):
xavier_init(m, self.gain, self.bias, self.distribution) xavier_init(m, self.gain, self.bias, self.distribution)
else: else:
layername = m.__class__.__name__ layername = m.__class__.__name__
if layername in self.layer: basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution) xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init) module.apply(init)
...@@ -219,8 +225,8 @@ class NormalInit(BaseInit): ...@@ -219,8 +225,8 @@ class NormalInit(BaseInit):
normal_init(m, self.mean, self.std, self.bias) normal_init(m, self.mean, self.std, self.bias)
else: else:
layername = m.__class__.__name__ layername = m.__class__.__name__
for layer_ in self.layer: basesname = _get_bases_name(m)
if layername == layer_: if len(set(self.layer) & set([layername] + basesname)):
normal_init(m, self.mean, self.std, self.bias) normal_init(m, self.mean, self.std, self.bias)
module.apply(init) module.apply(init)
...@@ -267,10 +273,10 @@ class TruncNormalInit(BaseInit): ...@@ -267,10 +273,10 @@ class TruncNormalInit(BaseInit):
self.bias) self.bias)
else: else:
layername = m.__class__.__name__ layername = m.__class__.__name__
for layer_ in self.layer: basesname = _get_bases_name(m)
if layername == layer_: if len(set(self.layer) & set([layername] + basesname)):
trunc_normal_init(m, self.mean, self.std, self.a, trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.b, self.bias) self.bias)
module.apply(init) module.apply(init)
...@@ -305,7 +311,8 @@ class UniformInit(BaseInit): ...@@ -305,7 +311,8 @@ class UniformInit(BaseInit):
uniform_init(m, self.a, self.b, self.bias) uniform_init(m, self.a, self.b, self.bias)
else: else:
layername = m.__class__.__name__ layername = m.__class__.__name__
if layername in self.layer: basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias) uniform_init(m, self.a, self.b, self.bias)
module.apply(init) module.apply(init)
...@@ -359,7 +366,8 @@ class KaimingInit(BaseInit): ...@@ -359,7 +366,8 @@ class KaimingInit(BaseInit):
self.bias, self.distribution) self.bias, self.distribution)
else: else:
layername = m.__class__.__name__ layername = m.__class__.__name__
if layername in self.layer: basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity, kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution) self.bias, self.distribution)
......
...@@ -134,6 +134,15 @@ def test_constaninit(): ...@@ -134,6 +134,15 @@ def test_constaninit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
# test bias input type # test bias input type
with pytest.raises(TypeError): with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1') func = ConstantInit(val=1, bias='1')
...@@ -170,6 +179,22 @@ def test_xavierinit(): ...@@ -170,6 +179,22 @@ def test_xavierinit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
func(model)
assert not torch.all(model[0].weight == 4.)
assert not torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
# test bias input type # test bias input type
with pytest.raises(TypeError): with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d') func = XavierInit(bias='0.1', layer='Conv2d')
...@@ -198,6 +223,16 @@ def test_normalinit(): ...@@ -198,6 +223,16 @@ def test_normalinit():
assert model[0].bias.allclose(torch.tensor(res)) assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res)) assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_truncnormalinit(): def test_truncnormalinit():
"""test TruncNormalInit class.""" """test TruncNormalInit class."""
...@@ -225,6 +260,17 @@ def test_truncnormalinit(): ...@@ -225,6 +260,17 @@ def test_truncnormalinit():
assert model[0].bias.allclose(torch.tensor(res)) assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res)) assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = TruncNormalInit(
mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_uniforminit(): def test_uniforminit():
""""test UniformInit class.""" """"test UniformInit class."""
...@@ -245,6 +291,17 @@ def test_uniforminit(): ...@@ -245,6 +291,17 @@ def test_uniforminit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
res = bias_init_with_prob(0.01)
func(model)
assert torch.all(model[0].weight == 100.)
assert torch.all(model[2].weight == 100.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_kaiminginit(): def test_kaiminginit():
"""test KaimingInit class.""" """test KaimingInit class."""
...@@ -270,6 +327,29 @@ def test_kaiminginit(): ...@@ -270,6 +327,29 @@ def test_kaiminginit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = KaimingInit(bias=0.1, layer='_ConvNd')
func(model)
assert torch.all(model[0].bias == 0.1)
assert torch.all(model[2].bias == 0.1)
func = KaimingInit(a=100, bias=10, layer='_ConvNd')
constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit(): def test_caffe2xavierinit():
"""test Caffe2XavierInit.""" """test Caffe2XavierInit."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment