Unverified Commit 999f2d08 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

fix weight_init.py (#825)

* fix weight_init.py

* revise BaseInit args
parent 4712db75
...@@ -77,7 +77,7 @@ def bias_init_with_prob(prior_prob): ...@@ -77,7 +77,7 @@ def bias_init_with_prob(prior_prob):
class BaseInit(object): class BaseInit(object):
def __init__(self, bias, bias_prob, layer): def __init__(self, *, bias=0, bias_prob=None, layer=None):
if not isinstance(bias, (int, float)): if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a numbel, but got a {type(bias)}') raise TypeError(f'bias must be a numbel, but got a {type(bias)}')
...@@ -88,7 +88,7 @@ class BaseInit(object): ...@@ -88,7 +88,7 @@ class BaseInit(object):
if layer is not None: if layer is not None:
if not isinstance(layer, (str, list)): if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be str or list[str], \ raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}') but got a {type(layer)}')
if bias_prob is not None: if bias_prob is not None:
...@@ -112,8 +112,8 @@ class ConstantInit(BaseInit): ...@@ -112,8 +112,8 @@ class ConstantInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, val, bias=0, bias_prob=None, layer=None): def __init__(self, val, **kwargs):
super().__init__(bias, bias_prob, layer) super().__init__(**kwargs)
self.val = val self.val = val
def __call__(self, module): def __call__(self, module):
...@@ -149,13 +149,8 @@ class XavierInit(BaseInit): ...@@ -149,13 +149,8 @@ class XavierInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, def __init__(self, gain=1, distribution='normal', **kwargs):
gain=1, super().__init__(**kwargs)
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
self.gain = gain self.gain = gain
self.distribution = distribution self.distribution = distribution
...@@ -191,8 +186,8 @@ class NormalInit(BaseInit): ...@@ -191,8 +186,8 @@ class NormalInit(BaseInit):
""" """
def __init__(self, mean=0, std=1, bias=0, bias_prob=None, layer=None): def __init__(self, mean=0, std=1, **kwargs):
super().__init__(bias, bias_prob, layer) super().__init__(**kwargs)
self.mean = mean self.mean = mean
self.std = std self.std = std
...@@ -228,8 +223,8 @@ class UniformInit(BaseInit): ...@@ -228,8 +223,8 @@ class UniformInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, a=0, b=1, bias=0, bias_prob=None, layer=None): def __init__(self, a=0, b=1, **kwargs):
super().__init__(bias, bias_prob, layer) super().__init__(**kwargs)
self.a = a self.a = a
self.b = b self.b = b
...@@ -279,11 +274,9 @@ class KaimingInit(BaseInit): ...@@ -279,11 +274,9 @@ class KaimingInit(BaseInit):
a=0, a=0,
mode='fan_out', mode='fan_out',
nonlinearity='relu', nonlinearity='relu',
bias=0,
bias_prob=None,
distribution='normal', distribution='normal',
layer=None): **kwargs):
super().__init__(bias, bias_prob, layer) super().__init__(**kwargs)
self.a = a self.a = a
self.mode = mode self.mode = mode
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
...@@ -307,10 +300,15 @@ class KaimingInit(BaseInit): ...@@ -307,10 +300,15 @@ class KaimingInit(BaseInit):
@INITIALIZERS.register_module(name='Pretrained') @INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object): class PretrainedInit(object):
"""Initialize module by loading a pretrained model """Initialize module by loading a pretrained model.
Args: Args:
checkpoint (str): the file should be load checkpoint (str): the checkpoint file of the pretrained model should
prefix (str, optional): the prefix to indicate the sub-module. be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None. Defaults to None.
""" """
...@@ -347,8 +345,8 @@ def _initialize(module, cfg): ...@@ -347,8 +345,8 @@ def _initialize(module, cfg):
def _initialize_override(module, override): def _initialize_override(module, override):
if not isinstance(override, (dict, list)): if not isinstance(override, (dict, list)):
raise TypeError( raise TypeError(f'override must be a dict or a list of dict, \
f'override must be a dict or list, but got {type(override)}') but got {type(override)}')
override = [override] if isinstance(override, dict) else override override = [override] if isinstance(override, dict) else override
...@@ -366,10 +364,9 @@ def initialize(module, init_cfg): ...@@ -366,10 +364,9 @@ def initialize(module, init_cfg):
Args: Args:
module (``torch.nn.Module``): the module will be initialized. module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 7 initializers define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, ``Pretrained`` and ``BiasProb`` for bias ``Kaiming``, and ``Pretrained``.
initialization.
Example: Example:
>>> module = nn.Linear(2, 3, bias=True) >>> module = nn.Linear(2, 3, bias=True)
...@@ -415,7 +412,8 @@ def initialize(module, init_cfg): ...@@ -415,7 +412,8 @@ def initialize(module, init_cfg):
checkpoint=url, prefix='backbone.') checkpoint=url, prefix='backbone.')
""" """
if not isinstance(init_cfg, (dict, list)): if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict, but got {type(init_cfg)}') raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')
if isinstance(init_cfg, dict): if isinstance(init_cfg, dict):
init_cfg = [init_cfg] init_cfg = [init_cfg]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment