Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
40325555
Commit
40325555
authored
Dec 12, 2018
by
yhcao6
Browse files
merge new master
parents
cfdd8050
c95c6373
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
201 additions
and
37 deletions
+201
-37
README.md
README.md
+3
-3
mmdet/apis/train.py
mmdet/apis/train.py
+8
-4
mmdet/core/evaluation/__init__.py
mmdet/core/evaluation/__init__.py
+2
-2
mmdet/core/evaluation/class_names.py
mmdet/core/evaluation/class_names.py
+8
-8
mmdet/core/evaluation/eval_hooks.py
mmdet/core/evaluation/eval_hooks.py
+39
-0
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+6
-3
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+15
-0
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+8
-6
mmdet/datasets/custom.py
mmdet/datasets/custom.py
+5
-2
mmdet/datasets/repeat_dataset.py
mmdet/datasets/repeat_dataset.py
+5
-3
mmdet/datasets/voc.py
mmdet/datasets/voc.py
+18
-0
mmdet/datasets/xml_style.py
mmdet/datasets/xml_style.py
+76
-0
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+4
-3
tools/test.py
tools/test.py
+4
-3
No files found.
README.md
View file @
40325555
...
...
@@ -194,7 +194,7 @@ Here is an example.
'bboxes': <np.ndarray> (n, 4),
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4),
'labels_ignore': <np.ndarray> (k,
4
) (optional field)
'labels_ignore': <np.ndarray> (k, ) (optional field)
}
},
...
...
...
@@ -206,12 +206,12 @@ There are two ways to work with custom datasets.
-
online conversion
You can write a new Dataset class inherited from
`CustomDataset`
, and overwrite two methods
`load_annotations(self, ann_file)`
and
`get_ann_info(self, idx)`
, like
[
CocoDataset
](
mmdet/datasets/coco.py
)
.
`load_annotations(self, ann_file)`
and
`get_ann_info(self, idx)`
, like
[
CocoDataset
](
mmdet/datasets/coco.py
)
and
[
VOCDataset
](
mmdet/datasets/voc.py
)
.
-
offline conversion
You can convert the annotation format to the expected format above and save it to
a pickle file, like
[
pascal_voc.py
](
tools/convert_datasets/pascal_voc.py
)
.
a pickle
or json
file, like
[
pascal_voc.py
](
tools/convert_datasets/pascal_voc.py
)
.
Then you can simply use
`CustomDataset`
.
## Technical details
...
...
mmdet/apis/train.py
View file @
40325555
...
...
@@ -6,8 +6,8 @@ import torch
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet.core
import
(
DistOptimizerHook
,
Coco
DistEval
Recall
Hook
,
CocoDistEvalmAPHook
)
from
mmdet.core
import
(
DistOptimizerHook
,
DistEval
mAP
Hook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
RPN
from
.env
import
get_root_logger
...
...
@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
if
validate
:
if
isinstance
(
model
.
module
,
RPN
):
# TODO: implement recall hooks for other datasets
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
else
:
if
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
else
:
runner
.
register_hook
(
DistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
...
...
mmdet/core/evaluation/__init__.py
View file @
40325555
...
...
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes
,
coco_classes
,
dataset_aliases
,
get_classes
)
from
.coco_utils
import
coco_eval
,
fast_eval_recall
,
results2json
from
.eval_hooks
import
(
DistEvalHook
,
CocoDistEvalRecallHook
,
from
.eval_hooks
import
(
DistEvalHook
,
DistEvalmAPHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
.mean_ap
import
average_precision
,
eval_map
,
print_map_summary
from
.recall
import
(
eval_recalls
,
print_recall_summary
,
plot_num_recall
,
...
...
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__
=
[
'voc_classes'
,
'imagenet_det_classes'
,
'imagenet_vid_classes'
,
'coco_classes'
,
'dataset_aliases'
,
'get_classes'
,
'coco_eval'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'DistEvalmAPHook'
,
'CocoDistEvalRecallHook'
,
'CocoDistEvalmAPHook'
,
'average_precision'
,
'eval_map'
,
'print_map_summary'
,
'eval_recalls'
,
'print_recall_summary'
,
'plot_num_recall'
,
'plot_iou_recall'
...
...
mmdet/core/evaluation/class_names.py
View file @
40325555
...
...
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def
coco_classes
():
return
[
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic
light'
,
'fire
hydrant'
,
'stop
sign'
,
'parking
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'truck'
,
'boat'
,
'traffic
_
light'
,
'fire
_
hydrant'
,
'stop
_
sign'
,
'parking
_
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports
ball'
,
'kite'
,
'baseball
bat'
,
'baseball
glove'
,
'skateboard'
,
'surfboard'
,
'tennis
racket'
,
'bottle'
,
'wine
glass'
,
'cup'
,
'fork'
,
'sports
_
ball'
,
'kite'
,
'baseball
_
bat'
,
'baseball
_
glove'
,
'skateboard'
,
'surfboard'
,
'tennis
_
racket'
,
'bottle'
,
'wine
_
glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted
plant'
,
'bed'
,
'dining
table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
phone'
,
'microwave'
,
'broccoli'
,
'carrot'
,
'hot
_
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted
_
plant'
,
'bed'
,
'dining
_
table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
_
phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy
bear'
,
'hair
drier'
,
'toothbrush'
'scissors'
,
'teddy
_
bear'
,
'hair
_
drier'
,
'toothbrush'
]
...
...
mmdet/core/evaluation/eval_hooks.py
View file @
40325555
...
...
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from
torch.utils.data
import
Dataset
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.mean_ap
import
eval_map
from
mmdet
import
datasets
...
...
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise
NotImplementedError
class
DistEvalmAPHook
(
DistEvalHook
):
def
evaluate
(
self
,
runner
,
results
):
gt_bboxes
=
[]
gt_labels
=
[]
gt_ignore
=
[]
if
self
.
dataset
.
with_crowd
else
None
for
i
in
range
(
len
(
self
.
dataset
)):
ann
=
self
.
dataset
.
get_ann_info
(
i
)
bboxes
=
ann
[
'bboxes'
]
labels
=
ann
[
'labels'
]
if
gt_ignore
is
not
None
:
ignore
=
np
.
concatenate
([
np
.
zeros
(
bboxes
.
shape
[
0
],
dtype
=
np
.
bool
),
np
.
ones
(
ann
[
'bboxes_ignore'
].
shape
[
0
],
dtype
=
np
.
bool
)
])
gt_ignore
.
append
(
ignore
)
bboxes
=
np
.
vstack
([
bboxes
,
ann
[
'bboxes_ignore'
]])
labels
=
np
.
concatenate
([
labels
,
ann
[
'labels_ignore'
]])
gt_bboxes
.
append
(
bboxes
)
gt_labels
.
append
(
labels
)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if
hasattr
(
self
.
dataset
,
'year'
)
and
self
.
dataset
.
year
==
2007
:
ds_name
=
'voc07'
else
:
ds_name
=
self
.
dataset
.
CLASSES
mean_ap
,
eval_results
=
eval_map
(
results
,
gt_bboxes
,
gt_labels
,
gt_ignore
=
gt_ignore
,
scale_ranges
=
None
,
iou_thr
=
0.5
,
dataset
=
ds_name
,
print_summary
=
True
)
runner
.
log_buffer
.
output
[
'mAP'
]
=
mean_ap
runner
.
log_buffer
.
ready
=
True
class
CocoDistEvalRecallHook
(
DistEvalHook
):
def
__init__
(
self
,
...
...
mmdet/datasets/__init__.py
View file @
40325555
from
.custom
import
CustomDataset
from
.xml_style
import
XMLDataset
from
.coco
import
CocoDataset
from
.voc
import
VOCDataset
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.utils
import
to_tensor
,
random_scale
,
show_ann
,
get_dataset
from
.concat_dataset
import
ConcatDataset
...
...
@@ -7,7 +9,8 @@ from .repeat_dataset import RepeatDataset
from
.extra_aug
import
ExtraAugmentation
__all__
=
[
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'get_dataset'
,
'ConcatDataset'
,
'RepeatDataset'
,
'ExtraAugmentation'
'CustomDataset'
,
'XMLDataset'
,
'CocoDataset'
,
'VOCDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'get_dataset'
,
'ConcatDataset'
,
'RepeatDataset'
,
'ExtraAugmentation'
]
mmdet/datasets/coco.py
View file @
40325555
...
...
@@ -6,6 +6,21 @@ from .custom import CustomDataset
class
CocoDataset
(
CustomDataset
):
CLASSES
=
(
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic_light'
,
'fire_hydrant'
,
'stop_sign'
,
'parking_meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports_ball'
,
'kite'
,
'baseball_bat'
,
'baseball_glove'
,
'skateboard'
,
'surfboard'
,
'tennis_racket'
,
'bottle'
,
'wine_glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot_dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted_plant'
,
'bed'
,
'dining_table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell_phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy_bear'
,
'hair_drier'
,
'toothbrush'
)
def
load_annotations
(
self
,
ann_file
):
self
.
coco
=
COCO
(
ann_file
)
self
.
cat_ids
=
self
.
coco
.
getCatIds
()
...
...
mmdet/datasets/concat_dataset.py
View file @
40325555
...
...
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class
ConcatDataset
(
_ConcatDataset
):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def
__init__
(
self
,
datasets
):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
self
.
CLASSES
=
datasets
[
0
].
CLASSES
if
hasattr
(
datasets
[
0
],
'flag'
):
flags
=
[]
for
i
in
range
(
0
,
len
(
datasets
)):
...
...
mmdet/datasets/custom.py
View file @
40325555
...
...
@@ -33,6 +33,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing.
"""
CLASSES
=
None
def
__init__
(
self
,
ann_file
,
img_prefix
,
...
...
@@ -48,6 +50,9 @@ class CustomDataset(Dataset):
test_mode
=
False
,
extra_aug
=
None
,
resize_keep_ratio
=
True
):
# prefix of images path
self
.
img_prefix
=
img_prefix
# load annotations (and proposals)
self
.
img_infos
=
self
.
load_annotations
(
ann_file
)
if
proposal_file
is
not
None
:
...
...
@@ -61,8 +66,6 @@ class CustomDataset(Dataset):
if
self
.
proposals
is
not
None
:
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
# prefix of images path
self
.
img_prefix
=
img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self
.
img_scales
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
...
...
mmdet/datasets/repeat_dataset.py
View file @
40325555
...
...
@@ -6,12 +6,14 @@ class RepeatDataset(object):
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
times
=
times
self
.
CLASSES
=
dataset
.
CLASSES
if
hasattr
(
self
.
dataset
,
'flag'
):
self
.
flag
=
np
.
tile
(
self
.
dataset
.
flag
,
times
)
self
.
_original_length
=
len
(
self
.
dataset
)
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
%
self
.
_ori
ginal
_len
gth
]
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
__len__
(
self
):
return
self
.
times
*
self
.
_ori
ginal
_len
gth
return
self
.
times
*
self
.
_ori_len
mmdet/datasets/voc.py
0 → 100644
View file @
40325555
from
.xml_style
import
XMLDataset
class
VOCDataset
(
XMLDataset
):
CLASSES
=
(
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
'bottle'
,
'bus'
,
'car'
,
'cat'
,
'chair'
,
'cow'
,
'diningtable'
,
'dog'
,
'horse'
,
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
)
def
__init__
(
self
,
**
kwargs
):
super
(
VOCDataset
,
self
).
__init__
(
**
kwargs
)
if
'VOC2007'
in
self
.
img_prefix
:
self
.
year
=
2007
elif
'VOC2012'
in
self
.
img_prefix
:
self
.
year
=
2012
else
:
raise
ValueError
(
'Cannot infer dataset year from img_prefix'
)
mmdet/datasets/xml_style.py
0 → 100644
View file @
40325555
import
os.path
as
osp
import
xml.etree.ElementTree
as
ET
import
mmcv
import
numpy
as
np
from
.custom
import
CustomDataset
class
XMLDataset
(
CustomDataset
):
def
__init__
(
self
,
**
kwargs
):
super
(
XMLDataset
,
self
).
__init__
(
**
kwargs
)
self
.
cat2label
=
{
cat
:
i
+
1
for
i
,
cat
in
enumerate
(
self
.
CLASSES
)}
def
load_annotations
(
self
,
ann_file
):
img_infos
=
[]
img_ids
=
mmcv
.
list_from_file
(
ann_file
)
for
img_id
in
img_ids
:
filename
=
'JPEGImages/{}.jpg'
.
format
(
img_id
)
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
'{}.xml'
.
format
(
img_id
))
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
size
=
root
.
find
(
'size'
)
width
=
int
(
size
.
find
(
'width'
).
text
)
height
=
int
(
size
.
find
(
'height'
).
text
)
img_infos
.
append
(
dict
(
id
=
img_id
,
filename
=
filename
,
width
=
width
,
height
=
height
))
return
img_infos
def
get_ann_info
(
self
,
idx
):
img_id
=
self
.
img_infos
[
idx
][
'id'
]
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
'{}.xml'
.
format
(
img_id
))
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
bboxes
=
[]
labels
=
[]
bboxes_ignore
=
[]
labels_ignore
=
[]
for
obj
in
root
.
findall
(
'object'
):
name
=
obj
.
find
(
'name'
).
text
label
=
self
.
cat2label
[
name
]
difficult
=
int
(
obj
.
find
(
'difficult'
).
text
)
bnd_box
=
obj
.
find
(
'bndbox'
)
bbox
=
[
int
(
bnd_box
.
find
(
'xmin'
).
text
),
int
(
bnd_box
.
find
(
'ymin'
).
text
),
int
(
bnd_box
.
find
(
'xmax'
).
text
),
int
(
bnd_box
.
find
(
'ymax'
).
text
)
]
if
difficult
:
bboxes_ignore
.
append
(
bbox
)
labels_ignore
.
append
(
label
)
else
:
bboxes
.
append
(
bbox
)
labels
.
append
(
label
)
if
not
bboxes
:
bboxes
=
np
.
zeros
((
0
,
4
))
labels
=
np
.
zeros
((
0
,
))
else
:
bboxes
=
np
.
array
(
bboxes
,
ndmin
=
2
)
-
1
labels
=
np
.
array
(
labels
)
if
not
bboxes_ignore
:
bboxes_ignore
=
np
.
zeros
((
0
,
4
))
labels_ignore
=
np
.
zeros
((
0
,
))
else
:
bboxes_ignore
=
np
.
array
(
bboxes_ignore
,
ndmin
=
2
)
-
1
labels_ignore
=
np
.
array
(
labels_ignore
)
ann
=
dict
(
bboxes
=
bboxes
.
astype
(
np
.
float32
),
labels
=
labels
.
astype
(
np
.
int64
),
bboxes_ignore
=
bboxes_ignore
.
astype
(
np
.
float32
),
labels_ignore
=
labels_ignore
.
astype
(
np
.
int64
))
return
ann
mmdet/models/detectors/base.py
View file @
40325555
...
...
@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
if
isinstance
(
dataset
,
str
):
class_names
=
get_classes
(
dataset
)
elif
isinstance
(
dataset
,
list
)
:
elif
isinstance
(
dataset
,
(
list
,
tuple
))
or
dataset
is
None
:
class_names
=
dataset
else
:
raise
TypeError
(
'dataset must be a valid dataset name or a list'
' of class names, not {}'
.
format
(
type
(
dataset
)))
raise
TypeError
(
'dataset must be a valid dataset name or a sequence'
' of class names, not {}'
.
format
(
type
(
dataset
)))
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
...
...
tools/test.py
View file @
40325555
...
...
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
def
single_test
(
model
,
data_loader
,
show
=
False
):
model
.
eval
()
results
=
[]
prog_bar
=
mmcv
.
ProgressBar
(
len
(
data_loader
.
dataset
))
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
results
.
append
(
result
)
if
show
:
model
.
module
.
show_result
(
data
,
result
,
data
_loader
.
dataset
.
img_norm_cfg
)
model
.
module
.
show_result
(
data
,
result
,
dataset
.
img_norm_cfg
,
data
set
.
CLASSES
)
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
for
_
in
range
(
batch_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment