Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
825cfa0c
Commit
825cfa0c
authored
Dec 11, 2018
by
Kai Chen
Browse files
add class attribute CLASSES to Dataset
parent
826a5613
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
25 deletions
+48
-25
mmdet/core/evaluation/class_names.py
mmdet/core/evaluation/class_names.py
+8
-8
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+15
-0
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+8
-6
mmdet/datasets/custom.py
mmdet/datasets/custom.py
+4
-2
mmdet/datasets/repeat_dataset.py
mmdet/datasets/repeat_dataset.py
+5
-3
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+4
-3
tools/test.py
tools/test.py
+4
-3
No files found.
mmdet/core/evaluation/class_names.py
View file @
825cfa0c
...
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
...
@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def
coco_classes
():
def
coco_classes
():
return
[
return
[
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic
light'
,
'fire
hydrant'
,
'stop
sign'
,
'truck'
,
'boat'
,
'traffic
_
light'
,
'fire
_
hydrant'
,
'stop
_
sign'
,
'parking
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'parking
_
meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports
ball'
,
'kite'
,
'baseball
bat'
,
'baseball
glove'
,
'skateboard'
,
'sports
_
ball'
,
'kite'
,
'baseball
_
bat'
,
'baseball
_
glove'
,
'skateboard'
,
'surfboard'
,
'tennis
racket'
,
'bottle'
,
'wine
glass'
,
'cup'
,
'fork'
,
'surfboard'
,
'tennis
_
racket'
,
'bottle'
,
'wine
_
glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'broccoli'
,
'carrot'
,
'hot
_
dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted
plant'
,
'bed'
,
'dining
table'
,
'toilet'
,
'tv'
,
'couch'
,
'potted
_
plant'
,
'bed'
,
'dining
_
table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
phone'
,
'microwave'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell
_
phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy
bear'
,
'hair
drier'
,
'toothbrush'
'scissors'
,
'teddy
_
bear'
,
'hair
_
drier'
,
'toothbrush'
]
]
...
...
mmdet/datasets/coco.py
View file @
825cfa0c
...
@@ -6,6 +6,21 @@ from .custom import CustomDataset
...
@@ -6,6 +6,21 @@ from .custom import CustomDataset
class
CocoDataset
(
CustomDataset
):
class
CocoDataset
(
CustomDataset
):
CLASSES
=
(
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic_light'
,
'fire_hydrant'
,
'stop_sign'
,
'parking_meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports_ball'
,
'kite'
,
'baseball_bat'
,
'baseball_glove'
,
'skateboard'
,
'surfboard'
,
'tennis_racket'
,
'bottle'
,
'wine_glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot_dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted_plant'
,
'bed'
,
'dining_table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell_phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy_bear'
,
'hair_drier'
,
'toothbrush'
)
def
load_annotations
(
self
,
ann_file
):
def
load_annotations
(
self
,
ann_file
):
self
.
coco
=
COCO
(
ann_file
)
self
.
coco
=
COCO
(
ann_file
)
self
.
cat_ids
=
self
.
coco
.
getCatIds
()
self
.
cat_ids
=
self
.
coco
.
getCatIds
()
...
...
mmdet/datasets/concat_dataset.py
View file @
825cfa0c
...
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
...
@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class
ConcatDataset
(
_ConcatDataset
):
class
ConcatDataset
(
_ConcatDataset
):
"""
"""A wrapper of concatenated dataset.
Same as torch.utils.data.dataset.ConcatDataset, but
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
"""
def
__init__
(
self
,
datasets
):
def
__init__
(
self
,
datasets
):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
self
.
CLASSES
=
datasets
[
0
].
CLASSES
if
hasattr
(
datasets
[
0
],
'flag'
):
if
hasattr
(
datasets
[
0
],
'flag'
):
flags
=
[]
flags
=
[]
for
i
in
range
(
0
,
len
(
datasets
)):
for
i
in
range
(
0
,
len
(
datasets
)):
...
...
mmdet/datasets/custom.py
View file @
825cfa0c
...
@@ -32,6 +32,8 @@ class CustomDataset(Dataset):
...
@@ -32,6 +32,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing.
The `ann` field is optional for testing.
"""
"""
CLASSES
=
None
def
__init__
(
self
,
def
__init__
(
self
,
ann_file
,
ann_file
,
img_prefix
,
img_prefix
,
...
@@ -45,6 +47,8 @@ class CustomDataset(Dataset):
...
@@ -45,6 +47,8 @@ class CustomDataset(Dataset):
with_crowd
=
True
,
with_crowd
=
True
,
with_label
=
True
,
with_label
=
True
,
test_mode
=
False
):
test_mode
=
False
):
# prefix of images path
self
.
img_prefix
=
img_prefix
# load annotations (and proposals)
# load annotations (and proposals)
self
.
img_infos
=
self
.
load_annotations
(
ann_file
)
self
.
img_infos
=
self
.
load_annotations
(
ann_file
)
if
proposal_file
is
not
None
:
if
proposal_file
is
not
None
:
...
@@ -58,8 +62,6 @@ class CustomDataset(Dataset):
...
@@ -58,8 +62,6 @@ class CustomDataset(Dataset):
if
self
.
proposals
is
not
None
:
if
self
.
proposals
is
not
None
:
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
# prefix of images path
self
.
img_prefix
=
img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self
.
img_scales
=
img_scale
if
isinstance
(
img_scale
,
self
.
img_scales
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
list
)
else
[
img_scale
]
...
...
mmdet/datasets/repeat_dataset.py
View file @
825cfa0c
...
@@ -6,12 +6,14 @@ class RepeatDataset(object):
...
@@ -6,12 +6,14 @@ class RepeatDataset(object):
def
__init__
(
self
,
dataset
,
times
):
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
times
=
times
self
.
times
=
times
self
.
CLASSES
=
dataset
.
CLASSES
if
hasattr
(
self
.
dataset
,
'flag'
):
if
hasattr
(
self
.
dataset
,
'flag'
):
self
.
flag
=
np
.
tile
(
self
.
dataset
.
flag
,
times
)
self
.
flag
=
np
.
tile
(
self
.
dataset
.
flag
,
times
)
self
.
_original_length
=
len
(
self
.
dataset
)
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
%
self
.
_ori
ginal
_len
gth
]
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
times
*
self
.
_ori
ginal
_len
gth
return
self
.
times
*
self
.
_ori_len
mmdet/models/detectors/base.py
View file @
825cfa0c
...
@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
...
@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
if
isinstance
(
dataset
,
str
):
if
isinstance
(
dataset
,
str
):
class_names
=
get_classes
(
dataset
)
class_names
=
get_classes
(
dataset
)
elif
isinstance
(
dataset
,
list
)
:
elif
isinstance
(
dataset
,
(
list
,
tuple
))
or
dataset
is
None
:
class_names
=
dataset
class_names
=
dataset
else
:
else
:
raise
TypeError
(
'dataset must be a valid dataset name or a list'
raise
TypeError
(
' of class names, not {}'
.
format
(
type
(
dataset
)))
'dataset must be a valid dataset name or a sequence'
' of class names, not {}'
.
format
(
type
(
dataset
)))
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
h
,
w
,
_
=
img_meta
[
'img_shape'
]
...
...
tools/test.py
View file @
825cfa0c
...
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
...
@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
def
single_test
(
model
,
data_loader
,
show
=
False
):
def
single_test
(
model
,
data_loader
,
show
=
False
):
model
.
eval
()
model
.
eval
()
results
=
[]
results
=
[]
prog_bar
=
mmcv
.
ProgressBar
(
len
(
data_loader
.
dataset
))
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
i
,
data
in
enumerate
(
data_loader
):
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
result
=
model
(
return_loss
=
False
,
rescale
=
not
show
,
**
data
)
results
.
append
(
result
)
results
.
append
(
result
)
if
show
:
if
show
:
model
.
module
.
show_result
(
data
,
result
,
model
.
module
.
show_result
(
data
,
result
,
dataset
.
img_norm_cfg
,
data
_loader
.
dataset
.
img_norm_cfg
)
data
set
.
CLASSES
)
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
batch_size
=
data
[
'img'
][
0
].
size
(
0
)
for
_
in
range
(
batch_size
):
for
_
in
range
(
batch_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment