Unverified Commit 17fa6670 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

[Features] Add logger for initialization of parameters (#1150)

* add logger for init

* change init_info of oevrload init_weight

* add judgement for params_init_info

* add delete comments for params_init_info

* add docstr and more comments

* add docstr and more comments

* resolve comments

* dump to a file

* add unitest

* fix unitest

* fix unitest

* write to ori log

* fix typo

* resolve commnets
parent ef48a473
......@@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer')
......@@ -122,6 +123,10 @@ class BaseInit(object):
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self):
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
......@@ -152,6 +157,12 @@ class ConstantInit(BaseInit):
constant_init(m, self.val, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Xavier')
......@@ -189,6 +200,13 @@ class XavierInit(BaseInit):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Normal')
......@@ -225,6 +243,13 @@ class NormalInit(BaseInit):
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='TruncNormal')
......@@ -273,6 +298,13 @@ class TruncNormalInit(BaseInit):
self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Uniform')
......@@ -309,6 +341,13 @@ class UniformInit(BaseInit):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Kaiming')
......@@ -364,6 +403,14 @@ class KaimingInit(BaseInit):
self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Caffe2Xavier')
......@@ -422,6 +469,13 @@ class PretrainedInit(object):
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, INITIALIZERS)
......
# Copyright (c) Open-MMLab. All rights reserved.
import copy
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
import torch.nn as nn
from mmcv import ConfigDict
from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log
def update_init_info(module, *, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
for param in module.parameters():
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab."""
"""Base module for all modules in openmmlab.
def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
- ``init_cfg``: the config to control the initialization.
- ``_params_init_info``: Used to track the parameter
initialization information.
- ``init_weights``: The function of parameter
initialization and recording initialization
information.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super(BaseModule, self).__init__()
# define default value of init_cfg instead of hard code
# in init_weight() function
# in init_weights() function
self._is_init = False
self.init_cfg = init_cfg
self.init_cfg = copy.deepcopy(init_cfg)
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - param_name (str): The name of parameter.
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters is initialized.
self._params_init_info = defaultdict(dict)
# Backward compatibility in derived classes
# if pretrained is not None:
......@@ -38,26 +82,100 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def init_weights(self):
"""Initialize the weights."""
from ..cnn import initialize
# check if it is top-level module
is_top_level_module = len(self._params_init_info) == 0
if is_top_level_module:
# Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related
# initialization information
for name, param in self.named_parameters():
self._params_init_info[param]['param_name'] = name
self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()
# pass `params_init_info` to all submodules
# All submodules share the same `params_init_info`,
# so it will be updated when parameters are
# modified at any level of the model.
for sub_module in self.modules():
sub_module._params_init_info = self._params_init_info
# Get the initialized logger, if not exist,
# create a logger named `mmcv`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name)
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, (dict, ConfigDict)):
# Avoid the parameters of the pre-training model
# being overwritten by the init_weights
# of the children.
if isinstance(self.init_cfg, dict):
# prevent the parameters of
# the pre-trained model
# from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
return
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
# users may overload the `init_weights`
update_init_info(
m,
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
if is_top_level_module:
self._dump_init_info(logger_name)
for sub_module in self.modules():
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name):
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Args:
logger_name (str): The name of logger.
"""
logger = get_logger(logger_name)
with_file_handler = False
# dump the information to the logger file if there is a `FileHandler`
for handler in logger.handlers:
if isinstance(handler, FileHandler):
handler.stream.write(
'Name of parameter - Initialization information\n')
for item in list(self._params_init_info.values()):
handler.stream.write(
f"{item['param_name']} - {item['init_info']} \n")
handler.stream.flush()
with_file_handler = True
if not with_file_handler:
for item in list(self._params_init_info.values()):
print_log(
f"{item['param_name']} - {item['init_info']}",
logger=logger_name)
def __repr__(self):
s = super().__repr__()
if self.init_cfg:
......
import tempfile
import torch
from torch import nn
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component')
......@@ -80,6 +83,96 @@ class FooModel(BaseModule):
self.reg = nn.Linear(3, 4)
def test_initilization_info_logger():
# 'override' has higher priority
import torch.nn as nn
from mmcv.utils.logging import get_logger
import os
class OverloadInitConv(nn.Conv2d, BaseModule):
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConv(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
model = CheckLoggerModel(init_cfg=init_cfg)
train_log = '20210720_132454.log'
workdir = tempfile.mkdtemp()
log_file = os.path.join(workdir, train_log)
# create a logger
get_logger('init_logger', log_file=log_file)
assert hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
with open(log_file) as f:
lines = f.readlines()
for line in lines:
print(line)
# check initialization information is right
for line in lines:
if 'conv1.weight' in line:
assert 'NormalInit' in line
if 'conv2.weight' in line:
assert 'OverloadInitConv' in line
if 'fc1.weight' in line:
assert 'ConstantInit' in line
def test_update_init_info():
class DummyModel(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
model = DummyModel()
from collections import defaultdict
model._params_init_info = defaultdict(dict)
for name, param in model.named_parameters():
model._params_init_info[param]['param_name'] = name
model._params_init_info[param]['init_info'] = 'init'
model._params_init_info[param]['tmp_mean_value'] = param.data.mean()
with torch.no_grad():
for p in model.parameters():
p.fill_(1)
update_init_info(model, init_info='fill_1')
for item in model._params_init_info.values():
assert item['init_info'] == 'fill_1'
assert item['tmp_mean_value'] == 1
def test_model_weight_init():
"""
Config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment