Unverified Commit 967d9d58 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Refactoring]Add Sequential with init_weight (#884)

* [Refactoring]Add BaseSequtial with init_weight

* revise according to comments

* revise comments

* minors

* baseseq2seq

* add modulelist

* revise minors

* fix isort

* format
parent 73bff4ea
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base_module import BaseModule from .base_module import BaseModule, ModuleList, Sequential
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint, from .checkpoint import (CheckpointLoader, _load_checkpoint,
...@@ -36,5 +36,6 @@ __all__ = [ ...@@ -36,5 +36,6 @@ __all__ = [
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix' 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix',
'Sequential', 'ModuleList'
] ]
...@@ -51,3 +51,34 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -51,3 +51,34 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
else: else:
warnings.warn('This module has bee initialized, \ warnings.warn('This module has bee initialized, \
please call initialize(module, init_cfg) to reinitialize it') please call initialize(module, init_cfg) to reinitialize it')
def __repr__(self):
s = super().__repr__()
if hasattr(self, 'init_cfg'):
s += f'\ninit_cfg={self.init_cfg}'
return s
class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Args:
modules (iterable, optional): an iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, modules=None, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
import torch import torch
from torch import nn from torch import nn
from mmcv.runner import BaseModule from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component') COMPONENTS = Registry('component')
...@@ -226,3 +226,65 @@ def test_nest_components_weight_init(): ...@@ -226,3 +226,65 @@ def test_nest_components_weight_init():
assert torch.equal(model.reg.weight, assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 13.0)) torch.full(model.reg.weight.shape, 13.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))
def test_sequential_model_weight_init():
seq_model_cfg = [
dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=0., bias=1.)),
dict(
type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weight()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
# inner init_cfg has highter priority
seq_model = Sequential(
*layers, init_cfg=dict(type='Constant', val=4., bias=5.))
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
def test_modulelist_weight_init():
models_cfg = [
dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=0., bias=1.)),
dict(
type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weight()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))
# inner init_cfg has highter priority
modellist = ModuleList(
layers, init_cfg=dict(type='Constant', val=4., bias=5.))
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment