Unverified Commit a7641639 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Fix] Rename init_weight to init_weights (#971)

* [Fix] Rename init_weight to init_weights

* warning msg
parent ee041cec
......@@ -34,7 +34,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def is_init(self):
return self._is_init
def init_weight(self):
def init_weights(self):
"""Initialize the weights."""
from ..cnn import initialize
......@@ -42,11 +42,11 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
if self.init_cfg:
initialize(self, self.init_cfg)
for m in self.children():
if hasattr(m, 'init_weight'):
m.init_weight()
if hasattr(m, 'init_weights'):
m.init_weights()
self._is_init = True
else:
warnings.warn(f'init_weight of {self.__class__.__name__} has '
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
def __repr__(self):
......
......@@ -119,7 +119,7 @@ def test_model_weight_init():
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
......@@ -199,7 +199,7 @@ def test_nest_components_weight_init():
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 7.0))
......@@ -243,7 +243,7 @@ def test_without_layer_weight_init():
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
......@@ -276,7 +276,7 @@ def test_override_weight_init():
component1=dict(type='FooConv1d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
model.init_weights()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 10.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0))
......@@ -308,7 +308,7 @@ def test_override_weight_init():
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
model.init_weights()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 30.0))
......@@ -326,7 +326,7 @@ def test_sequential_model_weight_init():
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weight()
seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
......@@ -341,7 +341,7 @@ def test_sequential_model_weight_init():
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weight()
seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
......@@ -363,7 +363,7 @@ def test_modulelist_weight_init():
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weight()
modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
......@@ -378,7 +378,7 @@ def test_modulelist_weight_init():
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weight()
modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment