Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
999f2d08
Unverified
Commit
999f2d08
authored
Feb 07, 2021
by
Miao Zheng
Committed by
GitHub
Feb 07, 2021
Browse files
fix weight_init.py (#825)
* fix weight_init.py * revise BaseInit args
parent
4712db75
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
28 deletions
+26
-28
mmcv/cnn/utils/weight_init.py
mmcv/cnn/utils/weight_init.py
+26
-28
No files found.
mmcv/cnn/utils/weight_init.py
View file @
999f2d08
...
@@ -77,7 +77,7 @@ def bias_init_with_prob(prior_prob):
...
@@ -77,7 +77,7 @@ def bias_init_with_prob(prior_prob):
class
BaseInit
(
object
):
class
BaseInit
(
object
):
def
__init__
(
self
,
bias
,
bias_prob
,
layer
):
def
__init__
(
self
,
*
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
if
not
isinstance
(
bias
,
(
int
,
float
)):
if
not
isinstance
(
bias
,
(
int
,
float
)):
raise
TypeError
(
f
'bias must be a numbel, but got a
{
type
(
bias
)
}
'
)
raise
TypeError
(
f
'bias must be a numbel, but got a
{
type
(
bias
)
}
'
)
...
@@ -88,7 +88,7 @@ class BaseInit(object):
...
@@ -88,7 +88,7 @@ class BaseInit(object):
if
layer
is
not
None
:
if
layer
is
not
None
:
if
not
isinstance
(
layer
,
(
str
,
list
)):
if
not
isinstance
(
layer
,
(
str
,
list
)):
raise
TypeError
(
f
'layer must be str or list
[
str
]
,
\
raise
TypeError
(
f
'layer must be
a
str or
a
list
of
str,
\
but got a
{
type
(
layer
)
}
'
)
but got a
{
type
(
layer
)
}
'
)
if
bias_prob
is
not
None
:
if
bias_prob
is
not
None
:
...
@@ -112,8 +112,8 @@ class ConstantInit(BaseInit):
...
@@ -112,8 +112,8 @@ class ConstantInit(BaseInit):
Defaults to None.
Defaults to None.
"""
"""
def
__init__
(
self
,
val
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
def
__init__
(
self
,
val
,
**
kwargs
):
super
().
__init__
(
bias
,
bias_prob
,
layer
)
super
().
__init__
(
**
kwargs
)
self
.
val
=
val
self
.
val
=
val
def
__call__
(
self
,
module
):
def
__call__
(
self
,
module
):
...
@@ -149,13 +149,8 @@ class XavierInit(BaseInit):
...
@@ -149,13 +149,8 @@ class XavierInit(BaseInit):
Defaults to None.
Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
gain
=
1
,
distribution
=
'normal'
,
**
kwargs
):
gain
=
1
,
super
().
__init__
(
**
kwargs
)
bias
=
0
,
bias_prob
=
None
,
distribution
=
'normal'
,
layer
=
None
):
super
().
__init__
(
bias
,
bias_prob
,
layer
)
self
.
gain
=
gain
self
.
gain
=
gain
self
.
distribution
=
distribution
self
.
distribution
=
distribution
...
@@ -191,8 +186,8 @@ class NormalInit(BaseInit):
...
@@ -191,8 +186,8 @@ class NormalInit(BaseInit):
"""
"""
def
__init__
(
self
,
mean
=
0
,
std
=
1
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
def
__init__
(
self
,
mean
=
0
,
std
=
1
,
**
kwargs
):
super
().
__init__
(
bias
,
bias_prob
,
layer
)
super
().
__init__
(
**
kwargs
)
self
.
mean
=
mean
self
.
mean
=
mean
self
.
std
=
std
self
.
std
=
std
...
@@ -228,8 +223,8 @@ class UniformInit(BaseInit):
...
@@ -228,8 +223,8 @@ class UniformInit(BaseInit):
Defaults to None.
Defaults to None.
"""
"""
def
__init__
(
self
,
a
=
0
,
b
=
1
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
def
__init__
(
self
,
a
=
0
,
b
=
1
,
**
kwargs
):
super
().
__init__
(
bias
,
bias_prob
,
layer
)
super
().
__init__
(
**
kwargs
)
self
.
a
=
a
self
.
a
=
a
self
.
b
=
b
self
.
b
=
b
...
@@ -279,11 +274,9 @@ class KaimingInit(BaseInit):
...
@@ -279,11 +274,9 @@ class KaimingInit(BaseInit):
a
=
0
,
a
=
0
,
mode
=
'fan_out'
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
,
nonlinearity
=
'relu'
,
bias
=
0
,
bias_prob
=
None
,
distribution
=
'normal'
,
distribution
=
'normal'
,
layer
=
None
):
**
kwargs
):
super
().
__init__
(
bias
,
bias_prob
,
layer
)
super
().
__init__
(
**
kwargs
)
self
.
a
=
a
self
.
a
=
a
self
.
mode
=
mode
self
.
mode
=
mode
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
...
@@ -307,10 +300,15 @@ class KaimingInit(BaseInit):
...
@@ -307,10 +300,15 @@ class KaimingInit(BaseInit):
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
class
PretrainedInit
(
object
):
class
PretrainedInit
(
object
):
"""Initialize module by loading a pretrained model
"""Initialize module by loading a pretrained model.
Args:
Args:
checkpoint (str): the file should be load
checkpoint (str): the checkpoint file of the pretrained model should
prefix (str, optional): the prefix to indicate the sub-module.
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
Defaults to None.
"""
"""
...
@@ -347,8 +345,8 @@ def _initialize(module, cfg):
...
@@ -347,8 +345,8 @@ def _initialize(module, cfg):
def
_initialize_override
(
module
,
override
):
def
_initialize_override
(
module
,
override
):
if
not
isinstance
(
override
,
(
dict
,
list
)):
if
not
isinstance
(
override
,
(
dict
,
list
)):
raise
TypeError
(
raise
TypeError
(
f
'override must be a dict or a list of dict,
\
f
'override must be a dict or list,
but got
{
type
(
override
)
}
'
)
but got
{
type
(
override
)
}
'
)
override
=
[
override
]
if
isinstance
(
override
,
dict
)
else
override
override
=
[
override
]
if
isinstance
(
override
,
dict
)
else
override
...
@@ -366,10 +364,9 @@ def initialize(module, init_cfg):
...
@@ -366,10 +364,9 @@ def initialize(module, init_cfg):
Args:
Args:
module (``torch.nn.Module``): the module will be initialized.
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented
7
initializers
define initializer. OpenMMLab has implemented
6
initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, ``Pretrained`` and ``BiasProb`` for bias
``Kaiming``, and ``Pretrained``.
initialization.
Example:
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> module = nn.Linear(2, 3, bias=True)
...
@@ -415,7 +412,8 @@ def initialize(module, init_cfg):
...
@@ -415,7 +412,8 @@ def initialize(module, init_cfg):
checkpoint=url, prefix='backbone.')
checkpoint=url, prefix='backbone.')
"""
"""
if
not
isinstance
(
init_cfg
,
(
dict
,
list
)):
if
not
isinstance
(
init_cfg
,
(
dict
,
list
)):
raise
TypeError
(
f
'init_cfg must be a dict, but got
{
type
(
init_cfg
)
}
'
)
raise
TypeError
(
f
'init_cfg must be a dict or a list of dict,
\
but got
{
type
(
init_cfg
)
}
'
)
if
isinstance
(
init_cfg
,
dict
):
if
isinstance
(
init_cfg
,
dict
):
init_cfg
=
[
init_cfg
]
init_cfg
=
[
init_cfg
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment