Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
c4588297
"vscode:/vscode.git/clone" did not exist on "b2b554395ddd04393475c97cbccbd0cd3df8b362"
Commit
c4588297
authored
Jan 17, 2023
by
Mashiro
Committed by
Zaida Zhou
Mar 20, 2023
Browse files
Refine rfsearch and fix a typo
parent
1f9e5b57
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
78 deletions
+31
-78
docs/en/api/cnn.rst
docs/en/api/cnn.rst
+1
-0
docs/zh_cn/api/cnn.rst
docs/zh_cn/api/cnn.rst
+1
-0
mmcv/cnn/rfsearch/operator.py
mmcv/cnn/rfsearch/operator.py
+11
-12
mmcv/cnn/rfsearch/search.py
mmcv/cnn/rfsearch/search.py
+17
-16
mmcv/image/geometric.py
mmcv/image/geometric.py
+1
-1
tests/test_cnn/test_rfsearch/test_search.py
tests/test_cnn/test_rfsearch/test_search.py
+0
-49
No files found.
docs/en/api/cnn.rst
View file @
c4588297
...
@@ -40,6 +40,7 @@ Module
...
@@ -40,6 +40,7 @@ Module
NonLocal3d
NonLocal3d
Scale
Scale
Swish
Swish
Conv2dRFSearchOp
Build Function
Build Function
----------------
----------------
...
...
docs/zh_cn/api/cnn.rst
View file @
c4588297
...
@@ -40,6 +40,7 @@ Module
...
@@ -40,6 +40,7 @@ Module
NonLocal3d
NonLocal3d
Scale
Scale
Swish
Swish
Conv2dRFSearchOp
Build Function
Build Function
----------------
----------------
...
...
mmcv/cnn/rfsearch/operator.py
View file @
c4588297
...
@@ -4,14 +4,12 @@ import copy
...
@@ -4,14 +4,12 @@ import copy
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmengine.logging
import
MMLogger
from
mmengine.logging
import
print_log
from
mmengine.model
import
BaseModule
from
mmengine.model
import
BaseModule
from
torch
import
Tensor
from
torch
import
Tensor
from
.utils
import
expand_rates
,
get_single_padding
from
.utils
import
expand_rates
,
get_single_padding
logger
=
MMLogger
.
get_current_instance
()
class
BaseConvRFSearchOp
(
BaseModule
):
class
BaseConvRFSearchOp
(
BaseModule
):
"""Based class of ConvRFSearchOp.
"""Based class of ConvRFSearchOp.
...
@@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
...
@@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self
.
branch_weights
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_branches
))
self
.
branch_weights
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_branches
))
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
f
'Expand as
{
self
.
dilation_rates
}
'
)
print_log
(
f
'Expand as
{
self
.
dilation_rates
}
'
,
'current'
)
nn
.
init
.
constant_
(
self
.
branch_weights
,
global_config
[
'init_alphas'
])
nn
.
init
.
constant_
(
self
.
branch_weights
,
global_config
[
'init_alphas'
])
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
...
@@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
...
@@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
output
+=
outputs
[
i
]
output
+=
outputs
[
i
]
return
output
return
output
def
estimate_rates
(
self
):
def
estimate_rates
(
self
)
->
None
:
"""Estimate new dilation rate based on trained branch_weights."""
"""Estimate new dilation rate based on trained branch_weights."""
norm_w
=
self
.
normlize
(
self
.
branch_weights
[:
len
(
self
.
dilation_rates
)])
norm_w
=
self
.
normlize
(
self
.
branch_weights
[:
len
(
self
.
dilation_rates
)])
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
'Estimate dilation {} with weight {}.'
.
format
(
print_log
(
'Estimate dilation {} with weight {}.'
.
format
(
self
.
dilation_rates
,
self
.
dilation_rates
,
norm_w
.
detach
().
cpu
().
numpy
().
tolist
()))
norm_w
.
detach
().
cpu
().
numpy
().
tolist
())
,
'current'
)
sum0
,
sum1
,
w_sum
=
0
,
0
,
0
sum0
,
sum1
,
w_sum
=
0
,
0
,
0
for
i
in
range
(
len
(
self
.
dilation_rates
)):
for
i
in
range
(
len
(
self
.
dilation_rates
)):
...
@@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
...
@@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self
.
op_layer
.
padding
=
self
.
get_padding
(
self
.
op_layer
.
dilation
)
self
.
op_layer
.
padding
=
self
.
get_padding
(
self
.
op_layer
.
dilation
)
self
.
dilation_rates
=
[
tuple
(
estimated
)]
self
.
dilation_rates
=
[
tuple
(
estimated
)]
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
f
'Estimate as
{
tuple
(
estimated
)
}
'
)
print_log
(
f
'Estimate as
{
tuple
(
estimated
)
}
'
,
'current'
)
def
expand_rates
(
self
):
def
expand_rates
(
self
)
->
None
:
"""Expand dilation rate."""
"""Expand dilation rate."""
dilation
=
self
.
op_layer
.
dilation
dilation
=
self
.
op_layer
.
dilation
dilation_rates
=
expand_rates
(
dilation
,
self
.
global_config
)
dilation_rates
=
expand_rates
(
dilation
,
self
.
global_config
)
...
@@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
...
@@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self
.
dilation_rates
=
copy
.
deepcopy
(
dilation_rates
)
self
.
dilation_rates
=
copy
.
deepcopy
(
dilation_rates
)
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
f
'Expand as
{
self
.
dilation_rates
}
'
)
print_log
(
f
'Expand as
{
self
.
dilation_rates
}
'
,
'current'
)
nn
.
init
.
constant_
(
self
.
branch_weights
,
nn
.
init
.
constant_
(
self
.
branch_weights
,
self
.
global_config
[
'init_alphas'
])
self
.
global_config
[
'init_alphas'
])
def
get_padding
(
self
,
dilation
):
def
get_padding
(
self
,
dilation
)
->
tuple
:
padding
=
(
get_single_padding
(
self
.
op_layer
.
kernel_size
[
0
],
padding
=
(
get_single_padding
(
self
.
op_layer
.
kernel_size
[
0
],
self
.
op_layer
.
stride
[
0
],
dilation
[
0
]),
self
.
op_layer
.
stride
[
0
],
dilation
[
0
]),
get_single_padding
(
self
.
op_layer
.
kernel_size
[
1
],
get_single_padding
(
self
.
op_layer
.
kernel_size
[
1
],
...
...
mmcv/cnn/rfsearch/search.py
View file @
c4588297
...
@@ -3,15 +3,14 @@ import os
...
@@ -3,15 +3,14 @@ import os
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
mmengine
import
mmengine
import
torch
# noqa
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmengine.hooks
import
Hook
from
mmengine.hooks
import
Hook
from
mmengine.logging
import
MMLogger
from
mmengine.logging
import
print_log
from
mmengine.registry
import
HOOKS
from
mmengine.registry
import
HOOKS
from
mmcv.cnn.rfsearch.utils
import
get_single_padding
,
write_to_json
from
.operator
import
BaseConvRFSearchOp
,
Conv2dRFSearchOp
# noqa
from
.operator
import
BaseConvRFSearchOp
from
.utils
import
get_single_padding
,
write_to_json
logger
=
MMLogger
.
get_current_instance
()
@
HOOKS
.
register_module
()
@
HOOKS
.
register_module
()
...
@@ -82,7 +81,7 @@ class RFSearchHook(Hook):
...
@@ -82,7 +81,7 @@ class RFSearchHook(Hook):
search/fixed_single_branch/fixed_multi_branch
search/fixed_single_branch/fixed_multi_branch
"""
"""
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
'RFSearch init begin.'
)
print_log
(
'RFSearch init begin.'
,
'current'
)
if
self
.
mode
==
'search'
:
if
self
.
mode
==
'search'
:
if
self
.
config
[
'structure'
]:
if
self
.
config
[
'structure'
]:
self
.
set_model
(
model
,
search_op
=
'Conv2d'
)
self
.
set_model
(
model
,
search_op
=
'Conv2d'
)
...
@@ -95,19 +94,19 @@ class RFSearchHook(Hook):
...
@@ -95,19 +94,19 @@ class RFSearchHook(Hook):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
'RFSearch init end.'
)
print_log
(
'RFSearch init end.'
,
'current'
)
def
after_train_epoch
(
self
,
runner
):
def
after_train_epoch
(
self
,
runner
):
"""Performs a dilation searching step after one training epoch."""
"""Performs a dilation searching step after one training epoch."""
if
self
.
by_epoch
and
self
.
mode
==
'search'
:
if
self
.
by_epoch
and
self
.
mode
==
'search'
:
self
.
step
(
runner
.
model
,
runner
.
work_dir
)
self
.
step
(
runner
.
model
,
runner
.
work_dir
)
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
,
batch_idx
,
data_batch
,
outputs
):
"""Performs a dilation searching step after one training iteration."""
"""Performs a dilation searching step after one training iteration."""
if
not
self
.
by_epoch
and
self
.
mode
==
'search'
:
if
not
self
.
by_epoch
and
self
.
mode
==
'search'
:
self
.
step
(
runner
.
model
,
runner
.
work_dir
)
self
.
step
(
runner
.
model
,
runner
.
work_dir
)
def
step
(
self
,
model
:
nn
.
Module
,
work_dir
:
str
):
def
step
(
self
,
model
:
nn
.
Module
,
work_dir
:
str
)
->
None
:
"""Performs a dilation searching step.
"""Performs a dilation searching step.
Args:
Args:
...
@@ -132,7 +131,7 @@ class RFSearchHook(Hook):
...
@@ -132,7 +131,7 @@ class RFSearchHook(Hook):
),
),
)
)
def
estimate_and_expand
(
self
,
model
:
nn
.
Module
):
def
estimate_and_expand
(
self
,
model
:
nn
.
Module
)
->
None
:
"""estimate and search for RFConvOp.
"""estimate and search for RFConvOp.
Args:
Args:
...
@@ -146,7 +145,7 @@ class RFSearchHook(Hook):
...
@@ -146,7 +145,7 @@ class RFSearchHook(Hook):
def
wrap_model
(
self
,
def
wrap_model
(
self
,
model
:
nn
.
Module
,
model
:
nn
.
Module
,
search_op
:
str
=
'Conv2d'
,
search_op
:
str
=
'Conv2d'
,
prefix
:
str
=
''
):
prefix
:
str
=
''
)
->
None
:
"""wrap model to support searchable conv op.
"""wrap model to support searchable conv op.
Args:
Args:
...
@@ -176,8 +175,9 @@ class RFSearchHook(Hook):
...
@@ -176,8 +175,9 @@ class RFSearchHook(Hook):
module
,
self
.
config
[
'search'
],
self
.
verbose
)
module
,
self
.
config
[
'search'
],
self
.
verbose
)
moduleWrap
=
moduleWrap
.
to
(
module
.
weight
.
device
)
moduleWrap
=
moduleWrap
.
to
(
module
.
weight
.
device
)
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
'Wrap model %s to %s.'
%
print_log
(
(
str
(
module
),
str
(
moduleWrap
)))
'Wrap model %s to %s.'
%
(
str
(
module
),
str
(
moduleWrap
)),
'current'
)
setattr
(
model
,
name
,
moduleWrap
)
setattr
(
model
,
name
,
moduleWrap
)
elif
not
isinstance
(
module
,
BaseConvRFSearchOp
):
elif
not
isinstance
(
module
,
BaseConvRFSearchOp
):
self
.
wrap_model
(
module
,
search_op
,
fullname
)
self
.
wrap_model
(
module
,
search_op
,
fullname
)
...
@@ -186,7 +186,7 @@ class RFSearchHook(Hook):
...
@@ -186,7 +186,7 @@ class RFSearchHook(Hook):
model
:
nn
.
Module
,
model
:
nn
.
Module
,
search_op
:
str
=
'Conv2d'
,
search_op
:
str
=
'Conv2d'
,
init_rates
:
Optional
[
int
]
=
None
,
init_rates
:
Optional
[
int
]
=
None
,
prefix
:
str
=
''
):
prefix
:
str
=
''
)
->
None
:
"""set model based on config.
"""set model based on config.
Args:
Args:
...
@@ -231,8 +231,9 @@ class RFSearchHook(Hook):
...
@@ -231,8 +231,9 @@ class RFSearchHook(Hook):
self
.
config
[
'structure'
][
fullname
][
1
]))
self
.
config
[
'structure'
][
fullname
][
1
]))
setattr
(
model
,
name
,
module
)
setattr
(
model
,
name
,
module
)
if
self
.
verbose
:
if
self
.
verbose
:
logger
.
info
(
print_log
(
'Set module %s dilation as: [%d %d]'
%
'Set module %s dilation as: [%d %d]'
%
(
fullname
,
module
.
dilation
[
0
],
module
.
dilation
[
1
]))
(
fullname
,
module
.
dilation
[
0
],
module
.
dilation
[
1
]),
'current'
)
elif
not
isinstance
(
module
,
BaseConvRFSearchOp
):
elif
not
isinstance
(
module
,
BaseConvRFSearchOp
):
self
.
set_model
(
module
,
search_op
,
init_rates
,
fullname
)
self
.
set_model
(
module
,
search_op
,
init_rates
,
fullname
)
mmcv/image/geometric.py
View file @
c4588297
...
@@ -440,7 +440,7 @@ def imcrop(
...
@@ -440,7 +440,7 @@ def imcrop(
img (ndarray): Image to be cropped.
img (ndarray): Image to be cropped.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale (float, optional): Scale ratio of bboxes, the default value
scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no
padd
ing.
1.0 means no
scal
ing.
pad_fill (Number | list[Number]): Value to be filled for padding.
pad_fill (Number | list[Number]): Value to be filled for padding.
Default: None, which means no padding.
Default: None, which means no padding.
...
...
tests/test_cnn/test_rfsearch/test_search.py
View file @
c4588297
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
"""Tests the rfsearch with runners.
CommandLine:
pytest tests/test_runner/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
mmcv.cnn.rfsearch
import
Conv2dRFSearchOp
,
RFSearchHook
from
mmcv.cnn.rfsearch
import
Conv2dRFSearchOp
,
RFSearchHook
from
tests.test_runner.test_hooks
import
_build_demo_runner
def
test_rfsearchhook
():
def
test_rfsearchhook
():
...
@@ -114,20 +105,6 @@ def test_rfsearchhook():
...
@@ -114,20 +105,6 @@ def test_rfsearchhook():
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
# 1. test step() with mode of search
loader
=
DataLoader
(
torch
.
ones
((
1
,
1
,
1
,
1
)))
runner
=
_build_demo_runner
()
runner
.
model
=
model
runner
.
register_hook
(
rfsearchhook_search
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
test_skip_layer
()
assert
not
isinstance
(
model
.
conv1
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv2
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
3
)]
# 2. test init_model() with mode of fixed_single_branch
# 2. test init_model() with mode of fixed_single_branch
model
=
Model
()
model
=
Model
()
rfsearchhook_fixed_single_branch
.
init_model
(
model
)
rfsearchhook_fixed_single_branch
.
init_model
(
model
)
...
@@ -139,19 +116,6 @@ def test_rfsearchhook():
...
@@ -139,19 +116,6 @@ def test_rfsearchhook():
assert
model
.
conv2
.
dilation
==
(
2
,
2
)
assert
model
.
conv2
.
dilation
==
(
2
,
2
)
assert
model
.
conv3
.
dilation
==
(
1
,
1
)
assert
model
.
conv3
.
dilation
==
(
1
,
1
)
# 2. test step() with mode of fixed_single_branch
runner
=
_build_demo_runner
()
runner
.
model
=
model
runner
.
register_hook
(
rfsearchhook_fixed_single_branch
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
assert
not
isinstance
(
model
.
conv1
,
Conv2dRFSearchOp
)
assert
not
isinstance
(
model
.
conv2
,
Conv2dRFSearchOp
)
assert
not
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv1
.
dilation
==
(
1
,
1
)
assert
model
.
conv2
.
dilation
==
(
2
,
2
)
assert
model
.
conv3
.
dilation
==
(
1
,
1
)
# 3. test init_model() with mode of fixed_multi_branch
# 3. test init_model() with mode of fixed_multi_branch
model
=
Model
()
model
=
Model
()
rfsearchhook_fixed_multi_branch
.
init_model
(
model
)
rfsearchhook_fixed_multi_branch
.
init_model
(
model
)
...
@@ -162,16 +126,3 @@ def test_rfsearchhook():
...
@@ -162,16 +126,3 @@ def test_rfsearchhook():
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
# 3. test step() with mode of fixed_single_branch
runner
=
_build_demo_runner
()
runner
.
model
=
model
runner
.
register_hook
(
rfsearchhook_fixed_multi_branch
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
test_skip_layer
()
assert
not
isinstance
(
model
.
conv1
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv2
,
Conv2dRFSearchOp
)
assert
isinstance
(
model
.
conv3
,
Conv2dRFSearchOp
)
assert
model
.
conv2
.
dilation_rates
==
[(
1
,
1
),
(
3
,
3
)]
assert
model
.
conv3
.
dilation_rates
==
[(
1
,
1
),
(
1
,
2
)]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment