Unverified Commit 999f2d08 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

fix weight_init.py (#825)

* fix weight_init.py

* revise BaseInit args
parent 4712db75
......@@ -77,7 +77,7 @@ def bias_init_with_prob(prior_prob):
class BaseInit(object):
def __init__(self, bias, bias_prob, layer):
def __init__(self, *, bias=0, bias_prob=None, layer=None):
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a numbel, but got a {type(bias)}')
......@@ -88,7 +88,7 @@ class BaseInit(object):
if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be str or list[str], \
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')
if bias_prob is not None:
......@@ -112,8 +112,8 @@ class ConstantInit(BaseInit):
Defaults to None.
"""
def __init__(self, val, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, val, **kwargs):
super().__init__(**kwargs)
self.val = val
def __call__(self, module):
......@@ -149,13 +149,8 @@ class XavierInit(BaseInit):
Defaults to None.
"""
def __init__(self,
gain=1,
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, gain=1, distribution='normal', **kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution
......@@ -191,8 +186,8 @@ class NormalInit(BaseInit):
"""
def __init__(self, mean=0, std=1, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, mean=0, std=1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
......@@ -228,8 +223,8 @@ class UniformInit(BaseInit):
Defaults to None.
"""
def __init__(self, a=0, b=1, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
def __init__(self, a=0, b=1, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
......@@ -279,11 +274,9 @@ class KaimingInit(BaseInit):
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
**kwargs):
super().__init__(**kwargs)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
......@@ -307,10 +300,15 @@ class KaimingInit(BaseInit):
@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
"""Initialize module by loading a pretrained model
"""Initialize module by loading a pretrained model.
Args:
checkpoint (str): the file should be load
prefix (str, optional): the prefix to indicate the sub-module.
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
"""
......@@ -347,8 +345,8 @@ def _initialize(module, cfg):
def _initialize_override(module, override):
if not isinstance(override, (dict, list)):
raise TypeError(
f'override must be a dict or list, but got {type(override)}')
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')
override = [override] if isinstance(override, dict) else override
......@@ -366,10 +364,9 @@ def initialize(module, init_cfg):
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 7 initializers
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, ``Pretrained`` and ``BiasProb`` for bias
initialization.
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
......@@ -415,7 +412,8 @@ def initialize(module, init_cfg):
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict, but got {type(init_cfg)}')
raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')
if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment