Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
92018ce1
Commit
92018ce1
authored
May 14, 2020
by
liyinhao
Browse files
use self.get_classes()
parent
cbb98068
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
22 deletions
+42
-22
mmdet3d/core/evaluation/indoor_eval.py
mmdet3d/core/evaluation/indoor_eval.py
+11
-14
mmdet3d/datasets/indoor_base_dataset.py
mmdet3d/datasets/indoor_base_dataset.py
+24
-1
tests/test_scannet_dataset.py
tests/test_scannet_dataset.py
+4
-4
tests/test_sunrgbd_dataset.py
tests/test_sunrgbd_dataset.py
+3
-3
No files found.
mmdet3d/core/evaluation/indoor_eval.py
View file @
92018ce1
...
...
@@ -91,14 +91,14 @@ def eval_det_cls(pred, gt, ovthresh=None):
for a single class.
Args:
pred (dict):
map of
{img_id: [(bbox, score)]} where bbox is numpy array
gt (dict):
map of
{img_id: [bbox]}
ovthresh (List[float]): a list, iou threshold
pred (dict): {img_id: [(bbox, score)]} where bbox is numpy array
.
gt (dict): {img_id: [bbox]}
.
ovthresh (List[float]): a list, iou threshold
.
Return:
ndarray: numpy array of length nd
ndarray: numpy array of length nd
float: scalar, average precision
ndarray: numpy array of length nd
.
ndarray: numpy array of length nd
.
float: scalar, average precision
.
"""
# construct gt objects
...
...
@@ -295,13 +295,10 @@ def indoor_eval(gt_annos, dt_annos, metric, label2cat):
ret_dict
=
{}
for
i
,
iou_thresh
in
enumerate
(
metric
):
for
label
in
ap
[
i
].
keys
():
ret_dict
[
f
'
{
label2cat
[
label
]
}
_AP_
{
int
(
iou_thresh
*
100
)
}
'
]
=
ap
[
i
][
label
]
ret_dict
[
f
'mAP_
{
int
(
iou_thresh
*
100
)
}
'
]
=
sum
(
ap
[
i
].
values
())
/
len
(
ap
[
i
])
ret_dict
[
f
'
{
label2cat
[
label
]
}
_AP_
{
iou_thresh
:.
2
f
}
'
]
=
ap
[
i
][
label
]
ret_dict
[
f
'mAP_
{
iou_thresh
:.
2
f
}
'
]
=
sum
(
ap
[
i
].
values
())
/
len
(
ap
[
i
])
for
label
in
rec
[
i
].
keys
():
ret_dict
[
f
'
{
label2cat
[
label
]
}
_rec_
{
int
(
iou_thresh
*
100
)
}
'
]
=
rec
[
i
][
label
]
ret_dict
[
f
'mAR_
{
int
(
iou_thresh
*
100
)
}
'
]
=
sum
(
rec
[
i
].
values
())
/
len
(
rec
[
i
])
ret_dict
[
f
'
{
label2cat
[
label
]
}
_rec_
{
iou_thresh
:.
2
f
}
'
]
=
rec
[
i
][
label
]
ret_dict
[
f
'mAR_
{
iou_thresh
:.
2
f
}
'
]
=
sum
(
rec
[
i
].
values
())
/
len
(
rec
[
i
])
return
ret_dict
mmdet3d/datasets/indoor_base_dataset.py
View file @
92018ce1
...
...
@@ -20,7 +20,7 @@ class IndoorBaseDataset(torch_data.Dataset):
with_label
=
True
):
super
().
__init__
()
self
.
root_path
=
root_path
self
.
CLASSES
=
classes
if
classes
else
self
.
CLASSES
self
.
CLASSES
=
self
.
get_
classes
(
classes
)
self
.
test_mode
=
test_mode
self
.
label2cat
=
{
i
:
cat_id
for
i
,
cat_id
in
enumerate
(
self
.
CLASSES
)}
mmcv
.
check_file_exist
(
ann_file
)
...
...
@@ -77,6 +77,29 @@ class IndoorBaseDataset(torch_data.Dataset):
example
=
self
.
pipeline
(
input_dict
)
return
example
@
classmethod
def
get_classes
(
cls
,
classes
=
None
):
"""Get class names of current dataset
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
"""
if
classes
is
None
:
return
cls
.
CLASSES
if
isinstance
(
classes
,
str
):
# take it as a file path
class_names
=
mmcv
.
list_from_file
(
classes
)
elif
isinstance
(
classes
,
(
tuple
,
list
)):
class_names
=
classes
else
:
raise
ValueError
(
f
'Unsupported type
{
type
(
classes
)
}
of classes.'
)
return
class_names
def
_generate_annotations
(
self
,
output
):
"""Generate Annotations.
...
...
tests/test_scannet_dataset.py
View file @
92018ce1
...
...
@@ -109,10 +109,10 @@ def test_evaluate():
results
.
append
([
pred_boxes
])
metric
=
[
0.25
,
0.5
]
ret_dict
=
scannet_dataset
.
evaluate
(
results
,
metric
)
table_average_precision_25
=
ret_dict
[
'table_AP_25'
]
window_average_precision_25
=
ret_dict
[
'window_AP_25'
]
counter_average_precision_25
=
ret_dict
[
'counter_AP_25'
]
curtain_average_precision_25
=
ret_dict
[
'curtain_AP_25'
]
table_average_precision_25
=
ret_dict
[
'table_AP_
0.
25'
]
window_average_precision_25
=
ret_dict
[
'window_AP_
0.
25'
]
counter_average_precision_25
=
ret_dict
[
'counter_AP_
0.
25'
]
curtain_average_precision_25
=
ret_dict
[
'curtain_AP_
0.
25'
]
assert
abs
(
table_average_precision_25
-
0.3333
)
<
0.01
assert
abs
(
window_average_precision_25
-
1
)
<
0.01
assert
abs
(
counter_average_precision_25
-
1
)
<
0.01
...
...
tests/test_sunrgbd_dataset.py
View file @
92018ce1
...
...
@@ -85,9 +85,9 @@ def test_evaluate():
results
.
append
([
pred_boxes
])
metric
=
[
0.25
,
0.5
]
ap_dict
=
sunrgbd_dataset
.
evaluate
(
results
,
metric
)
bed_precision_25
=
ap_dict
[
'bed_AP_25'
]
dresser_precision_25
=
ap_dict
[
'dresser_AP_25'
]
night_stand_precision_25
=
ap_dict
[
'night_stand_AP_25'
]
bed_precision_25
=
ap_dict
[
'bed_AP_
0.
25'
]
dresser_precision_25
=
ap_dict
[
'dresser_AP_
0.
25'
]
night_stand_precision_25
=
ap_dict
[
'night_stand_AP_
0.
25'
]
assert
abs
(
bed_precision_25
-
1
)
<
0.01
assert
abs
(
dresser_precision_25
-
1
)
<
0.01
assert
abs
(
night_stand_precision_25
-
1
)
<
0.01
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment