Unverified Commit 18c64d5f authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

Fix potential bugs of basemodule when record the initilization information (#1217)

* add logger for init

* change init_info of oevrload init_weight

* add judgement for params_init_info

* add delete comments for params_init_info

* add docstr and more comments

* add docstr and more comments

* resolve comments

* dump to a file

* add unitest

* fix unitest

* fix unitest

* write to ori log

* fix typo

* resolve commnets

* fix call initweights twice in topmost module

* fix the potential bug of recursive import

* fix unitest

* fix potiential bugs

* remove unneccesary change

* add more unitest

* fix add param in initweights

* add more comments

* raise error

* add more detail assert error
parent b5a4bbd0
...@@ -8,12 +8,43 @@ import torch ...@@ -8,12 +8,43 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer') INITIALIZERS = Registry('initializer')
def update_init_info(module, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert hasattr(
module,
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
for name, param in module.named_parameters():
assert param in module._params_init_info, (
f'Find a new :obj:`Parameter` '
f'named `{name}` during executing the '
f'`init_weights` of '
f'`{module.__class__.__name__}`. '
f'Please do not add or '
f'replace parameters during executing '
f'the `init_weights`. ')
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0): def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val) nn.init.constant_(module.weight, val)
......
...@@ -11,23 +11,6 @@ from mmcv.runner.dist_utils import master_only ...@@ -11,23 +11,6 @@ from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log from mmcv.utils.logging import get_logger, logger_initialized, print_log
def update_init_info(module, *, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
for param in module.parameters():
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
class BaseModule(nn.Module, metaclass=ABCMeta): class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab. """Base module for all modules in openmmlab.
...@@ -36,11 +19,12 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -36,11 +19,12 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization. - ``init_cfg``: the config to control the initialization.
- ``_params_init_info``: Used to track the parameter
initialization information.
- ``init_weights``: The function of parameter - ``init_weights``: The function of parameter
initialization and recording initialization initialization and recording initialization
information. information.
- ``_params_init_info``: Used to track the parameter
initialization information. This attribute only
exists during executing the ``init_weights``.
Args: Args:
init_cfg (dict, optional): Initialization config dict. init_cfg (dict, optional): Initialization config dict.
...@@ -59,17 +43,6 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -59,17 +43,6 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
self.init_cfg = copy.deepcopy(init_cfg) self.init_cfg = copy.deepcopy(init_cfg)
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - param_name (str): The name of parameter.
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters is initialized.
self._params_init_info = defaultdict(dict)
# Backward compatibility in derived classes # Backward compatibility in derived classes
# if pretrained is not None: # if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
...@@ -83,15 +56,26 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -83,15 +56,26 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def init_weights(self): def init_weights(self):
"""Initialize the weights.""" """Initialize the weights."""
is_top_level_module = False
# check if it is top-level module # check if it is top-level module
is_top_level_module = len(self._params_init_info) == 0 if not hasattr(self, '_params_init_info'):
if is_top_level_module: # The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info = defaultdict(dict)
is_top_level_module = True
# Initialize the `_params_init_info`, # Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of # When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related # the corresponding parameter is changed, update related
# initialization information # initialization information
for name, param in self.named_parameters(): for name, param in self.named_parameters():
self._params_init_info[param]['param_name'] = name
self._params_init_info[param][ self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \ 'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \ f'after calling `init_weights` ' \
...@@ -112,6 +96,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -112,6 +96,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
logger_name = logger_names[0] if logger_names else 'mmcv' logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize from ..cnn import initialize
from ..cnn.utils.weight_init import update_init_info
module_name = self.__class__.__name__ module_name = self.__class__.__name__
if not self._is_init: if not self._is_init:
if self.init_cfg: if self.init_cfg:
...@@ -165,15 +150,17 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -165,15 +150,17 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
if isinstance(handler, FileHandler): if isinstance(handler, FileHandler):
handler.stream.write( handler.stream.write(
'Name of parameter - Initialization information\n') 'Name of parameter - Initialization information\n')
for item in list(self._params_init_info.values()): for name, param in self.named_parameters():
handler.stream.write( handler.stream.write(
f"{item['param_name']} - {item['init_info']} \n") f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n")
handler.stream.flush() handler.stream.flush()
with_file_handler = True with_file_handler = True
if not with_file_handler: if not with_file_handler:
for item in list(self._params_init_info.values()): for name, param in self.named_parameters():
print_log( print_log(
f"{item['param_name']} - {item['init_info']}", f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n ",
logger=logger_name) logger=logger_name)
def __repr__(self): def __repr__(self):
......
import tempfile import tempfile
import pytest
import torch import torch
from torch import nn from torch import nn
import mmcv
from mmcv.cnn.utils.weight_init import update_init_info
from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component') COMPONENTS = Registry('component')
...@@ -123,25 +125,94 @@ def test_initilization_info_logger(): ...@@ -123,25 +125,94 @@ def test_initilization_info_logger():
log_file = os.path.join(workdir, train_log) log_file = os.path.join(workdir, train_log)
# create a logger # create a logger
get_logger('init_logger', log_file=log_file) get_logger('init_logger', log_file=log_file)
assert hasattr(model, '_params_init_info') assert not hasattr(model, '_params_init_info')
model.init_weights() model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights` # assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info') assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped # assert initialization information has been dumped
assert os.path.exists(log_file) assert os.path.exists(log_file)
with open(log_file) as f: lines = mmcv.list_from_file(log_file)
lines = f.readlines()
for line in lines:
print(line)
# check initialization information is right # check initialization information is right
for line in lines: for i, line in enumerate(lines):
if 'conv1.weight' in line: if 'conv1.weight' in line:
assert 'NormalInit' in line assert 'NormalInit' in lines[i + 1]
if 'conv2.weight' in line: if 'conv2.weight' in line:
assert 'OverloadInitConv' in line assert 'OverloadInitConv' in lines[i + 1]
if 'fc1.weight' in line: if 'fc1.weight' in line:
assert 'ConstantInit' in line assert 'ConstantInit' in lines[i + 1]
# test corner case
class OverloadInitConvFc(nn.Conv2d, BaseModule):
def __init__(self, *args, **kwargs):
super(OverloadInitConvFc, self).__init__(*args, **kwargs)
self.conv1 = nn.Linear(1, 1)
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConvFc(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
class TopLevelModule(BaseModule):
def __init__(self, init_cfg=None, checklog_init_cfg=None):
super(TopLevelModule, self).__init__(init_cfg)
self.module1 = CheckLoggerModel(checklog_init_cfg)
self.module2 = OverloadInitConvFc(1, 1, 1, 1)
checklog_init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
top_level_init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='module2', std=0.01, bias_prob=0.01))
]
model = TopLevelModule(
init_cfg=top_level_init_cfg, checklog_init_cfg=checklog_init_cfg)
model.module1.init_weights()
model.module2.init_weights()
model.init_weights()
model.module1.init_weights()
model.module2.init_weights()
assert not hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
lines = mmcv.list_from_file(log_file)
# check initialization information is right
for i, line in enumerate(lines):
if 'TopLevelModule' in line and 'init_cfg' not in line:
# have been set init_flag
assert 'the same' in line
def test_update_init_info(): def test_update_init_info():
...@@ -158,7 +229,6 @@ def test_update_init_info(): ...@@ -158,7 +229,6 @@ def test_update_init_info():
from collections import defaultdict from collections import defaultdict
model._params_init_info = defaultdict(dict) model._params_init_info = defaultdict(dict)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
model._params_init_info[param]['param_name'] = name
model._params_init_info[param]['init_info'] = 'init' model._params_init_info[param]['init_info'] = 'init'
model._params_init_info[param]['tmp_mean_value'] = param.data.mean() model._params_init_info[param]['tmp_mean_value'] = param.data.mean()
...@@ -172,6 +242,11 @@ def test_update_init_info(): ...@@ -172,6 +242,11 @@ def test_update_init_info():
assert item['init_info'] == 'fill_1' assert item['init_info'] == 'fill_1'
assert item['tmp_mean_value'] == 1 assert item['tmp_mean_value'] == 1
# test assert for new parameters
model.conv1.bias = nn.Parameter(torch.ones_like(model.conv1.bias))
with pytest.raises(AssertionError):
update_init_info(model, init_info=' ')
def test_model_weight_init(): def test_model_weight_init():
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment