Unverified Commit a7641639 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Fix] Rename init_weight to init_weights (#971)

* [Fix] Rename init_weight to init_weights

* warning msg
parent ee041cec
...@@ -34,7 +34,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -34,7 +34,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def is_init(self): def is_init(self):
return self._is_init return self._is_init
def init_weight(self): def init_weights(self):
"""Initialize the weights.""" """Initialize the weights."""
from ..cnn import initialize from ..cnn import initialize
...@@ -42,11 +42,11 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -42,11 +42,11 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
if self.init_cfg: if self.init_cfg:
initialize(self, self.init_cfg) initialize(self, self.init_cfg)
for m in self.children(): for m in self.children():
if hasattr(m, 'init_weight'): if hasattr(m, 'init_weights'):
m.init_weight() m.init_weights()
self._is_init = True self._is_init = True
else: else:
warnings.warn(f'init_weight of {self.__class__.__name__} has ' warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.') f'been called more than once.')
def __repr__(self): def __repr__(self):
......
...@@ -119,7 +119,7 @@ def test_model_weight_init(): ...@@ -119,7 +119,7 @@ def test_model_weight_init():
conv1d=dict(type='FooConv1d'))) conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS) model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight() model.init_weights()
assert torch.equal(model.component1.conv1d.weight, assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0)) torch.full(model.component1.conv1d.weight.shape, 3.0))
...@@ -199,7 +199,7 @@ def test_nest_components_weight_init(): ...@@ -199,7 +199,7 @@ def test_nest_components_weight_init():
conv1d=dict(type='FooConv1d'))) conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS) model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight() model.init_weights()
assert torch.equal(model.component1.conv1d.weight, assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 7.0)) torch.full(model.component1.conv1d.weight.shape, 7.0))
...@@ -243,7 +243,7 @@ def test_without_layer_weight_init(): ...@@ -243,7 +243,7 @@ def test_without_layer_weight_init():
component2=dict(type='FooConv2d'), component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear')) component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS) model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight() model.init_weights()
assert torch.equal(model.component1.conv1d.weight, assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0)) torch.full(model.component1.conv1d.weight.shape, 3.0))
...@@ -276,7 +276,7 @@ def test_override_weight_init(): ...@@ -276,7 +276,7 @@ def test_override_weight_init():
component1=dict(type='FooConv1d'), component1=dict(type='FooConv1d'),
component3=dict(type='FooLinear')) component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS) model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight() model.init_weights()
assert torch.equal(model.reg.weight, assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 10.0)) torch.full(model.reg.weight.shape, 10.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0))
...@@ -308,7 +308,7 @@ def test_override_weight_init(): ...@@ -308,7 +308,7 @@ def test_override_weight_init():
component2=dict(type='FooConv2d'), component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear')) component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS) model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight() model.init_weights()
assert torch.equal(model.reg.weight, assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 30.0)) torch.full(model.reg.weight.shape, 30.0))
...@@ -326,7 +326,7 @@ def test_sequential_model_weight_init(): ...@@ -326,7 +326,7 @@ def test_sequential_model_weight_init():
] ]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers) seq_model = Sequential(*layers)
seq_model.init_weight() seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight, assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)) torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias, assert torch.equal(seq_model[0].conv1d.bias,
...@@ -341,7 +341,7 @@ def test_sequential_model_weight_init(): ...@@ -341,7 +341,7 @@ def test_sequential_model_weight_init():
*layers, *layers,
init_cfg=dict( init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weight() seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight, assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)) torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias, assert torch.equal(seq_model[0].conv1d.bias,
...@@ -363,7 +363,7 @@ def test_modulelist_weight_init(): ...@@ -363,7 +363,7 @@ def test_modulelist_weight_init():
] ]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers) modellist = ModuleList(layers)
modellist.init_weight() modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight, assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)) torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias, assert torch.equal(modellist[0].conv1d.bias,
...@@ -378,7 +378,7 @@ def test_modulelist_weight_init(): ...@@ -378,7 +378,7 @@ def test_modulelist_weight_init():
layers, layers,
init_cfg=dict( init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weight() modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight, assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)) torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias, assert torch.equal(modellist[0].conv1d.bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment