You need to sign in or sign up before continuing.
Unverified Commit 18c64d5f authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

Fix potential bugs of basemodule when record the initilization information (#1217)

* add logger for init

* change init_info of oevrload init_weight

* add judgement for params_init_info

* add delete comments for params_init_info

* add docstr and more comments

* add docstr and more comments

* resolve comments

* dump to a file

* add unitest

* fix unitest

* fix unitest

* write to ori log

* fix typo

* resolve commnets

* fix call initweights twice in topmost module

* fix the potential bug of recursive import

* fix unitest

* fix potiential bugs

* remove unneccesary change

* add more unitest

* fix add param in initweights

* add more comments

* raise error

* add more detail assert error
parent b5a4bbd0
...@@ -8,12 +8,43 @@ import torch ...@@ -8,12 +8,43 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer') INITIALIZERS = Registry('initializer')
def update_init_info(module, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert hasattr(
module,
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
for name, param in module.named_parameters():
assert param in module._params_init_info, (
f'Find a new :obj:`Parameter` '
f'named `{name}` during executing the '
f'`init_weights` of '
f'`{module.__class__.__name__}`. '
f'Please do not add or '
f'replace parameters during executing '
f'the `init_weights`. ')
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0): def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val) nn.init.constant_(module.weight, val)
......
...@@ -11,23 +11,6 @@ from mmcv.runner.dist_utils import master_only ...@@ -11,23 +11,6 @@ from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log from mmcv.utils.logging import get_logger, logger_initialized, print_log
def update_init_info(module, *, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
for param in module.parameters():
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
class BaseModule(nn.Module, metaclass=ABCMeta): class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab. """Base module for all modules in openmmlab.
...@@ -36,11 +19,12 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -36,11 +19,12 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization. - ``init_cfg``: the config to control the initialization.
- ``_params_init_info``: Used to track the parameter
initialization information.
- ``init_weights``: The function of parameter - ``init_weights``: The function of parameter
initialization and recording initialization initialization and recording initialization
information. information.
- ``_params_init_info``: Used to track the parameter
initialization information. This attribute only
exists during executing the ``init_weights``.
Args: Args:
init_cfg (dict, optional): Initialization config dict. init_cfg (dict, optional): Initialization config dict.
...@@ -59,17 +43,6 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -59,17 +43,6 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
self.init_cfg = copy.deepcopy(init_cfg) self.init_cfg = copy.deepcopy(init_cfg)
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - param_name (str): The name of parameter.
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters is initialized.
self._params_init_info = defaultdict(dict)
# Backward compatibility in derived classes # Backward compatibility in derived classes
# if pretrained is not None: # if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
...@@ -83,15 +56,26 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -83,15 +56,26 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def init_weights(self): def init_weights(self):
"""Initialize the weights.""" """Initialize the weights."""
is_top_level_module = False
# check if it is top-level module # check if it is top-level module
is_top_level_module = len(self._params_init_info) == 0 if not hasattr(self, '_params_init_info'):
if is_top_level_module: # The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info = defaultdict(dict)
is_top_level_module = True
# Initialize the `_params_init_info`, # Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of # When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related # the corresponding parameter is changed, update related
# initialization information # initialization information
for name, param in self.named_parameters(): for name, param in self.named_parameters():
self._params_init_info[param]['param_name'] = name
self._params_init_info[param][ self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \ 'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \ f'after calling `init_weights` ' \
...@@ -112,6 +96,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -112,6 +96,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
logger_name = logger_names[0] if logger_names else 'mmcv' logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize from ..cnn import initialize
from ..cnn.utils.weight_init import update_init_info
module_name = self.__class__.__name__ module_name = self.__class__.__name__
if not self._is_init: if not self._is_init:
if self.init_cfg: if self.init_cfg:
...@@ -165,15 +150,17 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -165,15 +150,17 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
if isinstance(handler, FileHandler): if isinstance(handler, FileHandler):
handler.stream.write( handler.stream.write(
'Name of parameter - Initialization information\n') 'Name of parameter - Initialization information\n')
for item in list(self._params_init_info.values()): for name, param in self.named_parameters():
handler.stream.write( handler.stream.write(
f"{item['param_name']} - {item['init_info']} \n") f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n")
handler.stream.flush() handler.stream.flush()
with_file_handler = True with_file_handler = True
if not with_file_handler: if not with_file_handler:
for item in list(self._params_init_info.values()): for name, param in self.named_parameters():
print_log( print_log(
f"{item['param_name']} - {item['init_info']}", f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n ",
logger=logger_name) logger=logger_name)
def __repr__(self): def __repr__(self):
......
import tempfile import tempfile
import pytest
import torch import torch
from torch import nn from torch import nn
import mmcv
from mmcv.cnn.utils.weight_init import update_init_info
from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component') COMPONENTS = Registry('component')
...@@ -123,25 +125,94 @@ def test_initilization_info_logger(): ...@@ -123,25 +125,94 @@ def test_initilization_info_logger():
log_file = os.path.join(workdir, train_log) log_file = os.path.join(workdir, train_log)
# create a logger # create a logger
get_logger('init_logger', log_file=log_file) get_logger('init_logger', log_file=log_file)
assert hasattr(model, '_params_init_info') assert not hasattr(model, '_params_init_info')
model.init_weights() model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights` # assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info') assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped # assert initialization information has been dumped
assert os.path.exists(log_file) assert os.path.exists(log_file)
with open(log_file) as f: lines = mmcv.list_from_file(log_file)
lines = f.readlines()
for line in lines:
print(line)
# check initialization information is right # check initialization information is right
for line in lines: for i, line in enumerate(lines):
if 'conv1.weight' in line: if 'conv1.weight' in line:
assert 'NormalInit' in line assert 'NormalInit' in lines[i + 1]
if 'conv2.weight' in line: if 'conv2.weight' in line:
assert 'OverloadInitConv' in line assert 'OverloadInitConv' in lines[i + 1]
if 'fc1.weight' in line: if 'fc1.weight' in line:
assert 'ConstantInit' in line assert 'ConstantInit' in lines[i + 1]
# test corner case
class OverloadInitConvFc(nn.Conv2d, BaseModule):
def __init__(self, *args, **kwargs):
super(OverloadInitConvFc, self).__init__(*args, **kwargs)
self.conv1 = nn.Linear(1, 1)
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConvFc(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
class TopLevelModule(BaseModule):
def __init__(self, init_cfg=None, checklog_init_cfg=None):
super(TopLevelModule, self).__init__(init_cfg)
self.module1 = CheckLoggerModel(checklog_init_cfg)
self.module2 = OverloadInitConvFc(1, 1, 1, 1)
checklog_init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
top_level_init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='module2', std=0.01, bias_prob=0.01))
]
model = TopLevelModule(
init_cfg=top_level_init_cfg, checklog_init_cfg=checklog_init_cfg)
model.module1.init_weights()
model.module2.init_weights()
model.init_weights()
model.module1.init_weights()
model.module2.init_weights()
assert not hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
lines = mmcv.list_from_file(log_file)
# check initialization information is right
for i, line in enumerate(lines):
if 'TopLevelModule' in line and 'init_cfg' not in line:
# have been set init_flag
assert 'the same' in line
def test_update_init_info(): def test_update_init_info():
...@@ -158,7 +229,6 @@ def test_update_init_info(): ...@@ -158,7 +229,6 @@ def test_update_init_info():
from collections import defaultdict from collections import defaultdict
model._params_init_info = defaultdict(dict) model._params_init_info = defaultdict(dict)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
model._params_init_info[param]['param_name'] = name
model._params_init_info[param]['init_info'] = 'init' model._params_init_info[param]['init_info'] = 'init'
model._params_init_info[param]['tmp_mean_value'] = param.data.mean() model._params_init_info[param]['tmp_mean_value'] = param.data.mean()
...@@ -172,6 +242,11 @@ def test_update_init_info(): ...@@ -172,6 +242,11 @@ def test_update_init_info():
assert item['init_info'] == 'fill_1' assert item['init_info'] == 'fill_1'
assert item['tmp_mean_value'] == 1 assert item['tmp_mean_value'] == 1
# test assert for new parameters
model.conv1.bias = nn.Parameter(torch.ones_like(model.conv1.bias))
with pytest.raises(AssertionError):
update_init_info(model, init_info=' ')
def test_model_weight_init(): def test_model_weight_init():
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment