Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
5f5e8e83
Unverified
Commit
5f5e8e83
authored
Mar 24, 2021
by
Miao Zheng
Committed by
GitHub
Mar 24, 2021
Browse files
[Refactoring] Add Caffe2Xavier Initializer (#902)
* [Refactoring] Add Caffe2Xavier Initializer * fix lint
parent
933b052d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
11 deletions
+38
-11
mmcv/cnn/__init__.py
mmcv/cnn/__init__.py
+4
-3
mmcv/cnn/utils/__init__.py
mmcv/cnn/utils/__init__.py
+4
-4
mmcv/cnn/utils/weight_init.py
mmcv/cnn/utils/weight_init.py
+16
-0
tests/test_cnn/test_weight_init.py
tests/test_cnn/test_weight_init.py
+14
-4
No files found.
mmcv/cnn/__init__.py
View file @
5f5e8e83
...
@@ -13,8 +13,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
...
@@ -13,8 +13,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
build_upsample_layer
,
conv_ws_2d
,
is_norm
)
build_upsample_layer
,
conv_ws_2d
,
is_norm
)
# yapf: enable
# yapf: enable
from
.resnet
import
ResNet
,
make_res_layer
from
.resnet
import
ResNet
,
make_res_layer
from
.utils
import
(
INITIALIZERS
,
ConstantInit
,
KaimingInit
,
NormalInit
,
from
.utils
import
(
INITIALIZERS
,
Caffe2XavierInit
,
ConstantInit
,
KaimingInit
,
PretrainedInit
,
UniformInit
,
XavierInit
,
NormalInit
,
PretrainedInit
,
UniformInit
,
XavierInit
,
bias_init_with_prob
,
caffe2_xavier_init
,
constant_init
,
bias_init_with_prob
,
caffe2_xavier_init
,
constant_init
,
fuse_conv_bn
,
get_model_complexity_info
,
initialize
,
fuse_conv_bn
,
get_model_complexity_info
,
initialize
,
kaiming_init
,
normal_init
,
uniform_init
,
xavier_init
)
kaiming_init
,
normal_init
,
uniform_init
,
xavier_init
)
...
@@ -33,5 +33,6 @@ __all__ = [
...
@@ -33,5 +33,6 @@ __all__ = [
'ConvAWS2d'
,
'ConvWS2d'
,
'fuse_conv_bn'
,
'DepthwiseSeparableConvModule'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'fuse_conv_bn'
,
'DepthwiseSeparableConvModule'
,
'Linear'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
,
'Linear'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
,
'MaxPool3d'
,
'Conv3d'
,
'initialize'
,
'INITIALIZERS'
,
'ConstantInit'
,
'MaxPool3d'
,
'Conv3d'
,
'initialize'
,
'INITIALIZERS'
,
'ConstantInit'
,
'XavierInit'
,
'NormalInit'
,
'UniformInit'
,
'KaimingInit'
,
'PretrainedInit'
'XavierInit'
,
'NormalInit'
,
'UniformInit'
,
'KaimingInit'
,
'PretrainedInit'
,
'Caffe2XavierInit'
]
]
mmcv/cnn/utils/__init__.py
View file @
5f5e8e83
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
.flops_counter
import
get_model_complexity_info
from
.flops_counter
import
get_model_complexity_info
from
.fuse_conv_bn
import
fuse_conv_bn
from
.fuse_conv_bn
import
fuse_conv_bn
from
.weight_init
import
(
INITIALIZERS
,
C
onstantInit
,
KaimingInit
,
Normal
Init
,
from
.weight_init
import
(
INITIALIZERS
,
C
affe2XavierInit
,
Constant
Init
,
PretrainedInit
,
UniformInit
,
XavierInit
,
KaimingInit
,
NormalInit
,
PretrainedInit
,
UniformInit
,
bias_init_with_prob
,
caffe2_xavier_init
,
XavierInit
,
bias_init_with_prob
,
caffe2_xavier_init
,
constant_init
,
initialize
,
kaiming_init
,
normal_init
,
constant_init
,
initialize
,
kaiming_init
,
normal_init
,
uniform_init
,
xavier_init
)
uniform_init
,
xavier_init
)
...
@@ -12,5 +12,5 @@ __all__ = [
...
@@ -12,5 +12,5 @@ __all__ = [
'constant_init'
,
'kaiming_init'
,
'normal_init'
,
'uniform_init'
,
'constant_init'
,
'kaiming_init'
,
'normal_init'
,
'uniform_init'
,
'xavier_init'
,
'fuse_conv_bn'
,
'initialize'
,
'INITIALIZERS'
,
'xavier_init'
,
'fuse_conv_bn'
,
'initialize'
,
'INITIALIZERS'
,
'ConstantInit'
,
'XavierInit'
,
'NormalInit'
,
'UniformInit'
,
'KaimingInit'
,
'ConstantInit'
,
'XavierInit'
,
'NormalInit'
,
'UniformInit'
,
'KaimingInit'
,
'PretrainedInit'
'PretrainedInit'
,
'Caffe2XavierInit'
]
]
mmcv/cnn/utils/weight_init.py
View file @
5f5e8e83
...
@@ -298,6 +298,22 @@ class KaimingInit(BaseInit):
...
@@ -298,6 +298,22 @@ class KaimingInit(BaseInit):
module
.
apply
(
init
)
module
.
apply
(
init
)
@
INITIALIZERS
.
register_module
(
name
=
'Caffe2Xavier'
)
class
Caffe2XavierInit
(
KaimingInit
):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
a
=
1
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
,
distribution
=
'uniform'
,
**
kwargs
)
def
__call__
(
self
,
module
):
super
().
__call__
(
module
)
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
class
PretrainedInit
(
object
):
class
PretrainedInit
(
object
):
"""Initialize module by loading a pretrained model.
"""Initialize module by loading a pretrained model.
...
...
tests/test_cnn/test_weight_init.py
View file @
5f5e8e83
...
@@ -6,10 +6,11 @@ import pytest
...
@@ -6,10 +6,11 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
mmcv.cnn
import
(
ConstantInit
,
KaimingInit
,
NormalInit
,
PretrainedInit
,
from
mmcv.cnn
import
(
Caffe2XavierInit
,
ConstantInit
,
KaimingInit
,
NormalInit
,
UniformInit
,
XavierInit
,
bias_init_with_prob
,
PretrainedInit
,
UniformInit
,
XavierInit
,
caffe2_xavier_init
,
constant_init
,
initialize
,
bias_init_with_prob
,
caffe2_xavier_init
,
constant_init
,
kaiming_init
,
normal_init
,
uniform_init
,
xavier_init
)
initialize
,
kaiming_init
,
normal_init
,
uniform_init
,
xavier_init
)
def
test_constant_init
():
def
test_constant_init
():
...
@@ -219,6 +220,15 @@ def test_kaiminginit():
...
@@ -219,6 +220,15 @@ def test_kaiminginit():
assert
torch
.
equal
(
model
[
2
].
bias
,
torch
.
full
(
model
[
2
].
bias
.
shape
,
10.
))
assert
torch
.
equal
(
model
[
2
].
bias
,
torch
.
full
(
model
[
2
].
bias
.
shape
,
10.
))
def
test_caffe2xavierinit
():
"""test Caffe2XavierInit."""
model
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
1
,
3
),
nn
.
ReLU
(),
nn
.
Linear
(
1
,
2
))
func
=
Caffe2XavierInit
(
bias
=
0.1
,
layer
=
'Conv2d'
)
func
(
model
)
assert
torch
.
equal
(
model
[
0
].
bias
,
torch
.
full
(
model
[
0
].
bias
.
shape
,
0.1
))
assert
not
torch
.
equal
(
model
[
2
].
bias
,
torch
.
full
(
model
[
2
].
bias
.
shape
,
0.1
))
class
FooModule
(
nn
.
Module
):
class
FooModule
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment