Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
d95727b2
Unverified
Commit
d95727b2
authored
Jun 22, 2019
by
Kai Chen
Committed by
GitHub
Jun 22, 2019
Browse files
add a field to support the evaluation interval (#849)
parent
e4917130
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
6 deletions
+13
-6
configs/mask_rcnn_r50_fpn_1x.py
configs/mask_rcnn_r50_fpn_1x.py
+1
-0
mmdet/apis/train.py
mmdet/apis/train.py
+9
-5
mmdet/core/evaluation/eval_hooks.py
mmdet/core/evaluation/eval_hooks.py
+3
-1
No files found.
configs/mask_rcnn_r50_fpn_1x.py
View file @
d95727b2
...
@@ -171,6 +171,7 @@ log_config = dict(
...
@@ -171,6 +171,7 @@ log_config = dict(
# dict(type='TensorboardLoggerHook')
# dict(type='TensorboardLoggerHook')
])
])
# yapf:enable
# yapf:enable
evaluation
=
dict
(
interval
=
1
)
# runtime settings
# runtime settings
total_epochs
=
12
total_epochs
=
12
dist_params
=
dict
(
backend
=
'nccl'
)
dist_params
=
dict
(
backend
=
'nccl'
)
...
...
mmdet/apis/train.py
View file @
d95727b2
...
@@ -91,8 +91,8 @@ def build_optimizer(model, optimizer_cfg):
...
@@ -91,8 +91,8 @@ def build_optimizer(model, optimizer_cfg):
paramwise_options
=
optimizer_cfg
.
pop
(
'paramwise_options'
,
None
)
paramwise_options
=
optimizer_cfg
.
pop
(
'paramwise_options'
,
None
)
# if no paramwise option is specified, just use the global setting
# if no paramwise option is specified, just use the global setting
if
paramwise_options
is
None
:
if
paramwise_options
is
None
:
return
obj_from_dict
(
return
obj_from_dict
(
optimizer_cfg
,
torch
.
optim
,
optimizer_cfg
,
torch
.
optim
,
dict
(
params
=
model
.
parameters
()))
dict
(
params
=
model
.
parameters
()))
else
:
else
:
assert
isinstance
(
paramwise_options
,
dict
)
assert
isinstance
(
paramwise_options
,
dict
)
# get base lr and weight decay
# get base lr and weight decay
...
@@ -154,15 +154,19 @@ def _dist_train(model, dataset, cfg, validate=False):
...
@@ -154,15 +154,19 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
# register eval hooks
if
validate
:
if
validate
:
val_dataset_cfg
=
cfg
.
data
.
val
val_dataset_cfg
=
cfg
.
data
.
val
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
if
isinstance
(
model
.
module
,
RPN
):
if
isinstance
(
model
.
module
,
RPN
):
# TODO: implement recall hooks for other datasets
# TODO: implement recall hooks for other datasets
runner
.
register_hook
(
CocoDistEvalRecallHook
(
val_dataset_cfg
))
runner
.
register_hook
(
CocoDistEvalRecallHook
(
val_dataset_cfg
,
**
eval_cfg
))
else
:
else
:
dataset_type
=
getattr
(
datasets
,
val_dataset_cfg
.
type
)
dataset_type
=
getattr
(
datasets
,
val_dataset_cfg
.
type
)
if
issubclass
(
dataset_type
,
datasets
.
CocoDataset
):
if
issubclass
(
dataset_type
,
datasets
.
CocoDataset
):
runner
.
register_hook
(
CocoDistEvalmAPHook
(
val_dataset_cfg
))
runner
.
register_hook
(
CocoDistEvalmAPHook
(
val_dataset_cfg
,
**
eval_cfg
))
else
:
else
:
runner
.
register_hook
(
DistEvalmAPHook
(
val_dataset_cfg
))
runner
.
register_hook
(
DistEvalmAPHook
(
val_dataset_cfg
,
**
eval_cfg
))
if
cfg
.
resume_from
:
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
runner
.
resume
(
cfg
.
resume_from
)
...
...
mmdet/core/evaluation/eval_hooks.py
View file @
d95727b2
...
@@ -116,9 +116,11 @@ class CocoDistEvalRecallHook(DistEvalHook):
...
@@ -116,9 +116,11 @@ class CocoDistEvalRecallHook(DistEvalHook):
def
__init__
(
self
,
def
__init__
(
self
,
dataset
,
dataset
,
interval
=
1
,
proposal_nums
=
(
100
,
300
,
1000
),
proposal_nums
=
(
100
,
300
,
1000
),
iou_thrs
=
np
.
arange
(
0.5
,
0.96
,
0.05
)):
iou_thrs
=
np
.
arange
(
0.5
,
0.96
,
0.05
)):
super
(
CocoDistEvalRecallHook
,
self
).
__init__
(
dataset
)
super
(
CocoDistEvalRecallHook
,
self
).
__init__
(
dataset
,
interval
=
interval
)
self
.
proposal_nums
=
np
.
array
(
proposal_nums
,
dtype
=
np
.
int32
)
self
.
proposal_nums
=
np
.
array
(
proposal_nums
,
dtype
=
np
.
int32
)
self
.
iou_thrs
=
np
.
array
(
iou_thrs
,
dtype
=
np
.
float32
)
self
.
iou_thrs
=
np
.
array
(
iou_thrs
,
dtype
=
np
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment