Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
c4588297
Commit
c4588297
authored
Jan 17, 2023
by
Mashiro
Committed by
Zaida Zhou
Mar 20, 2023
Browse files
Refine rfsearch and fix a typo
parent
1f9e5b57
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
78 deletions
+31
-78
docs/en/api/cnn.rst
docs/en/api/cnn.rst
+1
-0
docs/zh_cn/api/cnn.rst
docs/zh_cn/api/cnn.rst
+1
-0
mmcv/cnn/rfsearch/operator.py
mmcv/cnn/rfsearch/operator.py
+11
-12
mmcv/cnn/rfsearch/search.py
mmcv/cnn/rfsearch/search.py
+17
-16
mmcv/image/geometric.py
mmcv/image/geometric.py
+1
-1
tests/test_cnn/test_rfsearch/test_search.py
tests/test_cnn/test_rfsearch/test_search.py
+0
-49
No files found.
docs/en/api/cnn.rst
View file @
c4588297
...
...
@@ -40,6 +40,7 @@ Module
NonLocal3d
Scale
Swish
Conv2dRFSearchOp
Build Function
----------------
...
...
docs/zh_cn/api/cnn.rst
View file @
c4588297
...
...
@@ -40,6 +40,7 @@ Module
NonLocal3d
Scale
Swish
Conv2dRFSearchOp
Build Function
----------------
...
...
mmcv/cnn/rfsearch/operator.py
View file @
c4588297
...
...
@@ -4,14 +4,12 @@ import copy
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmengine.logging
import
MMLogger
from
mmengine.logging
import
print_log
from
mmengine.model
import
BaseModule
from
torch
import
Tensor
from
.utils
import
expand_rates
,
get_single_padding
logger
=
MMLogger
.
get_current_instance
()
class
BaseConvRFSearchOp
(
BaseModule
):
"""Based class of ConvRFSearchOp.
...
...
@@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self
.
branch_weights
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_branches
))
if
self
.
verbose
:
logger
.
info
(
f
'Expand as
{
self
.
dilation_rates
}
'
)
print_log
(
f
'Expand as
{
self
.
dilation_rates
}
'
,
'current'
)
nn
.
init
.
constant_
(
self
.
branch_weights
,
global_config
[
'init_alphas'
])
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
...
...
@@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
output
+=
outputs
[
i
]
return
output
def
estimate_rates
(
self
):
def
estimate_rates
(
self
)
->
None
:
"""Estimate new dilation rate based on trained branch_weights."""
norm_w
=
self
.
normlize
(
self
.
branch_weights
[:
len
(
self
.
dilation_rates
)])
if
self
.
verbose
:
logger
.
info
(
'Estimate dilation {} with weight {}.'
.
format
(
self
.
dilation_rates
,
norm_w
.
detach
().
cpu
().
numpy
().
tolist
()))
print_log
(
'Estimate dilation {} with weight {}.'
.
format
(
self
.
dilation_rates
,
norm_w
.
detach
().
cpu
().
numpy
().
tolist
()),
'current'
)
sum0
,
sum1
,
w_sum
=
0
,
0
,
0
for
i
in
range
(
len
(
self
.
dilation_rates
)):
...
...
@@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self
.
op_layer
.
padding
=
self
.
get_padding
(
self
.
op_layer
.
dilation
)
self
.
dilation_rates
=
[
tuple
(
estimated
)]
if
self
.
verbose
:
logger
.
info
(
f
'Estimate as
{
tuple
(
estimated
)
}
'
)
print_log
(
f
'Estimate as
{
tuple
(
estimated
)
}
'
,
'current'
)
def
expand_rates
(
self
):
def
expand_rates
(
self
)
->
None
:
"""Expand dilation rate."""
dilation
=
self
.
op_layer
.
dilation
dilation_rates
=
expand_rates
(
dilation
,
self
.
global_config
)
...
...
@@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self
.
dilation_rates
=
copy
.
deepcopy
(
dilation_rates
)
if
self
.
verbose
:
logger
.
info
(
f
'Expand as
{
self
.
dilation_rates
}
'
)
print_log
(
f
'Expand as
{
self
.
dilation_rates
}
'
,
'current'
)
nn
.
init
.
constant_
(
self
.
branch_weights
,
self
.
global_config
[
'init_alphas'
])
def
get_padding
(
self
,
dilation
):
def
get_padding
(
self
,
dilation
)
->
tuple
:
padding
=
(
get_single_padding
(
self
.
op_layer
.
kernel_size
[
0
],
self
.
op_layer
.
stride
[
0
],
dilation
[
0
]),
get_single_padding
(
self
.
op_layer
.
kernel_size
[
1
],
...
...
mmcv/cnn/rfsearch/search.py
View file @
c4588297
...
...
@@ -3,15 +3,14 @@ import os
from
typing
import
Dict
,
Optional
import
mmengine
import
torch
# noqa
import
torch.nn
as
nn
from
mmengine.hooks
import
Hook
from
mmengine.logging
import
MMLogger
from
mmengine.logging
import
print_log
from
mmengine.registry
import
HOOKS
from
mmcv.cnn.rfsearch.utils
import
get_single_padding
,
write_to_json
from
.operator
import
BaseConvRFSearchOp
logger
=
MMLogger
.
get_current_instance
()
from
.operator
import
BaseConvRFSearchOp
,
Conv2dRFSearchOp
# noqa
from
.utils
import
get_single_padding
,
write_to_json
@
HOOKS
.
register_module
()
...
...
@@ -82,7 +81,7 @@ class RFSearchHook(Hook):
search/fixed_single_branch/fixed_multi_branch
"""
if
self
.
verbose
:
logger
.
info
(
'RFSearch init begin.'
)
print_log
(
'RFSearch init begin.'
,
'current'
)
if
self
.
mode
==
'search'
:
if
self
.
config
[
'structure'
]:
self
.
set_model
(
model
,
search_op
=
'Conv2d'
)
...
...
@@ -95,19 +94,19 @@ class RFSearchHook(Hook):
else
:
raise
NotImplementedError
if
self
.
verbose
:
logger
.
info
(
'RFSearch init end.'
)
print_log
(
'RFSearch init end.'
,
'current'
)
def
after_train_epoch
(
self
,
runner
):
"""Performs a dilation searching step after one training epoch."""
if
self
.
by_epoch
and
self
.
mode
==
'search'
:
self
.
step
(
runner
.
model
,
runner
.
work_dir
)
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
,
batch_idx
,
data_batch
,
outputs
):
"""Performs a dilation searching step after one training iteration."""
if
not
self
.
by_epoch
and
self
.
mode
==
'search'
:
self
.
step
(
runner
.
model
,
runner
.
work_dir
)
def
step
(
self
,
model
:
nn
.
Module
,
work_dir
:
str
):
def
step
(
self
,
model
:
nn
.
Module
,
work_dir
:
str
)
->
None
:
"""Performs a dilation searching step.
Args:
...
...
@@ -132,7 +131,7 @@ class RFSearchHook(Hook):
),
)
def
estimate_and_expand
(
self
,
model
:
nn
.
Module
):
def
estimate_and_expand
(
self
,
model
:
nn
.
Module
)
->
None
:
"""estimate and search for RFConvOp.
Args:
...
...
@@ -146,7 +145,7 @@ class RFSearchHook(Hook):
def
wrap_model
(
self
,
model
:
nn
.
Module
,
search_op
:
str
=
'Conv2d'
,
prefix
:
str
=
''
):
prefix
:
str
=
''
)
->
None
:
"""wrap model to support searchable conv op.
Args:
...
...
@@ -176,8 +175,9 @@ class RFSearchHook(Hook):
module
,
self
.
config
[
'search'
],
self
.
verbose
)
moduleWrap
=
moduleWrap
.
to
(
module
.
weight
.
device
)
if
self
.
verbose
:
logger
.
info
(
'Wrap model %s to %s.'
%
(
str
(
module
),
str
(
moduleWrap
)))
print_log
(
'Wrap model %s to %s.'
%
(
str
(
module
),
str
(
moduleWrap
)),
'current'
)
setattr
(
model
,
name
,
moduleWrap
)
elif
not
isinstance
(
module
,
BaseConvRFSearchOp
):
self
.
wrap_model
(
module
,
search_op
,
fullname
)
...
...
@@ -186,7 +186,7 @@ class RFSearchHook(Hook):
model
:
nn
.
Module
,
search_op
:
str
=
'Conv2d'
,
init_rates
:
Optional
[
int
]
=
None
,
prefix
:
str
=
''
):
prefix
:
str
=
''
)
->
None
:
"""set model based on config.
Args:
...
...
@@ -231,8 +231,9 @@ class RFSearchHook(Hook):
self
.
config
[
'structure'
][
fullname
][
1
]))
setattr
(
model
,
name
,
module
)
if
self
.
verbose
:
logger
.
info
(
print_log
(
'Set module %s dilation as: [%d %d]'
%
(
fullname
,
module
.
dilation
[
0
],
module
.
dilation
[
1
]))
(
fullname
,
module
.
dilation
[
0
],
module
.
dilation
[
1
]),
'current'
)
elif
not
isinstance
(
module
,
BaseConvRFSearchOp
):
self
.
set_model
(
module
,
search_op
,
init_rates
,
fullname
)
mmcv/image/geometric.py
View file @
c4588297
...
...
@@ -440,7 +440,7 @@ def imcrop(
img (ndarray): Image to be cropped.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no
padd
ing.
1.0 means no
scal
ing.
pad_fill (Number | list[Number]): Value to be filled for padding.
Default: None, which means no padding.
...
...
tests/test_cnn/test_rfsearch/test_search.py
View file @
c4588297
# Copyright (c) OpenMMLab. All rights reserved.
"""Tests the rfsearch with runners.
CommandLine:
pytest tests/test_runner/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
mmcv.cnn.rfsearch
import
Conv2dRFSearchOp
,
RFSearchHook
from
tests.test_runner.test_hooks
import
_build_demo_runner
def
test_rfsearchhook
():
...
...
@@ -114,20 +105,6 @@ def test_rfsearchhook():
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
# 1. test step() with mode of search
loader
=
DataLoader
(
torch
.
ones
((
1
,
1
,
1
,
1
)))
runner
=
_build_demo_runner
()
runner
.
model
=
model
runner
.
register_hook
(
rfsearchhook_search
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
test_skip_layer
()
assert
not
isinstance
(
model
.
conv1
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv2
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
3
)]
# 2. test init_model() with mode of fixed_single_branch
model
=
Model
()
rfsearchhook_fixed_single_branch
.
init_model
(
model
)
...
...
@@ -139,19 +116,6 @@ def test_rfsearchhook():
assert
model
.
conv2
.
dilation
==
(
2
,
2
)
assert
model
.
conv3
.
dilation
==
(
1
,
1
)
# 2. test step() with mode of fixed_single_branch
runner
=
_build_demo_runner
()
runner
.
model
=
model
runner
.
register_hook
(
rfsearchhook_fixed_single_branch
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
assert
not
isinstance
(
model
.
conv1
,
Conv2dRFSearchOp
)
assert
not
isinstance
(
model
.
conv2
,
Conv2dRFSearchOp
)
assert
not
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv1
.
dilation
==
(
1
,
1
)
assert
model
.
conv2
.
dilation
==
(
2
,
2
)
assert
model
.
conv3
.
dilation
==
(
1
,
1
)
# 3. test init_model() with mode of fixed_multi_branch
model
=
Model
()
rfsearchhook_fixed_multi_branch
.
init_model
(
model
)
...
...
@@ -162,16 +126,3 @@ def test_rfsearchhook():
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
# 3. test step() with mode of fixed_single_branch
runner
=
_build_demo_runner
()
runner
.
model
=
model
runner
.
register_hook
(
rfsearchhook_fixed_multi_branch
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
test_skip_layer
()
assert
not
isinstance
(
model
.
conv1
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv2
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment