Unverified Commit 2fadb1a5 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Fix] Revise override in init_cfg (#930)

* [Fix] Config deep copy in initialize_override

* add asserts&comments

* add test

* test org init_cfg

* test override without name

* typo
parent 375605fb
......@@ -376,13 +376,26 @@ def _initialize_override(module, override, cfg):
override = [override] if isinstance(override, dict) else override
for override_ in override:
if 'type' not in override_.keys():
override_.update(cfg)
name = override_.pop('name', None)
cp_override = copy.deepcopy(override_)
name = cp_override.pop('name', None)
if name is None:
raise ValueError('`override` must contain the key "name",'
f'but got {cp_override}')
# if override only has name kay, it means use args in init_cfg
if not cp_override:
cp_override.update(cfg)
# if override has name key and other args except type key, it will
# raise error
elif 'type' not in cp_override.keys():
raise ValueError(
f'`override` need "type" key, but got {cp_override}')
if hasattr(module, name):
_initialize(getattr(module, name), override_, wholemodule=True)
_initialize(getattr(module, name), cp_override, wholemodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}')
raise RuntimeError(f'module did not have attribute {name}, '
f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg):
......@@ -394,10 +407,9 @@ def initialize(module, init_cfg):
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', val =1 , bias =2)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
......@@ -407,11 +419,7 @@ def initialize(module, init_cfg):
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # Omitting ``'layer'`` initialize module with same configuration
>>> init_cfg = dict(type='Constant', val=1, bias=2)
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific override in
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
......@@ -420,7 +428,7 @@ def initialize(module, init_cfg):
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2,
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
......
......@@ -266,13 +266,17 @@ def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()
# test layer key
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
assert init_cfg == dict(
type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
# test init_cfg with list type
init_cfg = [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
......@@ -282,7 +286,12 @@ def test_initialize():
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
assert init_cfg == [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
# test layer key and override key
init_cfg = dict(
type='Constant',
val=1,
......@@ -302,6 +311,31 @@ def test_initialize():
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test override key
init_cfg = dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
initialize(foonet, init_cfg)
assert not torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 5.))
assert not torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 6.))
assert not torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 5.))
assert not torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 6.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 5.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 6.))
assert init_cfg == dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
init_cfg = dict(
type='Pretrained',
......@@ -325,6 +359,11 @@ def test_initialize():
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
......@@ -362,3 +401,21 @@ def test_initialize():
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)
# test override with args except type key
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
# test override without name
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(type='Constant', val=3, bias=4))
initialize(foonet, init_cfg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment