Unverified Commit 17fa6670 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

[Features] Add logger for initialization of parameters (#1150)

* add logger for init

* change init_info of oevrload init_weight

* add judgement for params_init_info

* add delete comments for params_init_info

* add docstr and more comments

* add docstr and more comments

* resolve comments

* dump to a file

* add unitest

* fix unitest

* fix unitest

* write to ori log

* fix typo

* resolve commnets
parent ef48a473
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer') INITIALIZERS = Registry('initializer')
...@@ -122,6 +123,10 @@ class BaseInit(object): ...@@ -122,6 +123,10 @@ class BaseInit(object):
self.bias = bias self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self):
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Constant') @INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit): class ConstantInit(BaseInit):
...@@ -152,6 +157,12 @@ class ConstantInit(BaseInit): ...@@ -152,6 +157,12 @@ class ConstantInit(BaseInit):
constant_init(m, self.val, self.bias) constant_init(m, self.val, self.bias)
module.apply(init) module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Xavier') @INITIALIZERS.register_module(name='Xavier')
...@@ -189,6 +200,13 @@ class XavierInit(BaseInit): ...@@ -189,6 +200,13 @@ class XavierInit(BaseInit):
xavier_init(m, self.gain, self.bias, self.distribution) xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init) module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Normal') @INITIALIZERS.register_module(name='Normal')
...@@ -225,6 +243,13 @@ class NormalInit(BaseInit): ...@@ -225,6 +243,13 @@ class NormalInit(BaseInit):
normal_init(m, self.mean, self.std, self.bias) normal_init(m, self.mean, self.std, self.bias)
module.apply(init) module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='TruncNormal') @INITIALIZERS.register_module(name='TruncNormal')
...@@ -273,6 +298,13 @@ class TruncNormalInit(BaseInit): ...@@ -273,6 +298,13 @@ class TruncNormalInit(BaseInit):
self.bias) self.bias)
module.apply(init) module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Uniform') @INITIALIZERS.register_module(name='Uniform')
...@@ -309,6 +341,13 @@ class UniformInit(BaseInit): ...@@ -309,6 +341,13 @@ class UniformInit(BaseInit):
uniform_init(m, self.a, self.b, self.bias) uniform_init(m, self.a, self.b, self.bias)
module.apply(init) module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Kaiming') @INITIALIZERS.register_module(name='Kaiming')
...@@ -364,6 +403,14 @@ class KaimingInit(BaseInit): ...@@ -364,6 +403,14 @@ class KaimingInit(BaseInit):
self.bias, self.distribution) self.bias, self.distribution)
module.apply(init) module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Caffe2Xavier') @INITIALIZERS.register_module(name='Caffe2Xavier')
...@@ -422,6 +469,13 @@ class PretrainedInit(object): ...@@ -422,6 +469,13 @@ class PretrainedInit(object):
self.prefix, self.checkpoint, map_location=self.map_location) self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger) load_state_dict(module, state_dict, strict=False, logger=logger)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module, cfg, wholemodule=False): def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, INITIALIZERS) func = build_from_cfg(cfg, INITIALIZERS)
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import copy
import warnings import warnings
from abc import ABCMeta from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
import torch.nn as nn import torch.nn as nn
from mmcv import ConfigDict from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log
def update_init_info(module, *, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
for param in module.parameters():
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
class BaseModule(nn.Module, metaclass=ABCMeta): class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab.""" """Base module for all modules in openmmlab.
def __init__(self, init_cfg=None): ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
"""Initialize BaseModule, inherited from `torch.nn.Module` functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
Args: - ``init_cfg``: the config to control the initialization.
init_cfg (dict, optional): Initialization config dict. - ``_params_init_info``: Used to track the parameter
""" initialization information.
- ``init_weights``: The function of parameter
initialization and recording initialization
information.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg # NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority. # in low levels has a higher priority.
super(BaseModule, self).__init__() super(BaseModule, self).__init__()
# define default value of init_cfg instead of hard code # define default value of init_cfg instead of hard code
# in init_weight() function # in init_weights() function
self._is_init = False self._is_init = False
self.init_cfg = init_cfg
self.init_cfg = copy.deepcopy(init_cfg)
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - param_name (str): The name of parameter.
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters is initialized.
self._params_init_info = defaultdict(dict)
# Backward compatibility in derived classes # Backward compatibility in derived classes
# if pretrained is not None: # if pretrained is not None:
...@@ -38,26 +82,100 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -38,26 +82,100 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def init_weights(self): def init_weights(self):
"""Initialize the weights.""" """Initialize the weights."""
from ..cnn import initialize
# check if it is top-level module
is_top_level_module = len(self._params_init_info) == 0
if is_top_level_module:
# Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related
# initialization information
for name, param in self.named_parameters():
self._params_init_info[param]['param_name'] = name
self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()
# pass `params_init_info` to all submodules
# All submodules share the same `params_init_info`,
# so it will be updated when parameters are
# modified at any level of the model.
for sub_module in self.modules():
sub_module._params_init_info = self._params_init_info
# Get the initialized logger, if not exist,
# create a logger named `mmcv`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize
module_name = self.__class__.__name__
if not self._is_init: if not self._is_init:
if self.init_cfg: if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name)
initialize(self, self.init_cfg) initialize(self, self.init_cfg)
if isinstance(self.init_cfg, (dict, ConfigDict)): if isinstance(self.init_cfg, dict):
# Avoid the parameters of the pre-training model # prevent the parameters of
# being overwritten by the init_weights # the pre-trained model
# of the children. # from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained': if self.init_cfg['type'] == 'Pretrained':
return return
for m in self.children(): for m in self.children():
if hasattr(m, 'init_weights'): if hasattr(m, 'init_weights'):
m.init_weights() m.init_weights()
# users may overload the `init_weights`
update_init_info(
m,
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')
self._is_init = True self._is_init = True
else: else:
warnings.warn(f'init_weights of {self.__class__.__name__} has ' warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.') f'been called more than once.')
if is_top_level_module:
self._dump_init_info(logger_name)
for sub_module in self.modules():
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name):
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Args:
logger_name (str): The name of logger.
"""
logger = get_logger(logger_name)
with_file_handler = False
# dump the information to the logger file if there is a `FileHandler`
for handler in logger.handlers:
if isinstance(handler, FileHandler):
handler.stream.write(
'Name of parameter - Initialization information\n')
for item in list(self._params_init_info.values()):
handler.stream.write(
f"{item['param_name']} - {item['init_info']} \n")
handler.stream.flush()
with_file_handler = True
if not with_file_handler:
for item in list(self._params_init_info.values()):
print_log(
f"{item['param_name']} - {item['init_info']}",
logger=logger_name)
def __repr__(self): def __repr__(self):
s = super().__repr__() s = super().__repr__()
if self.init_cfg: if self.init_cfg:
......
import tempfile
import torch import torch
from torch import nn from torch import nn
from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component') COMPONENTS = Registry('component')
...@@ -80,6 +83,96 @@ class FooModel(BaseModule): ...@@ -80,6 +83,96 @@ class FooModel(BaseModule):
self.reg = nn.Linear(3, 4) self.reg = nn.Linear(3, 4)
def test_initilization_info_logger():
# 'override' has higher priority
import torch.nn as nn
from mmcv.utils.logging import get_logger
import os
class OverloadInitConv(nn.Conv2d, BaseModule):
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConv(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
model = CheckLoggerModel(init_cfg=init_cfg)
train_log = '20210720_132454.log'
workdir = tempfile.mkdtemp()
log_file = os.path.join(workdir, train_log)
# create a logger
get_logger('init_logger', log_file=log_file)
assert hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
with open(log_file) as f:
lines = f.readlines()
for line in lines:
print(line)
# check initialization information is right
for line in lines:
if 'conv1.weight' in line:
assert 'NormalInit' in line
if 'conv2.weight' in line:
assert 'OverloadInitConv' in line
if 'fc1.weight' in line:
assert 'ConstantInit' in line
def test_update_init_info():
class DummyModel(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
model = DummyModel()
from collections import defaultdict
model._params_init_info = defaultdict(dict)
for name, param in model.named_parameters():
model._params_init_info[param]['param_name'] = name
model._params_init_info[param]['init_info'] = 'init'
model._params_init_info[param]['tmp_mean_value'] = param.data.mean()
with torch.no_grad():
for p in model.parameters():
p.fill_(1)
update_init_info(model, init_info='fill_1')
for item in model._params_init_info.values():
assert item['init_info'] == 'fill_1'
assert item['tmp_mean_value'] == 1
def test_model_weight_init(): def test_model_weight_init():
""" """
Config Config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment