Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
f64d4858
Unverified
Commit
f64d4858
authored
Aug 03, 2023
by
youkaichao
Committed by
GitHub
Aug 03, 2023
Browse files
rename fast_conv_bn_eval to efficient_conv_bn_eval (#2884)
parent
ad7284e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
37 deletions
+42
-37
mmcv/cnn/bricks/conv_module.py
mmcv/cnn/bricks/conv_module.py
+22
-20
tests/test_cnn/test_conv_module.py
tests/test_cnn/test_conv_module.py
+20
-17
No files found.
mmcv/cnn/bricks/conv_module.py
View file @
f64d4858
...
@@ -15,8 +15,9 @@ from .norm import build_norm_layer
...
@@ -15,8 +15,9 @@ from .norm import build_norm_layer
from
.padding
import
build_padding_layer
from
.padding
import
build_padding_layer
def
fast_conv_bn_eval_forward
(
bn
:
_BatchNorm
,
conv
:
nn
.
modules
.
conv
.
_ConvNd
,
def
efficient_conv_bn_eval_forward
(
bn
:
_BatchNorm
,
x
:
torch
.
Tensor
):
conv
:
nn
.
modules
.
conv
.
_ConvNd
,
x
:
torch
.
Tensor
):
"""
"""
Implementation based on https://arxiv.org/abs/2305.11624
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
...
@@ -115,9 +116,9 @@ class ConvModule(nn.Module):
...
@@ -115,9 +116,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
Default: ('conv', 'norm', 'act').
fas
t_conv_bn_eval (bool): Whether use
fas
t conv when the
consecutive
efficien
t_conv_bn_eval (bool): Whether use
efficien
t conv when the
bn is in eval mode (either training or testing), as
proposed in
consecutive
bn is in eval mode (either training or testing), as
https://arxiv.org/abs/2305.11624 . Default: False.
proposed in
https://arxiv.org/abs/2305.11624 . Default:
`
False
`
.
"""
"""
_abbr_
=
'conv_block'
_abbr_
=
'conv_block'
...
@@ -138,7 +139,7 @@ class ConvModule(nn.Module):
...
@@ -138,7 +139,7 @@ class ConvModule(nn.Module):
with_spectral_norm
:
bool
=
False
,
with_spectral_norm
:
bool
=
False
,
padding_mode
:
str
=
'zeros'
,
padding_mode
:
str
=
'zeros'
,
order
:
tuple
=
(
'conv'
,
'norm'
,
'act'
),
order
:
tuple
=
(
'conv'
,
'norm'
,
'act'
),
fas
t_conv_bn_eval
:
bool
=
False
):
efficien
t_conv_bn_eval
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
assert
conv_cfg
is
None
or
isinstance
(
conv_cfg
,
dict
)
assert
conv_cfg
is
None
or
isinstance
(
conv_cfg
,
dict
)
assert
norm_cfg
is
None
or
isinstance
(
norm_cfg
,
dict
)
assert
norm_cfg
is
None
or
isinstance
(
norm_cfg
,
dict
)
...
@@ -209,7 +210,7 @@ class ConvModule(nn.Module):
...
@@ -209,7 +210,7 @@ class ConvModule(nn.Module):
else
:
else
:
self
.
norm_name
=
None
# type: ignore
self
.
norm_name
=
None
# type: ignore
self
.
turn_on_
fas
t_conv_bn_eval
(
fas
t_conv_bn_eval
)
self
.
turn_on_
efficien
t_conv_bn_eval
(
efficien
t_conv_bn_eval
)
# build activation layer
# build activation layer
if
self
.
with_activation
:
if
self
.
with_activation
:
...
@@ -263,15 +264,16 @@ class ConvModule(nn.Module):
...
@@ -263,15 +264,16 @@ class ConvModule(nn.Module):
if
self
.
with_explicit_padding
:
if
self
.
with_explicit_padding
:
x
=
self
.
padding_layer
(
x
)
x
=
self
.
padding_layer
(
x
)
# if the next operation is norm and we have a norm layer in
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled
fas
t_conv_bn_eval for
the conv
# eval mode and we have enabled
`efficien
t_conv_bn_eval
`
for
# operator, then activate the optimized forward and
skip the
#
the conv
operator, then activate the optimized forward and
# next norm operator since it has been fused
#
skip the
next norm operator since it has been fused
if
layer_index
+
1
<
len
(
self
.
order
)
and
\
if
layer_index
+
1
<
len
(
self
.
order
)
and
\
self
.
order
[
layer_index
+
1
]
==
'norm'
and
norm
and
\
self
.
order
[
layer_index
+
1
]
==
'norm'
and
norm
and
\
self
.
with_norm
and
not
self
.
norm
.
training
and
\
self
.
with_norm
and
not
self
.
norm
.
training
and
\
self
.
fast_conv_bn_eval_forward
is
not
None
:
self
.
efficient_conv_bn_eval_forward
is
not
None
:
self
.
conv
.
forward
=
partial
(
self
.
fast_conv_bn_eval_forward
,
self
.
conv
.
forward
=
partial
(
self
.
norm
,
self
.
conv
)
self
.
efficient_conv_bn_eval_forward
,
self
.
norm
,
self
.
conv
)
layer_index
+=
1
layer_index
+=
1
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
del
self
.
conv
.
forward
del
self
.
conv
.
forward
...
@@ -284,20 +286,20 @@ class ConvModule(nn.Module):
...
@@ -284,20 +286,20 @@ class ConvModule(nn.Module):
layer_index
+=
1
layer_index
+=
1
return
x
return
x
def
turn_on_
fas
t_conv_bn_eval
(
self
,
fas
t_conv_bn_eval
=
True
):
def
turn_on_
efficien
t_conv_bn_eval
(
self
,
efficien
t_conv_bn_eval
=
True
):
#
fas
t_conv_bn_eval works for conv + bn
#
efficien
t_conv_bn_eval works for conv + bn
# with `track_running_stats` option
# with `track_running_stats` option
if
fas
t_conv_bn_eval
and
self
.
norm
\
if
efficien
t_conv_bn_eval
and
self
.
norm
\
and
isinstance
(
self
.
norm
,
_BatchNorm
)
\
and
isinstance
(
self
.
norm
,
_BatchNorm
)
\
and
self
.
norm
.
track_running_stats
:
and
self
.
norm
.
track_running_stats
:
self
.
fas
t_conv_bn_eval_forward
=
fas
t_conv_bn_eval_forward
self
.
efficien
t_conv_bn_eval_forward
=
efficien
t_conv_bn_eval_forward
# noqa: E501
else
:
else
:
self
.
fas
t_conv_bn_eval_forward
=
None
# type: ignore
self
.
efficien
t_conv_bn_eval_forward
=
None
# type: ignore
@
staticmethod
@
staticmethod
def
create_from_conv_bn
(
conv
:
torch
.
nn
.
modules
.
conv
.
_ConvNd
,
def
create_from_conv_bn
(
conv
:
torch
.
nn
.
modules
.
conv
.
_ConvNd
,
bn
:
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
,
bn
:
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
,
fas
t_conv_bn_eval
=
True
)
->
'ConvModule'
:
efficien
t_conv_bn_eval
=
True
)
->
'ConvModule'
:
"""Create a ConvModule from a conv and a bn module."""
"""Create a ConvModule from a conv and a bn module."""
self
=
ConvModule
.
__new__
(
ConvModule
)
self
=
ConvModule
.
__new__
(
ConvModule
)
super
(
ConvModule
,
self
).
__init__
()
super
(
ConvModule
,
self
).
__init__
()
...
@@ -331,6 +333,6 @@ class ConvModule(nn.Module):
...
@@ -331,6 +333,6 @@ class ConvModule(nn.Module):
self
.
norm_name
,
norm
=
'bn'
,
bn
self
.
norm_name
,
norm
=
'bn'
,
bn
self
.
add_module
(
self
.
norm_name
,
norm
)
self
.
add_module
(
self
.
norm_name
,
norm
)
self
.
turn_on_
fas
t_conv_bn_eval
(
fas
t_conv_bn_eval
)
self
.
turn_on_
efficien
t_conv_bn_eval
(
efficien
t_conv_bn_eval
)
return
self
return
self
tests/test_cnn/test_conv_module.py
View file @
f64d4858
...
@@ -75,27 +75,30 @@ def test_conv_module():
...
@@ -75,27 +75,30 @@ def test_conv_module():
output
=
conv
(
x
)
output
=
conv
(
x
)
assert
output
.
shape
==
(
1
,
8
,
255
,
255
)
assert
output
.
shape
==
(
1
,
8
,
255
,
255
)
# conv + norm with
fas
t mode
# conv + norm with
efficien
t mode
fas
t_conv
=
ConvModule
(
efficien
t_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
fas
t_conv_bn_eval
=
True
).
eval
()
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
efficien
t_conv_bn_eval
=
True
).
eval
()
plain_conv
=
ConvModule
(
plain_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
fast_conv_bn_eval
=
False
).
eval
()
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
for
fast_param
,
plain_param
in
zip
(
fast_conv
.
state_dict
().
values
(),
efficient_conv_bn_eval
=
False
).
eval
()
plain_conv
.
state_dict
().
values
()):
for
efficient_param
,
plain_param
in
zip
(
plain_param
.
copy_
(
fast_param
)
efficient_conv
.
state_dict
().
values
(),
plain_conv
.
state_dict
().
values
()):
fast_mode_output
=
fast_conv
(
x
)
plain_param
.
copy_
(
efficient_param
)
efficient_mode_output
=
efficient_conv
(
x
)
plain_mode_output
=
plain_conv
(
x
)
plain_mode_output
=
plain_conv
(
x
)
assert
torch
.
allclose
(
fas
t_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
assert
torch
.
allclose
(
efficien
t_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
# `conv` attribute can be dynamically modified in
fas
t mode
# `conv` attribute can be dynamically modified in
efficien
t mode
fas
t_conv
=
ConvModule
(
efficien
t_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
fas
t_conv_bn_eval
=
True
).
eval
()
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
efficien
t_conv_bn_eval
=
True
).
eval
()
new_conv
=
nn
.
Conv2d
(
3
,
8
,
2
).
eval
()
new_conv
=
nn
.
Conv2d
(
3
,
8
,
2
).
eval
()
fast_conv
.
conv
=
new_conv
efficient_conv
.
conv
=
new_conv
fast_mode_output
=
fast_conv
(
x
)
efficient_mode_output
=
efficient_conv
(
x
)
plain_mode_output
=
fast_conv
.
activate
(
fast_conv
.
norm
(
new_conv
(
x
)))
plain_mode_output
=
efficient_conv
.
activate
(
assert
torch
.
allclose
(
fast_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
efficient_conv
.
norm
(
new_conv
(
x
)))
assert
torch
.
allclose
(
efficient_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
# conv + act
# conv + act
conv
=
ConvModule
(
3
,
8
,
2
)
conv
=
ConvModule
(
3
,
8
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment