Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
f64d4858
Unverified
Commit
f64d4858
authored
Aug 03, 2023
by
youkaichao
Committed by
GitHub
Aug 03, 2023
Browse files
rename fast_conv_bn_eval to efficient_conv_bn_eval (#2884)
parent
ad7284e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
37 deletions
+42
-37
mmcv/cnn/bricks/conv_module.py
mmcv/cnn/bricks/conv_module.py
+22
-20
tests/test_cnn/test_conv_module.py
tests/test_cnn/test_conv_module.py
+20
-17
No files found.
mmcv/cnn/bricks/conv_module.py
View file @
f64d4858
...
...
@@ -15,8 +15,9 @@ from .norm import build_norm_layer
from
.padding
import
build_padding_layer
def
fast_conv_bn_eval_forward
(
bn
:
_BatchNorm
,
conv
:
nn
.
modules
.
conv
.
_ConvNd
,
x
:
torch
.
Tensor
):
def
efficient_conv_bn_eval_forward
(
bn
:
_BatchNorm
,
conv
:
nn
.
modules
.
conv
.
_ConvNd
,
x
:
torch
.
Tensor
):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
...
...
@@ -115,9 +116,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
fas
t_conv_bn_eval (bool): Whether use
fas
t conv when the
consecutive
bn is in eval mode (either training or testing), as
proposed in
https://arxiv.org/abs/2305.11624 . Default: False.
efficien
t_conv_bn_eval (bool): Whether use
efficien
t conv when the
consecutive
bn is in eval mode (either training or testing), as
proposed in
https://arxiv.org/abs/2305.11624 . Default:
`
False
`
.
"""
_abbr_
=
'conv_block'
...
...
@@ -138,7 +139,7 @@ class ConvModule(nn.Module):
with_spectral_norm
:
bool
=
False
,
padding_mode
:
str
=
'zeros'
,
order
:
tuple
=
(
'conv'
,
'norm'
,
'act'
),
fas
t_conv_bn_eval
:
bool
=
False
):
efficien
t_conv_bn_eval
:
bool
=
False
):
super
().
__init__
()
assert
conv_cfg
is
None
or
isinstance
(
conv_cfg
,
dict
)
assert
norm_cfg
is
None
or
isinstance
(
norm_cfg
,
dict
)
...
...
@@ -209,7 +210,7 @@ class ConvModule(nn.Module):
else
:
self
.
norm_name
=
None
# type: ignore
self
.
turn_on_
fas
t_conv_bn_eval
(
fas
t_conv_bn_eval
)
self
.
turn_on_
efficien
t_conv_bn_eval
(
efficien
t_conv_bn_eval
)
# build activation layer
if
self
.
with_activation
:
...
...
@@ -263,15 +264,16 @@ class ConvModule(nn.Module):
if
self
.
with_explicit_padding
:
x
=
self
.
padding_layer
(
x
)
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled
fas
t_conv_bn_eval for
the conv
# operator, then activate the optimized forward and
skip the
# next norm operator since it has been fused
# eval mode and we have enabled
`efficien
t_conv_bn_eval
`
for
#
the conv
operator, then activate the optimized forward and
#
skip the
next norm operator since it has been fused
if
layer_index
+
1
<
len
(
self
.
order
)
and
\
self
.
order
[
layer_index
+
1
]
==
'norm'
and
norm
and
\
self
.
with_norm
and
not
self
.
norm
.
training
and
\
self
.
fast_conv_bn_eval_forward
is
not
None
:
self
.
conv
.
forward
=
partial
(
self
.
fast_conv_bn_eval_forward
,
self
.
norm
,
self
.
conv
)
self
.
efficient_conv_bn_eval_forward
is
not
None
:
self
.
conv
.
forward
=
partial
(
self
.
efficient_conv_bn_eval_forward
,
self
.
norm
,
self
.
conv
)
layer_index
+=
1
x
=
self
.
conv
(
x
)
del
self
.
conv
.
forward
...
...
@@ -284,20 +286,20 @@ class ConvModule(nn.Module):
layer_index
+=
1
return
x
def
turn_on_
fas
t_conv_bn_eval
(
self
,
fas
t_conv_bn_eval
=
True
):
#
fas
t_conv_bn_eval works for conv + bn
def
turn_on_
efficien
t_conv_bn_eval
(
self
,
efficien
t_conv_bn_eval
=
True
):
#
efficien
t_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if
fas
t_conv_bn_eval
and
self
.
norm
\
if
efficien
t_conv_bn_eval
and
self
.
norm
\
and
isinstance
(
self
.
norm
,
_BatchNorm
)
\
and
self
.
norm
.
track_running_stats
:
self
.
fas
t_conv_bn_eval_forward
=
fas
t_conv_bn_eval_forward
self
.
efficien
t_conv_bn_eval_forward
=
efficien
t_conv_bn_eval_forward
# noqa: E501
else
:
self
.
fas
t_conv_bn_eval_forward
=
None
# type: ignore
self
.
efficien
t_conv_bn_eval_forward
=
None
# type: ignore
@
staticmethod
def
create_from_conv_bn
(
conv
:
torch
.
nn
.
modules
.
conv
.
_ConvNd
,
bn
:
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
,
fas
t_conv_bn_eval
=
True
)
->
'ConvModule'
:
efficien
t_conv_bn_eval
=
True
)
->
'ConvModule'
:
"""Create a ConvModule from a conv and a bn module."""
self
=
ConvModule
.
__new__
(
ConvModule
)
super
(
ConvModule
,
self
).
__init__
()
...
...
@@ -331,6 +333,6 @@ class ConvModule(nn.Module):
self
.
norm_name
,
norm
=
'bn'
,
bn
self
.
add_module
(
self
.
norm_name
,
norm
)
self
.
turn_on_
fas
t_conv_bn_eval
(
fas
t_conv_bn_eval
)
self
.
turn_on_
efficien
t_conv_bn_eval
(
efficien
t_conv_bn_eval
)
return
self
tests/test_cnn/test_conv_module.py
View file @
f64d4858
...
...
@@ -75,27 +75,30 @@ def test_conv_module():
output
=
conv
(
x
)
assert
output
.
shape
==
(
1
,
8
,
255
,
255
)
# conv + norm with
fas
t mode
fas
t_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
fas
t_conv_bn_eval
=
True
).
eval
()
# conv + norm with
efficien
t mode
efficien
t_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
efficien
t_conv_bn_eval
=
True
).
eval
()
plain_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
fast_conv_bn_eval
=
False
).
eval
()
for
fast_param
,
plain_param
in
zip
(
fast_conv
.
state_dict
().
values
(),
plain_conv
.
state_dict
().
values
()):
plain_param
.
copy_
(
fast_param
)
fast_mode_output
=
fast_conv
(
x
)
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
efficient_conv_bn_eval
=
False
).
eval
()
for
efficient_param
,
plain_param
in
zip
(
efficient_conv
.
state_dict
().
values
(),
plain_conv
.
state_dict
().
values
()):
plain_param
.
copy_
(
efficient_param
)
efficient_mode_output
=
efficient_conv
(
x
)
plain_mode_output
=
plain_conv
(
x
)
assert
torch
.
allclose
(
fas
t_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
assert
torch
.
allclose
(
efficien
t_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
# `conv` attribute can be dynamically modified in
fas
t mode
fas
t_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
fas
t_conv_bn_eval
=
True
).
eval
()
# `conv` attribute can be dynamically modified in
efficien
t mode
efficien
t_conv
=
ConvModule
(
3
,
8
,
2
,
norm_cfg
=
dict
(
type
=
'BN'
),
efficien
t_conv_bn_eval
=
True
).
eval
()
new_conv
=
nn
.
Conv2d
(
3
,
8
,
2
).
eval
()
fast_conv
.
conv
=
new_conv
fast_mode_output
=
fast_conv
(
x
)
plain_mode_output
=
fast_conv
.
activate
(
fast_conv
.
norm
(
new_conv
(
x
)))
assert
torch
.
allclose
(
fast_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
efficient_conv
.
conv
=
new_conv
efficient_mode_output
=
efficient_conv
(
x
)
plain_mode_output
=
efficient_conv
.
activate
(
efficient_conv
.
norm
(
new_conv
(
x
)))
assert
torch
.
allclose
(
efficient_mode_output
,
plain_mode_output
,
atol
=
1e-5
)
# conv + act
conv
=
ConvModule
(
3
,
8
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment