Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
40325555
Commit
40325555
authored
Dec 12, 2018
by
yhcao6
Browse files
merge new master
parents
cfdd8050
c95c6373
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
201 additions
and
37 deletions
+201
-37
README.md
README.md
+3
-3
mmdet/apis/train.py
mmdet/apis/train.py
+8
-4
mmdet/core/evaluation/__init__.py
mmdet/core/evaluation/__init__.py
+2
-2
mmdet/core/evaluation/class_names.py
mmdet/core/evaluation/class_names.py
+8
-8
mmdet/core/evaluation/eval_hooks.py
mmdet/core/evaluation/eval_hooks.py
+39
-0
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+6
-3
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+15
-0
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+8
-6
mmdet/datasets/custom.py
mmdet/datasets/custom.py
+5
-2
mmdet/datasets/repeat_dataset.py
mmdet/datasets/repeat_dataset.py
+5
-3
mmdet/datasets/voc.py
mmdet/datasets/voc.py
+18
-0
mmdet/datasets/xml_style.py
mmdet/datasets/xml_style.py
+76
-0
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+4
-3
tools/test.py
tools/test.py
+4
-3
No files found.
README.md
View file @
40325555
...
@@ -194,7 +194,7 @@ Here is an example.
...
@@ -194,7 +194,7 @@ Here is an example.
'bboxes': <np.ndarray> (n, 4),
'bboxes': <np.ndarray> (n, 4),
'labels': <np.ndarray> (n, ),
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4),
'bboxes_ignore': <np.ndarray> (k, 4),
'labels_ignore': <np.ndarray> (k,
4
) (optional field)
'labels_ignore': <np.ndarray> (k, ) (optional field)
}
}
},
},
...
...
...
@@ -206,12 +206,12 @@ There are two ways to work with custom datasets.
...
@@ -206,12 +206,12 @@ There are two ways to work with custom datasets.
-
online conversion
-
online conversion
You can write a new Dataset class inherited from
`CustomDataset`
, and overwrite two methods
You can write a new Dataset class inherited from
`CustomDataset`
, and overwrite two methods
`load_annotations(self, ann_file)`
and
`get_ann_info(self, idx)`
, like
[
CocoDataset
](
mmdet/datasets/coco.py
)
.
`load_annotations(self, ann_file)`
and
`get_ann_info(self, idx)`
, like
[
CocoDataset
](
mmdet/datasets/coco.py
)
and
[
VOCDataset
](
mmdet/datasets/voc.py
)
.
-
offline conversion
-
offline conversion
You can convert the annotation format to the expected format above and save it to
You can convert the annotation format to the expected format above and save it to
a pickle file, like
[
pascal_voc.py
](
tools/convert_datasets/pascal_voc.py
)
.
a pickle
or json
file, like
[
pascal_voc.py
](
tools/convert_datasets/pascal_voc.py
)
.
Then you can simply use
`CustomDataset`
.
Then you can simply use
`CustomDataset`
.
## Technical details
## Technical details
...
...
mmdet/apis/train.py
View file @
40325555
...
@@ -6,8 +6,8 @@ import torch
...
@@ -6,8 +6,8 @@ import torch
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.runner
import
Runner
,
DistSamplerSeedHook
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmdet.core
import
(
DistOptimizerHook
,
Coco
DistEval
Recall
Hook
,
from
mmdet.core
import
(
DistOptimizerHook
,
DistEval
mAP
Hook
,
CocoDistEvalmAPHook
)
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
from
mmdet.datasets
import
build_dataloader
from
mmdet.datasets
import
build_dataloader
from
mmdet.models
import
RPN
from
mmdet.models
import
RPN
from
.env
import
get_root_logger
from
.env
import
get_root_logger
...
@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
...
@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
# register eval hooks
if
validate
:
if
validate
:
if
isinstance
(
model
.
module
,
RPN
):
if
isinstance
(
model
.
module
,
RPN
):
# TODO: implement recall hooks for other datasets
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
runner
.
register_hook
(
CocoDistEvalRecallHook
(
cfg
.
data
.
val
))
elif
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
else
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
data
.
val
.
type
==
'CocoDataset'
:
runner
.
register_hook
(
CocoDistEvalmAPHook
(
cfg
.
data
.
val
))
else
:
runner
.
register_hook
(
DistEvalmAPHook
(
cfg
.
data
.
val
))
if
cfg
.
resume_from
:
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
runner
.
resume
(
cfg
.
resume_from
)
...
...
mmdet/core/evaluation/__init__.py
View file @
40325555
...
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
...
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes
,
coco_classes
,
dataset_aliases
,
imagenet_vid_classes
,
coco_classes
,
dataset_aliases
,
get_classes
)
get_classes
)
from
.coco_utils
import
coco_eval
,
fast_eval_recall
,
results2json
from
.coco_utils
import
coco_eval
,
fast_eval_recall
,
results2json
from
.eval_hooks
import
(
DistEvalHook
,
CocoDistEvalRecallHook
,
from
.eval_hooks
import
(
DistEvalHook
,
DistEvalmAPHook
,
CocoDistEvalRecallHook
,
CocoDistEvalmAPHook
)
CocoDistEvalmAPHook
)
from
.mean_ap
import
average_precision
,
eval_map
,
print_map_summary
from
.mean_ap
import
average_precision
,
eval_map
,
print_map_summary
from
.recall
import
(
eval_recalls
,
print_recall_summary
,
plot_num_recall
,
from
.recall
import
(
eval_recalls
,
print_recall_summary
,
plot_num_recall
,
...
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
...
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__
=
[
__all__
=
[
'voc_classes'
,
'imagenet_det_classes'
,
'imagenet_vid_classes'
,
'voc_classes'
,
'imagenet_det_classes'
,
'imagenet_vid_classes'
,
'coco_classes'
,
'dataset_aliases'
,
'get_classes'
,
'coco_eval'
,
'coco_classes'
,
'dataset_aliases'
,
'get_classes'
,
'coco_eval'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'fast_eval_recall'
,
'results2json'
,
'DistEvalHook'
,
'DistEvalmAPHook'
,
'CocoDistEvalRecallHook'
,
'CocoDistEvalmAPHook'
,
'average_precision'
,
'CocoDistEvalRecallHook'
,
'CocoDistEvalmAPHook'
,
'average_precision'
,
'eval_map'
,
'print_map_summary'
,
'eval_recalls'
,
'print_recall_summary'
,
'eval_map'
,
'print_map_summary'
,
'eval_recalls'
,
'print_recall_summary'
,
'plot_num_recall'
,
'plot_iou_recall'
'plot_num_recall'
,
'plot_iou_recall'
...
...
mmdet/core/evaluation/class_names.py
View file @
40325555
...
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
...
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def
coco_classes
():
def
coco_classes
():
return
[
return
[
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic
light'
,
'fire
hydrant'
,
'stop
sign'
,
'truck'
,
'boat'
,
'traffic
_
light'
,
'fire
_
hydrant'
,
'stop
_
sign'
,
'parking
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'parking
_
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports
ball'
,
'kite'
,
'baseball
bat'
,
'baseball
glove'
,
'skateboard'
,
'sports
_
ball'
,
'kite'
,
'baseball
_
bat'
,
'baseball
_
glove'
,
'skateboard'
,
'surfboard'
,
'tennis
racket'
,
'bottle'
,
'wine
glass'
,
'cup'
,
'fork'
,
'surfboard'
,
'tennis
_
racket'
,
'bottle'
,
'wine
_
glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'broccoli'
,
'carrot'
,
'hot
_
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted
plant'
,
'bed'
,
'dining
table'
,
'toilet'
,
'tv'
,
'couch'
,
'potted
_
plant'
,
'bed'
,
'dining
_
table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
phone'
,
'microwave'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
_
phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy
bear'
,
'hair
drier'
,
'toothbrush'
'scissors'
,
'teddy
_
bear'
,
'hair
_
drier'
,
'toothbrush'
]
]
...
...
mmdet/core/evaluation/eval_hooks.py
View file @
40325555
...
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
...
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.coco_utils
import
results2json
,
fast_eval_recall
from
.mean_ap
import
eval_map
from
mmdet
import
datasets
from
mmdet
import
datasets
...
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
...
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise
NotImplementedError
raise
NotImplementedError
class
DistEvalmAPHook
(
DistEvalHook
):
def
evaluate
(
self
,
runner
,
results
):
gt_bboxes
=
[]
gt_labels
=
[]
gt_ignore
=
[]
if
self
.
dataset
.
with_crowd
else
None
for
i
in
range
(
len
(
self
.
dataset
)):
ann
=
self
.
dataset
.
get_ann_info
(
i
)
bboxes
=
ann
[
'bboxes'
]
labels
=
ann
[
'labels'
]
if
gt_ignore
is
not
None
:
ignore
=
np
.
concatenate
([
np
.
zeros
(
bboxes
.
shape
[
0
],
dtype
=
np
.
bool
),
np
.
ones
(
ann
[
'bboxes_ignore'
].
shape
[
0
],
dtype
=
np
.
bool
)
])
gt_ignore
.
append
(
ignore
)
bboxes
=
np
.
vstack
([
bboxes
,
ann
[
'bboxes_ignore'
]])
labels
=
np
.
concatenate
([
labels
,
ann
[
'labels_ignore'
]])
gt_bboxes
.
append
(
bboxes
)
gt_labels
.
append
(
labels
)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if
hasattr
(
self
.
dataset
,
'year'
)
and
self
.
dataset
.
year
==
2007
:
ds_name
=
'voc07'
else
:
ds_name
=
self
.
dataset
.
CLASSES
mean_ap
,
eval_results
=
eval_map
(
results
,
gt_bboxes
,
gt_labels
,
gt_ignore
=
gt_ignore
,
scale_ranges
=
None
,
iou_thr
=
0.5
,
dataset
=
ds_name
,
print_summary
=
True
)
runner
.
log_buffer
.
output
[
'mAP'
]
=
mean_ap
runner
.
log_buffer
.
ready
=
True
class
CocoDistEvalRecallHook
(
DistEvalHook
):
class
CocoDistEvalRecallHook
(
DistEvalHook
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/datasets/__init__.py
View file @
40325555
from
.custom
import
CustomDataset
from
.custom
import
CustomDataset
from
.xml_style
import
XMLDataset
from
.coco
import
CocoDataset
from
.coco
import
CocoDataset
from
.voc
import
VOCDataset
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.utils
import
to_tensor
,
random_scale
,
show_ann
,
get_dataset
from
.utils
import
to_tensor
,
random_scale
,
show_ann
,
get_dataset
from
.concat_dataset
import
ConcatDataset
from
.concat_dataset
import
ConcatDataset
...
@@ -7,7 +9,8 @@ from .repeat_dataset import RepeatDataset
...
@@ -7,7 +9,8 @@ from .repeat_dataset import RepeatDataset
from
.extra_aug
import
ExtraAugmentation
from
.extra_aug
import
ExtraAugmentation
__all__
=
[
__all__
=
[
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'CustomDataset'
,
'XMLDataset'
,
'CocoDataset'
,
'VOCDataset'
,
'GroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'DistributedGroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'get_dataset'
,
'ConcatDataset'
,
'RepeatDataset'
,
'ExtraAugmentation'
'show_ann'
,
'get_dataset'
,
'ConcatDataset'
,
'RepeatDataset'
,
'ExtraAugmentation'
]
]
mmdet/datasets/coco.py
View file @
40325555
...
@@ -6,6 +6,21 @@ from .custom import CustomDataset
...
@@ -6,6 +6,21 @@ from .custom import CustomDataset
class
CocoDataset
(
CustomDataset
):
class
CocoDataset
(
CustomDataset
):
CLASSES
=
(
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic_light'
,
'fire_hydrant'
,
'stop_sign'
,
'parking_meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports_ball'
,
'kite'
,
'baseball_bat'
,
'baseball_glove'
,
'skateboard'
,
'surfboard'
,
'tennis_racket'
,
'bottle'
,
'wine_glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot_dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted_plant'
,
'bed'
,
'dining_table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell_phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy_bear'
,
'hair_drier'
,
'toothbrush'
)
def
load_annotations
(
self
,
ann_file
):
def
load_annotations
(
self
,
ann_file
):
self
.
coco
=
COCO
(
ann_file
)
self
.
coco
=
COCO
(
ann_file
)
self
.
cat_ids
=
self
.
coco
.
getCatIds
()
self
.
cat_ids
=
self
.
coco
.
getCatIds
()
...
...
mmdet/datasets/concat_dataset.py
View file @
40325555
...
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
...
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class
ConcatDataset
(
_ConcatDataset
):
class
ConcatDataset
(
_ConcatDataset
):
"""
"""A wrapper of concatenated dataset.
Same as torch.utils.data.dataset.ConcatDataset, but
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
"""
def
__init__
(
self
,
datasets
):
def
__init__
(
self
,
datasets
):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
self
.
CLASSES
=
datasets
[
0
].
CLASSES
if
hasattr
(
datasets
[
0
],
'flag'
):
if
hasattr
(
datasets
[
0
],
'flag'
):
flags
=
[]
flags
=
[]
for
i
in
range
(
0
,
len
(
datasets
)):
for
i
in
range
(
0
,
len
(
datasets
)):
...
...
mmdet/datasets/custom.py
View file @
40325555
...
@@ -33,6 +33,8 @@ class CustomDataset(Dataset):
...
@@ -33,6 +33,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing.
The `ann` field is optional for testing.
"""
"""
CLASSES
=
None
def
__init__
(
self
,
def
__init__
(
self
,
ann_file
,
ann_file
,
img_prefix
,
img_prefix
,
...
@@ -48,6 +50,9 @@ class CustomDataset(Dataset):
...
@@ -48,6 +50,9 @@ class CustomDataset(Dataset):
test_mode
=
False
,
test_mode
=
False
,
extra_aug
=
None
,
extra_aug
=
None
,
resize_keep_ratio
=
True
):
resize_keep_ratio
=
True
):
# prefix of images path
self
.
img_prefix
=
img_prefix
# load annotations (and proposals)
# load annotations (and proposals)
self
.
img_infos
=
self
.
load_annotations
(
ann_file
)
self
.
img_infos
=
self
.
load_annotations
(
ann_file
)
if
proposal_file
is
not
None
:
if
proposal_file
is
not
None
:
...
@@ -61,8 +66,6 @@ class CustomDataset(Dataset):
...
@@ -61,8 +66,6 @@ class CustomDataset(Dataset):
if
self
.
proposals
is
not
None
:
if
self
.
proposals
is
not
None
:
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
# prefix of images path
self
.
img_prefix
=
img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self
.
img_scales
=
img_scale
if
isinstance
(
img_scale
,
self
.
img_scales
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
list
)
else
[
img_scale
]
...
...
mmdet/datasets/repeat_dataset.py
View file @
40325555
...
@@ -6,12 +6,14 @@ class RepeatDataset(object):
...
@@ -6,12 +6,14 @@ class RepeatDataset(object):
def
__init__
(
self
,
dataset
,
times
):
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
times
=
times
self
.
times
=
times
self
.
CLASSES
=
dataset
.
CLASSES
if
hasattr
(
self
.
dataset
,
'flag'
):
if
hasattr
(
self
.
dataset
,
'flag'
):
self
.
flag
=
np
.
tile
(
self
.
dataset
.
flag
,
times
)
self
.
flag
=
np
.
tile
(
self
.
dataset
.
flag
,
times
)
self
.
_original_length
=
len
(
self
.
dataset
)
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
%
self
.
_ori
ginal
_len
gth
]
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
times
*
self
.
_ori
ginal
_len
gth
return
self
.
times
*
self
.
_ori_len
mmdet/datasets/voc.py
0 → 100644
View file @
40325555
from
.xml_style
import
XMLDataset
class
VOCDataset
(
XMLDataset
):
CLASSES
=
(
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
'bottle'
,
'bus'
,
'car'
,
'cat'
,
'chair'
,
'cow'
,
'diningtable'
,
'dog'
,
'horse'
,
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
)
def
__init__
(
self
,
**
kwargs
):
super
(
VOCDataset
,
self
).
__init__
(
**
kwargs
)
if
'VOC2007'
in
self
.
img_prefix
:
self
.
year
=
2007
elif
'VOC2012'
in
self
.
img_prefix
:
self
.
year
=
2012
else
:
raise
ValueError
(
'Cannot infer dataset year from img_prefix'
)
mmdet/datasets/xml_style.py
0 → 100644
View file @
40325555
import
os.path
as
osp
import
xml.etree.ElementTree
as
ET
import
mmcv
import
numpy
as
np
from
.custom
import
CustomDataset
class
XMLDataset
(
CustomDataset
):
def
__init__
(
self
,
**
kwargs
):
super
(
XMLDataset
,
self
).
__init__
(
**
kwargs
)
self
.
cat2label
=
{
cat
:
i
+
1
for
i
,
cat
in
enumerate
(
self
.
CLASSES
)}
def
load_annotations
(
self
,
ann_file
):
img_infos
=
[]
img_ids
=
mmcv
.
list_from_file
(
ann_file
)
for
img_id
in
img_ids
:
filename
=
'JPEGImages/{}.jpg'
.
format
(
img_id
)
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
'{}.xml'
.
format
(
img_id
))
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
size
=
root
.
find
(
'size'
)
width
=
int
(
size
.
find
(
'width'
).
text
)
height
=
int
(
size
.
find
(
'height'
).
text
)
img_infos
.
append
(
dict
(
id
=
img_id
,
filename
=
filename
,
width
=
width
,
height
=
height
))
return
img_infos
def
get_ann_info
(
self
,
idx
):
img_id
=
self
.
img_infos
[
idx
][
'id'
]
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
'{}.xml'
.
format
(
img_id
))
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
bboxes
=
[]
labels
=
[]
bboxes_ignore
=
[]
labels_ignore
=
[]
for
obj
in
root
.
findall
(
'object'
):
name
=
obj
.
find
(
'name'
).
text
label
=
self
.
cat2label
[
name
]
difficult
=
int
(
obj
.
find
(
'difficult'
).
text
)
bnd_box
=
obj
.
find
(
'bndbox'
)
bbox
=
[
int
(
bnd_box
.
find
(
'xmin'
).
text
),
int
(
bnd_box
.
find
(
'ymin'
).
text
),
int
(
bnd_box
.
find
(
'xmax'
).
text
),
int
(
bnd_box
.
find
(
'ymax'
).
text
)
]
if
difficult
:
bboxes_ignore
.
append
(
bbox
)
labels_ignore
.
append
(
label
)
else
:
bboxes
.
append
(
bbox
)
labels
.
append
(
label
)
if
not
bboxes
:
bboxes
=
np
.
zeros
((
0
,
4
))
labels
=
np
.
zeros
((
0
,
))
else
:
bboxes
=
np
.
array
(
bboxes
,
ndmin
=
2
)
-
1
labels
=
np
.
array
(
labels
)
if
not
bboxes_ignore
:
bboxes_ignore
=
np
.
zeros
((
0
,
4
))
labels_ignore
=
np
.
zeros
((
0
,
))
else
:
bboxes_ignore
=
np
.
array
(
bboxes_ignore
,
ndmin
=
2
)
-
1
labels_ignore
=
np
.
array
(
labels_ignore
)
ann
=
dict
(
bboxes
=
bboxes
.
astype
(
np
.
float32
),
labels
=
labels
.
astype
(
np
.
int64
),
bboxes_ignore
=
bboxes_ignore
.
astype
(
np
.
float32
),
labels_ignore
=
labels_ignore
.
astype
(
np
.
int64
))
return
ann
mmdet/models/detectors/base.py
View file @
40325555
...
@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
...
@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
if
isinstance
(
dataset
,
str
):
if
isinstance
(
dataset
,
str
):
class_names
=
get_classes
(
dataset
)
class_names
=
get_classes
(
dataset
)
elif
isinstance
(
dataset
,
list
)
:
elif
isinstance
(
dataset
,
(
list
,
tuple
))
or
dataset
is
None
:
class_names
=
dataset
class_names
=
dataset
else
:
else
:
raise
TypeError
(
'dataset must be a valid dataset name or a list'
raise
TypeError
(
' of class names, not {}'
.
format
(
type
(
dataset
)))
'dataset must be a valid dataset name or a sequence'
' of class names, not {}'
.
format
(
type
(
dataset
)))
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
h
,
w
,
_
=
img_meta
[
'img_shape'
]
...
...
tools/test.py
View file @
40325555
...
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
...
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
def
single_test
(
model
,
data_loader
,
show
=
False
):
def
single_test
(
model
,
data_loader
,
show
=
False
):
model
.
eval
()
model
.
eval
()
results
=
[]
results
=
[]
prog_bar
=
mmcv
.
ProgressBar
(
len
(
data_loader
.
dataset
))
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
results
.
append
(
result
)
results
.
append
(
result
)
if
show
:
if
show
:
model
.
module
.
show_result
(
data
,
result
,
model
.
module
.
show_result
(
data
,
result
,
dataset
.
img_norm_cfg
,
data
_loader
.
dataset
.
img_norm_cfg
)
data
set
.
CLASSES
)
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
for
_
in
range
(
batch_size
):
for
_
in
range
(
batch_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment