Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0633f911
Commit
0633f911
authored
Nov 06, 2021
by
zhouzaida
Committed by
Wenwei Zhang
Nov 10, 2021
Browse files
print a warning information when eval_res is an empty dict
parent
f2d11076
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
8 deletions
+30
-8
mmcv/runner/hooks/evaluation.py
mmcv/runner/hooks/evaluation.py
+16
-5
tests/test_runner/test_eval_hook.py
tests/test_runner/test_eval_hook.py
+14
-3
No files found.
mmcv/runner/hooks/evaluation.py
View file @
0633f911
...
@@ -271,7 +271,9 @@ class EvalHook(Hook):
...
@@ -271,7 +271,9 @@ class EvalHook(Hook):
results
=
self
.
test_fn
(
runner
.
model
,
self
.
dataloader
)
results
=
self
.
test_fn
(
runner
.
model
,
self
.
dataloader
)
runner
.
log_buffer
.
output
[
'eval_iter_num'
]
=
len
(
self
.
dataloader
)
runner
.
log_buffer
.
output
[
'eval_iter_num'
]
=
len
(
self
.
dataloader
)
key_score
=
self
.
evaluate
(
runner
,
results
)
key_score
=
self
.
evaluate
(
runner
,
results
)
if
self
.
save_best
:
# the key_score may be `None` so it needs to skip the action to save
# the best checkpoint
if
self
.
save_best
and
key_score
:
self
.
_save_ckpt
(
runner
,
key_score
)
self
.
_save_ckpt
(
runner
,
key_score
)
def
_should_evaluate
(
self
,
runner
):
def
_should_evaluate
(
self
,
runner
):
...
@@ -359,13 +361,21 @@ class EvalHook(Hook):
...
@@ -359,13 +361,21 @@ class EvalHook(Hook):
eval_res
=
self
.
dataloader
.
dataset
.
evaluate
(
eval_res
=
self
.
dataloader
.
dataset
.
evaluate
(
results
,
logger
=
runner
.
logger
,
**
self
.
eval_kwargs
)
results
,
logger
=
runner
.
logger
,
**
self
.
eval_kwargs
)
assert
eval_res
,
'`eval_res` should not be a null dict.'
for
name
,
val
in
eval_res
.
items
():
for
name
,
val
in
eval_res
.
items
():
runner
.
log_buffer
.
output
[
name
]
=
val
runner
.
log_buffer
.
output
[
name
]
=
val
runner
.
log_buffer
.
ready
=
True
runner
.
log_buffer
.
ready
=
True
if
self
.
save_best
is
not
None
:
if
self
.
save_best
is
not
None
:
# If the performance of model is pool, the `eval_res` may be an
# empty dict and it will raise exception when `self.save_best` is
# not None. More details at
# https://github.com/open-mmlab/mmdetection/issues/6265.
if
not
eval_res
:
warnings
.
warn
(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.'
)
return
None
if
self
.
key_indicator
==
'auto'
:
if
self
.
key_indicator
==
'auto'
:
# infer from eval_results
# infer from eval_results
self
.
_init_rule
(
self
.
rule
,
list
(
eval_res
.
keys
())[
0
])
self
.
_init_rule
(
self
.
rule
,
list
(
eval_res
.
keys
())[
0
])
...
@@ -493,6 +503,7 @@ class DistEvalHook(EvalHook):
...
@@ -493,6 +503,7 @@ class DistEvalHook(EvalHook):
print
(
'
\n
'
)
print
(
'
\n
'
)
runner
.
log_buffer
.
output
[
'eval_iter_num'
]
=
len
(
self
.
dataloader
)
runner
.
log_buffer
.
output
[
'eval_iter_num'
]
=
len
(
self
.
dataloader
)
key_score
=
self
.
evaluate
(
runner
,
results
)
key_score
=
self
.
evaluate
(
runner
,
results
)
# the key_score may be `None` so it needs to skip the action to
if
self
.
save_best
:
# save the best checkpoint
if
self
.
save_best
and
key_score
:
self
.
_save_ckpt
(
runner
,
key_score
)
self
.
_save_ckpt
(
runner
,
key_score
)
tests/test_runner/test_eval_hook.py
View file @
0633f911
...
@@ -130,8 +130,9 @@ def test_eval_hook():
...
@@ -130,8 +130,9 @@ def test_eval_hook():
data_loader
=
DataLoader
(
test_dataset
)
data_loader
=
DataLoader
(
test_dataset
)
EvalHook
(
data_loader
,
save_best
=
'auto'
,
rule
=
'unsupport'
)
EvalHook
(
data_loader
,
save_best
=
'auto'
,
rule
=
'unsupport'
)
with
pytest
.
raises
(
AssertionError
):
# if eval_res is an empty dict, print a warning information
# eval_res returned by `dataset.evaluate()` should not be a null dict
with
pytest
.
warns
(
UserWarning
)
as
record_warnings
:
class
_EvalDataset
(
ExampleDataset
):
class
_EvalDataset
(
ExampleDataset
):
def
evaluate
(
self
,
results
,
logger
=
None
):
def
evaluate
(
self
,
results
,
logger
=
None
):
...
@@ -139,10 +140,20 @@ def test_eval_hook():
...
@@ -139,10 +140,20 @@ def test_eval_hook():
test_dataset
=
_EvalDataset
()
test_dataset
=
_EvalDataset
()
data_loader
=
DataLoader
(
test_dataset
)
data_loader
=
DataLoader
(
test_dataset
)
eval_hook
=
EvalHook
(
data_loader
)
eval_hook
=
EvalHook
(
data_loader
,
save_best
=
'auto'
)
runner
=
_build_epoch_runner
()
runner
=
_build_epoch_runner
()
runner
.
register_hook
(
eval_hook
)
runner
.
register_hook
(
eval_hook
)
runner
.
run
([
data_loader
],
[(
'train'
,
1
)],
1
)
runner
.
run
([
data_loader
],
[(
'train'
,
1
)],
1
)
# Since there will be many warnings thrown, we just need to check if the
# expected exceptions are thrown
expected_message
=
(
'Since `eval_res` is an empty dict, the behavior to '
'save the best checkpoint will be skipped in this '
'evaluation.'
)
for
warning
in
record_warnings
:
if
str
(
warning
.
message
)
==
expected_message
:
break
else
:
assert
False
test_dataset
=
ExampleDataset
()
test_dataset
=
ExampleDataset
()
loader
=
DataLoader
(
test_dataset
)
loader
=
DataLoader
(
test_dataset
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment