Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
60d5bc54
"...text-generation-inference.git" did not exist on "f7f61876cff78d934b96cb80a8b312d5f9600802"
Commit
60d5bc54
authored
Dec 12, 2018
by
Kai Chen
Browse files
add eval hooks for VOC dataset
parent
9d38a278
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
6 deletions
+49
-6
mmdet/apis/train.py
mmdet/apis/train.py
+8
-4
mmdet/core/evaluation/__init__.py
mmdet/core/evaluation/__init__.py
+2
-2
mmdet/core/evaluation/eval_hooks.py
mmdet/core/evaluation/eval_hooks.py
+39
-0
No files found.
mmdet/apis/train.py
View file @
60d5bc54
...
@@ -6,8 +6,8 @@ import torch
...
@@ -6,8 +6,8 @@ import torch
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet.core
import
(
DistOptimizerHook
,
Coco
DistEval
Recall
Hook
,
from
mmdet.core
import
(
DistOptimizerHook
,
DistEval
mAP
Hook
,
CocoDistEvalmAPHook
)
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
RPN
from
mmdet.models
import
RPN
from
.env
import
get_root_logger
from
.env
import
get_root_logger
...
@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
...
@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
# register eval hooks
if
validate
:
if
validate
:
if
isinstance
(
model
.
module
,
RPN
):
if
isinstance
(
model
.
module
,
RPN
):
# TODO: implement recall hooks for other datasets
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
else
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
else
:
runner
.
register_hook
(
DistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
runner
.
resume
(
cfg
.
resume_from
)
...
...
mmdet/core/evaluation/__init__.py
View file @
60d5bc54
...
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
...
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes
,
coco_classes
,
dataset_aliases
,
imagenet_vid_classes
,
coco_classes
,
dataset_aliases
,
get_classes
)
get_classes
)
from
.coco_utils
import
coco_eval
,
fast_eval_recall
,
results2json
from
.coco_utils
import
coco_eval
,
fast_eval_recall
,
results2json
from
.eval_hooks
import
(
DistEvalHook
,
CocoDistEvalRecallHook
,
from
.eval_hooks
import
(
DistEvalHook
,
DistEvalmAPHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
CocoDistEvalmAPHook
)
from
.mean_ap
import
average_precision
,
eval_map
,
print_map_summary
from
.mean_ap
import
average_precision
,
eval_map
,
print_map_summary
from
.recall
import
(
eval_recalls
,
print_recall_summary
,
plot_num_recall
,
from
.recall
import
(
eval_recalls
,
print_recall_summary
,
plot_num_recall
,
...
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
...
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__
=
[
__all__
=
[
'voc_classes'
,
'imagenet_det_classes'
,
'imagenet_vid_classes'
,
'voc_classes'
,
'imagenet_det_classes'
,
'imagenet_vid_classes'
,
'coco_classes'
,
'dataset_aliases'
,
'get_classes'
,
'coco_eval'
,
'coco_classes'
,
'dataset_aliases'
,
'get_classes'
,
'coco_eval'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'DistEvalmAPHook'
,
'CocoDistEvalRecallHook'
,
'CocoDistEvalmAPHook'
,
'average_precision'
,
'CocoDistEvalRecallHook'
,
'CocoDistEvalmAPHook'
,
'average_precision'
,
'eval_map'
,
'print_map_summary'
,
'eval_recalls'
,
'print_recall_summary'
,
'eval_map'
,
'print_map_summary'
,
'eval_recalls'
,
'print_recall_summary'
,
'plot_num_recall'
,
'plot_iou_recall'
'plot_num_recall'
,
'plot_iou_recall'
...
...
mmdet/core/evaluation/eval_hooks.py
View file @
60d5bc54
...
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
...
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.mean_ap
import
eval_map
from
mmdet
import
datasets
from
mmdet
import
datasets
...
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
...
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise
NotImplementedError
raise
NotImplementedError
class
DistEvalmAPHook
(
DistEvalHook
):
def
evaluate
(
self
,
runner
,
results
):
gt_bboxes
=
[]
gt_labels
=
[]
gt_ignore
=
[]
if
self
.
dataset
.
with_crowd
else
None
for
i
in
range
(
len
(
self
.
dataset
)):
ann
=
self
.
dataset
.
get_ann_info
(
i
)
bboxes
=
ann
[
'bboxes'
]
labels
=
ann
[
'labels'
]
if
gt_ignore
is
not
None
:
ignore
=
np
.
concatenate
([
np
.
zeros
(
bboxes
.
shape
[
0
],
dtype
=
np
.
bool
),
np
.
ones
(
ann
[
'bboxes_ignore'
].
shape
[
0
],
dtype
=
np
.
bool
)
])
gt_ignore
.
append
(
ignore
)
bboxes
=
np
.
vstack
([
bboxes
,
ann
[
'bboxes_ignore'
]])
labels
=
np
.
concatenate
([
labels
,
ann
[
'labels_ignore'
]])
gt_bboxes
.
append
(
bboxes
)
gt_labels
.
append
(
labels
)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if
hasattr
(
self
.
dataset
,
'year'
)
and
self
.
dataset
.
year
==
2007
:
ds_name
=
'voc07'
else
:
ds_name
=
self
.
dataset
.
CLASSES
mean_ap
,
eval_results
=
eval_map
(
results
,
gt_bboxes
,
gt_labels
,
gt_ignore
=
gt_ignore
,
scale_ranges
=
None
,
iou_thr
=
0.5
,
dataset
=
ds_name
,
print_summary
=
True
)
runner
.
log_buffer
.
output
[
'mAP'
]
=
mean_ap
runner
.
log_buffer
.
ready
=
True
class
CocoDistEvalRecallHook
(
DistEvalHook
):
class
CocoDistEvalRecallHook
(
DistEvalHook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment