Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
d71edf6c
"vscode:/vscode.git/clone" did not exist on "781c552ff9c5e3e4c60fe58126cf0f01f3989887"
Commit
d71edf6c
authored
May 12, 2020
by
yinchimaoliang
Browse files
finish test getitem
parent
49121b64
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
110 additions
and
44 deletions
+110
-44
mmdet3d/datasets/pipelines/formating.py
mmdet3d/datasets/pipelines/formating.py
+9
-6
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+30
-37
tests/test_scannet_dataset.py
tests/test_scannet_dataset.py
+70
-0
tools/test.py
tools/test.py
+1
-1
No files found.
mmdet3d/datasets/pipelines/formating.py
View file @
d71edf6c
...
...
@@ -127,6 +127,7 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
gt_bboxes_3d_mask
=
results
[
'gt_bboxes_3d_mask'
]
results
[
'gt_bboxes_3d'
]
=
results
[
'gt_bboxes_3d'
][
gt_bboxes_3d_mask
]
if
'gt_names_3d'
in
results
:
results
[
'gt_names_3d'
]
=
results
[
'gt_names_3d'
][
gt_bboxes_3d_mask
]
if
'gt_bboxes_mask'
in
results
:
...
...
@@ -151,8 +152,10 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
dtype
=
np
.
int64
)
# we still assume one pipeline for one frame LiDAR
# thus, the 3D name is list[string]
if
'gt_names_3d'
in
results
:
results
[
'gt_labels_3d'
]
=
np
.
array
([
self
.
class_names
.
index
(
n
)
for
n
in
results
[
'gt_names_3d'
]
self
.
class_names
.
index
(
n
)
for
n
in
results
[
'gt_names_3d'
]
],
dtype
=
np
.
int64
)
results
=
super
(
DefaultFormatBundle3D
,
self
).
__call__
(
results
)
...
...
mmdet3d/datasets/scannet_dataset.py
View file @
d71edf6c
...
...
@@ -11,7 +11,7 @@ from .pipelines import Compose
@
DATASETS
.
register_module
()
class
ScannetDataset
(
torch_data
.
d
ataset
):
class
ScannetDataset
(
torch_data
.
D
ataset
):
type2class
=
{
'cabinet'
:
0
,
'bed'
:
1
,
...
...
@@ -60,15 +60,14 @@ class ScannetDataset(torch_data.dataset):
def
__init__
(
self
,
root_path
,
ann_file
,
split
,
pipeline
=
None
,
training
=
False
,
class_names
=
None
,
test_mode
=
False
):
test_mode
=
False
,
with_label
=
True
):
super
().
__init__
()
self
.
root_path
=
root_path
self
.
class_names
=
class_names
if
class_names
else
self
.
CLASSES
self
.
split
=
split
self
.
data_path
=
os
.
path
.
join
(
root_path
,
'scannet_train_instance_data'
)
self
.
test_mode
=
test_mode
...
...
@@ -76,10 +75,6 @@ class ScannetDataset(torch_data.dataset):
self
.
mode
=
'TRAIN'
if
self
.
training
else
'TEST'
self
.
ann_file
=
ann_file
# set group flag for the sampler
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
self
.
scannet_infos
=
mmcv
.
load
(
ann_file
)
# dataset config
...
...
@@ -93,25 +88,26 @@ class ScannetDataset(torch_data.dataset):
}
if
pipeline
is
not
None
:
self
.
pipeline
=
Compose
(
pipeline
)
self
.
with_label
=
with_label
def
__getitem__
(
self
,
idx
):
if
self
.
test_mode
:
return
self
.
prepare_test_data
(
idx
)
return
self
.
_
prepare_test_data
(
idx
)
while
True
:
data
=
self
.
prepare_train_data
(
idx
)
data
=
self
.
_
prepare_train_data
(
idx
)
if
data
is
None
:
idx
=
self
.
_rand_another
(
idx
)
continue
return
data
def
prepare_test_data
(
self
,
index
):
input_dict
=
self
.
get_sensor_data
(
index
)
def
_
prepare_test_data
(
self
,
index
):
input_dict
=
self
.
_
get_sensor_data
(
index
)
example
=
self
.
pipeline
(
input_dict
)
return
example
def
prepare_train_data
(
self
,
index
):
input_dict
=
self
.
get_sensor_data
(
index
)
input_dict
=
self
.
train_pre_pipeline
(
input_dict
)
def
_
prepare_train_data
(
self
,
index
):
input_dict
=
self
.
_
get_sensor_data
(
index
)
input_dict
=
self
.
_
train_pre_pipeline
(
input_dict
)
if
input_dict
is
None
:
return
None
example
=
self
.
pipeline
(
input_dict
)
...
...
@@ -119,43 +115,40 @@ class ScannetDataset(torch_data.dataset):
return
None
return
example
def
train_pre_pipeline
(
self
,
input_dict
):
def
_
train_pre_pipeline
(
self
,
input_dict
):
if
len
(
input_dict
[
'gt_bboxes_3d'
])
==
0
:
return
None
return
input_dict
def
get_sensor_data
(
self
,
index
):
def
_
get_sensor_data
(
self
,
index
):
info
=
self
.
scannet_infos
[
index
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
p
oints
=
self
.
get_
lidar
(
sample_idx
)
p
ts_filename
=
self
.
_
get_
pts_filename
(
sample_idx
)
input_dict
=
dict
(
sample_idx
=
sample_idx
,
points
=
points
,
)
input_dict
=
dict
(
pts_filename
=
pts_filename
)
if
self
.
with_label
:
annos
=
self
.
get_ann_info
(
index
,
sample_idx
)
annos
=
self
.
_
get_ann_info
(
index
,
sample_idx
)
input_dict
.
update
(
annos
)
return
input_dict
def
get_
lidar
(
self
,
sample_idx
):
lidar_fil
e
=
os
.
path
.
join
(
self
.
data_path
,
sample_idx
+
'_vert.npy'
)
assert
os
.
path
.
exists
(
lidar_fil
e
)
return
np
.
load
(
lidar
_file
)
def
_
get_
pts_filename
(
self
,
sample_idx
):
pts_filenam
e
=
os
.
path
.
join
(
self
.
data_path
,
sample_idx
+
'_vert.npy'
)
mmcv
.
check_file_exist
(
pts_filenam
e
)
return
pts
_file
name
def
get_ann_info
(
self
,
index
,
sample_idx
):
def
_
get_ann_info
(
self
,
index
,
sample_idx
):
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
kitti
_infos
[
index
]
info
=
self
.
scannet
_infos
[
index
]
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
]
# k, 6
gt_labels
=
info
[
'annos'
][
'class'
]
.
reshape
(
-
1
,
1
)
gt_bboxes_3d_mask
=
np
.
ones_like
(
gt_labels
)
gt_labels
=
info
[
'annos'
][
'class'
]
gt_bboxes_3d_mask
=
np
.
ones_like
(
gt_labels
)
.
astype
(
np
.
bool
)
else
:
gt_bboxes_3d
=
np
.
zeros
((
1
,
6
),
dtype
=
np
.
float32
)
gt_labels
=
np
.
zeros
(
(
1
,
1
))
gt_bboxes_3d_mask
=
np
.
zeros
(
(
1
,
1
))
gt_labels
=
np
.
zeros
(
1
,
)
.
astype
(
np
.
bool
)
gt_bboxes_3d_mask
=
np
.
zeros
(
1
,
)
.
astype
(
np
.
bool
)
pts_instance_mask_path
=
osp
.
join
(
self
.
data_path
,
sample_idx
+
'_ins_label.npy'
)
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_path
,
...
...
@@ -173,7 +166,7 @@ class ScannetDataset(torch_data.dataset):
pool
=
np
.
where
(
self
.
flag
==
self
.
flag
[
idx
])[
0
]
return
np
.
random
.
choice
(
pool
)
def
generate_annotations
(
self
,
output
):
def
_
generate_annotations
(
self
,
output
):
'''
transfer input_dict & pred_dicts to anno format
which is needed by AP calculator
...
...
@@ -209,15 +202,15 @@ class ScannetDataset(torch_data.dataset):
return
result
def
format_results
(
self
,
outputs
):
def
_
format_results
(
self
,
outputs
):
results
=
[]
for
output
in
outputs
:
result
=
self
.
generate_annotations
(
output
)
result
=
self
.
_
generate_annotations
(
output
)
results
.
append
(
result
)
return
results
def
evaluate
(
self
,
results
,
metric
=
None
,
logger
=
None
,
pklfile_prefix
=
None
):
results
=
self
.
format_results
(
results
)
results
=
self
.
_
format_results
(
results
)
from
mmdet3d.core.evaluation.scannet_utils.eval
import
scannet_eval
assert
(
'AP_IOU_THRESHHOLDS'
in
metric
)
gt_annos
=
[
...
...
tests/test_scannet_dataset.py
0 → 100644
View file @
d71edf6c
import
numpy
as
np
from
mmdet3d.datasets.scannet_dataset
import
ScannetDataset
def
test_getitem
():
np
.
random
.
seed
(
0
)
root_path
=
'./tests/data/scannet'
ann_file
=
'./tests/data/scannet/scannet_infos.pkl'
class_names
=
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
)
pipelines
=
[
dict
(
type
=
'IndoorLoadPointsFromFile'
,
use_height
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
]),
dict
(
type
=
'IndoorLoadAnnotations3D'
),
dict
(
type
=
'IndoorPointSample'
,
num_points
=
5
),
dict
(
type
=
'IndoorFlipData'
,
flip_ratio_yz
=
1.0
,
flip_ratio_xz
=
1.0
),
dict
(
type
=
'IndoorGlobalRotScale'
,
use_height
=
True
,
rot_range
=
[
-
np
.
pi
*
1
/
36
,
np
.
pi
*
1
/
36
],
scale_range
=
None
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels'
,
'pts_semantic_mask'
,
'pts_instance_mask'
]),
]
scannet_dataset
=
ScannetDataset
(
root_path
,
ann_file
,
pipelines
,
True
)
data
=
scannet_dataset
[
0
]
points
=
data
[
'points'
].
_data
gt_bboxes_3d
=
data
[
'gt_bboxes_3d'
].
_data
gt_labels
=
data
[
'gt_labels'
].
_data
pts_semantic_mask
=
data
[
'pts_semantic_mask'
]
pts_instance_mask
=
data
[
'pts_instance_mask'
]
expected_points
=
np
.
array
(
[[
-
2.9078157
,
-
1.9569951
,
2.3543026
,
2.389488
],
[
-
0.71360034
,
-
3.4359822
,
2.1330001
,
2.1681855
],
[
-
1.332374
,
1.474838
,
-
0.04405887
,
-
0.00887359
],
[
2.1336637
,
-
1.3265059
,
-
0.02880373
,
0.00638155
],
[
0.43895668
,
-
3.0259454
,
1.5560012
,
1.5911865
]])
expected_gt_bboxes_3d
=
np
.
array
([
[
-
1.5005362
,
-
3.512584
,
1.8565295
,
1.7457027
,
0.24149807
,
0.57235193
],
[
-
2.8848705
,
3.4961755
,
1.5268247
,
0.66170084
,
0.17433672
,
0.67153597
],
[
-
1.1585636
,
-
2.192365
,
0.61649567
,
0.5557011
,
2.5375574
,
1.2144762
],
[
-
2.930457
,
-
2.4856408
,
0.9722377
,
0.6270478
,
1.8461524
,
0.28697443
],
[
3.3114715
,
-
0.00476722
,
1.0712197
,
0.46191898
,
3.8605113
,
2.1603441
]
])
expected_gt_labels
=
np
.
array
([
6
,
6
,
4
,
9
,
11
,
11
,
10
,
0
,
15
,
17
,
17
,
17
,
3
,
12
,
4
,
4
,
14
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
5
,
5
,
5
])
expected_pts_semantic_mask
=
np
.
array
([
3
,
1
,
2
,
2
,
15
])
expected_pts_instance_mask
=
np
.
array
([
44
,
22
,
10
,
10
,
57
])
assert
np
.
allclose
(
points
,
expected_points
)
assert
gt_bboxes_3d
[:
5
].
shape
==
(
5
,
6
)
assert
np
.
allclose
(
gt_bboxes_3d
[:
5
],
expected_gt_bboxes_3d
)
assert
np
.
all
(
gt_labels
.
numpy
()
==
expected_gt_labels
)
assert
np
.
all
(
pts_semantic_mask
==
expected_pts_semantic_mask
)
assert
np
.
all
(
pts_instance_mask
==
expected_pts_instance_mask
)
tools/test.py
View file @
d71edf6c
...
...
@@ -161,7 +161,7 @@ def main():
mmcv
.
dump
(
outputs
,
args
.
out
)
kwargs
=
{}
if
args
.
options
is
None
else
args
.
options
if
args
.
format_only
:
dataset
.
format_results
(
outputs
,
**
kwargs
)
dataset
.
_
format_results
(
outputs
,
**
kwargs
)
if
args
.
eval
:
dataset
.
evaluate
(
outputs
,
args
.
eval
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment