Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
ef19d3cb
Commit
ef19d3cb
authored
May 09, 2020
by
liyinhao
Browse files
finish scannet dataset before evaluate
parent
f584b970
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
176 additions
and
0 deletions
+176
-0
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+176
-0
No files found.
mmdet3d/datasets/scannet_dataset.py
0 → 100644
View file @
ef19d3cb
import
os
import
os.path
as
osp
import
mmcv
import
numpy
as
np
import
torch.utils.data
as
torch_data
from
mmdet.datasets
import
DATASETS
from
.pipelines
import
Compose
@
DATASETS
.
register_module
()
class
ScannetDataset
(
torch_data
.
dataset
):
type2class
=
{
'cabinet'
:
0
,
'bed'
:
1
,
'chair'
:
2
,
'sofa'
:
3
,
'table'
:
4
,
'door'
:
5
,
'window'
:
6
,
'bookshelf'
:
7
,
'picture'
:
8
,
'counter'
:
9
,
'desk'
:
10
,
'curtain'
:
11
,
'refrigerator'
:
12
,
'showercurtrain'
:
13
,
'toilet'
:
14
,
'sink'
:
15
,
'bathtub'
:
16
,
'garbagebin'
:
17
}
class2type
=
{
0
:
'cabinet'
,
1
:
'bed'
,
2
:
'chair'
,
3
:
'sofa'
,
4
:
'table'
,
5
:
'door'
,
6
:
'window'
,
7
:
'bookshelf'
,
8
:
'picture'
,
9
:
'counter'
,
10
:
'desk'
,
11
:
'curtain'
,
12
:
'refrigerator'
,
13
:
'showercurtrain'
,
14
:
'toilet'
,
15
:
'sink'
,
16
:
'bathtub'
,
17
:
'garbagebin'
}
CLASSES
=
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
)
def
__init__
(
self
,
root_path
,
ann_file
,
split
,
pipeline
=
None
,
training
=
False
,
class_names
=
None
,
test_mode
=
False
):
super
().
__init__
()
self
.
root_path
=
root_path
self
.
class_names
=
class_names
if
class_names
else
self
.
CLASSES
self
.
split
=
split
self
.
data_path
=
os
.
path
.
join
(
root_path
,
'scannet_train_instance_data'
)
self
.
test_mode
=
test_mode
self
.
training
=
training
self
.
mode
=
'TRAIN'
if
self
.
training
else
'TEST'
self
.
ann_file
=
ann_file
# set group flag for the sampler
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
self
.
scannet_infos
=
mmcv
.
load
(
ann_file
)
# dataset config
self
.
num_class
=
len
(
self
.
class_names
)
self
.
pcd_limit_range
=
[
0
,
-
40
,
-
3.0
,
70.4
,
40
,
3.0
]
self
.
nyu40ids
=
np
.
array
(
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
])
self
.
nyu40id2class
=
{
nyu40id
:
i
for
i
,
nyu40id
in
enumerate
(
list
(
self
.
nyu40ids
))
}
if
pipeline
is
not
None
:
self
.
pipeline
=
Compose
(
pipeline
)
def
__getitem__
(
self
,
idx
):
if
self
.
test_mode
:
return
self
.
prepare_test_data
(
idx
)
while
True
:
data
=
self
.
prepare_train_data
(
idx
)
if
data
is
None
:
idx
=
self
.
_rand_another
(
idx
)
continue
return
data
def
prepare_test_data
(
self
,
index
):
input_dict
=
self
.
get_sensor_data
(
index
)
example
=
self
.
pipeline
(
input_dict
)
return
example
def
prepare_train_data
(
self
,
index
):
input_dict
=
self
.
get_sensor_data
(
index
)
input_dict
=
self
.
train_pre_pipeline
(
input_dict
)
if
input_dict
is
None
:
return
None
example
=
self
.
pipeline
(
input_dict
)
if
len
(
example
[
'gt_bboxes_3d'
].
_data
)
==
0
:
return
None
return
example
def
train_pre_pipeline
(
self
,
input_dict
):
if
len
(
input_dict
[
'gt_bboxes_3d'
])
==
0
:
return
None
return
input_dict
def
get_sensor_data
(
self
,
index
):
info
=
self
.
scannet_infos
[
index
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
points
=
self
.
get_lidar
(
sample_idx
)
input_dict
=
dict
(
sample_idx
=
sample_idx
,
points
=
points
,
)
if
self
.
with_label
:
annos
=
self
.
get_ann_info
(
index
,
sample_idx
)
input_dict
.
update
(
annos
)
return
input_dict
def
get_lidar
(
self
,
sample_idx
):
lidar_file
=
os
.
path
.
join
(
self
.
data_path
,
sample_idx
+
'_vert.npy'
)
assert
os
.
path
.
exists
(
lidar_file
)
return
np
.
load
(
lidar_file
)
def
get_ann_info
(
self
,
index
,
sample_idx
):
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
kitti_infos
[
index
]
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
]
# k, 6
gt_labels
=
info
[
'annos'
][
'class'
].
reshape
(
-
1
,
1
)
gt_bboxes_3d_mask
=
np
.
ones_like
(
gt_labels
)
else
:
gt_bboxes_3d
=
np
.
zeros
((
1
,
6
),
dtype
=
np
.
float32
)
gt_labels
=
np
.
zeros
((
1
,
1
))
gt_bboxes_3d_mask
=
np
.
zeros
((
1
,
1
))
pts_instance_mask_path
=
osp
.
join
(
self
.
data_path
,
sample_idx
+
'_ins_label.npy'
)
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_path
,
sample_idx
+
'_sem_label.npy'
)
anns_results
=
dict
(
gt_bboxes_3d
=
gt_bboxes_3d
,
gt_labels
=
gt_labels
,
gt_bboxes_3d_mask
=
gt_bboxes_3d_mask
,
pts_instance_mask_path
=
pts_instance_mask_path
,
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
def
_rand_another
(
self
,
idx
):
pool
=
np
.
where
(
self
.
flag
==
self
.
flag
[
idx
])[
0
]
return
np
.
random
.
choice
(
pool
)
def
__len__
(
self
):
return
len
(
self
.
scannet_infos
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment