Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
72350b2d
Commit
72350b2d
authored
May 14, 2020
by
liyinhao
Browse files
merge funcs, change names
parent
cbb549aa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
33 deletions
+23
-33
mmdet3d/datasets/indoor_base_dataset.py
mmdet3d/datasets/indoor_base_dataset.py
+17
-23
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+3
-5
mmdet3d/datasets/sunrgbd_dataset.py
mmdet3d/datasets/sunrgbd_dataset.py
+3
-5
No files found.
mmdet3d/datasets/indoor_base_dataset.py
View file @
72350b2d
...
...
@@ -16,44 +16,42 @@ class IndoorBaseDataset(torch_data.Dataset):
ann_file
,
pipeline
=
None
,
training
=
False
,
c
at_id
s
=
None
,
c
lasse
s
=
None
,
test_mode
=
False
,
with_label
=
True
):
super
().
__init__
()
self
.
root_path
=
root_path
self
.
cat_ids
=
cat_id
s
if
c
at_id
s
else
self
.
CLASSES
self
.
CLASSES
=
classe
s
if
c
lasse
s
else
self
.
CLASSES
self
.
test_mode
=
test_mode
self
.
training
=
training
self
.
mode
=
'TRAIN'
if
self
.
training
else
'TEST'
self
.
label2cat
=
{
i
:
cat_id
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)}
self
.
label2cat
=
{
i
:
cat_id
for
i
,
cat_id
in
enumerate
(
self
.
CLASSES
)}
mmcv
.
check_file_exist
(
ann_file
)
self
.
infos
=
mmcv
.
load
(
ann_file
)
self
.
data_
infos
=
mmcv
.
load
(
ann_file
)
# dataset config
self
.
num_class
=
len
(
self
.
cat_ids
)
self
.
pcd_limit_range
=
[
0
,
-
40
,
-
3.0
,
70.4
,
40
,
3.0
]
self
.
num_class
=
len
(
self
.
CLASSES
)
if
pipeline
is
not
None
:
self
.
pipeline
=
Compose
(
pipeline
)
self
.
with_label
=
with_label
def
__getitem__
(
self
,
idx
):
if
self
.
test_mode
:
return
self
.
_
prepare_test_data
(
idx
)
return
self
.
prepare_test_data
(
idx
)
while
True
:
data
=
self
.
_
prepare_train_data
(
idx
)
data
=
self
.
prepare_train_data
(
idx
)
if
data
is
None
:
idx
=
self
.
_rand_another
(
idx
)
continue
return
data
def
_
prepare_test_data
(
self
,
index
):
input_dict
=
self
.
_
get_
sensor_data
(
index
)
def
prepare_test_data
(
self
,
index
):
input_dict
=
self
.
get_
data_info
(
index
)
example
=
self
.
pipeline
(
input_dict
)
return
example
def
_prepare_train_data
(
self
,
index
):
input_dict
=
self
.
_get_sensor_data
(
index
)
input_dict
=
self
.
_train_pre_pipeline
(
input_dict
)
def
prepare_train_data
(
self
,
index
):
input_dict
=
self
.
get_data_info
(
index
)
if
input_dict
is
None
:
return
None
example
=
self
.
pipeline
(
input_dict
)
...
...
@@ -61,13 +59,8 @@ class IndoorBaseDataset(torch_data.Dataset):
return
None
return
example
def
_train_pre_pipeline
(
self
,
input_dict
):
if
len
(
input_dict
[
'gt_bboxes_3d'
])
==
0
:
return
None
return
input_dict
def
_get_sensor_data
(
self
,
index
):
info
=
self
.
infos
[
index
]
def
get_data_info
(
self
,
index
):
info
=
self
.
data_infos
[
index
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
pts_filename
=
self
.
_get_pts_filename
(
sample_idx
)
...
...
@@ -76,7 +69,8 @@ class IndoorBaseDataset(torch_data.Dataset):
if
self
.
with_label
:
annos
=
self
.
_get_ann_info
(
index
,
sample_idx
)
input_dict
.
update
(
annos
)
if
len
(
input_dict
[
'gt_bboxes_3d'
])
==
0
:
return
None
return
input_dict
def
_rand_another
(
self
,
idx
):
...
...
@@ -132,9 +126,9 @@ class IndoorBaseDataset(torch_data.Dataset):
results
=
self
.
format_results
(
results
)
from
mmdet3d.core.evaluation
import
indoor_eval
assert
len
(
metric
)
>
0
gt_annos
=
[
copy
.
deepcopy
(
info
[
'annos'
])
for
info
in
self
.
infos
]
gt_annos
=
[
copy
.
deepcopy
(
info
[
'annos'
])
for
info
in
self
.
data_
infos
]
ret_dict
=
indoor_eval
(
gt_annos
,
results
,
metric
,
self
.
label2cat
)
return
ret_dict
def
__len__
(
self
):
return
len
(
self
.
infos
)
return
len
(
self
.
data_
infos
)
mmdet3d/datasets/scannet_dataset.py
View file @
72350b2d
import
os.path
as
osp
import
mmcv
import
numpy
as
np
from
mmdet.datasets
import
DATASETS
...
...
@@ -20,22 +19,21 @@ class ScannetBaseDataset(IndoorBaseDataset):
ann_file
,
pipeline
=
None
,
training
=
False
,
c
at_id
s
=
None
,
c
lasse
s
=
None
,
test_mode
=
False
,
with_label
=
True
):
super
().
__init__
(
root_path
,
ann_file
,
pipeline
,
training
,
c
at_id
s
,
super
().
__init__
(
root_path
,
ann_file
,
pipeline
,
training
,
c
lasse
s
,
test_mode
,
with_label
)
self
.
data_path
=
osp
.
join
(
root_path
,
'scannet_train_instance_data'
)
def
_get_pts_filename
(
self
,
sample_idx
):
pts_filename
=
osp
.
join
(
self
.
data_path
,
f
'
{
sample_idx
}
_vert.npy'
)
mmcv
.
check_file_exist
(
pts_filename
)
return
pts_filename
def
_get_ann_info
(
self
,
index
,
sample_idx
):
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
infos
[
index
]
info
=
self
.
data_
infos
[
index
]
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
]
# k, 6
gt_labels
=
info
[
'annos'
][
'class'
]
...
...
mmdet3d/datasets/sunrgbd_dataset.py
View file @
72350b2d
import
os.path
as
osp
import
mmcv
import
numpy
as
np
from
mmdet.datasets
import
DATASETS
...
...
@@ -18,22 +17,21 @@ class SunrgbdBaseDataset(IndoorBaseDataset):
ann_file
,
pipeline
=
None
,
training
=
False
,
c
at_id
s
=
None
,
c
lasse
s
=
None
,
test_mode
=
False
,
with_label
=
True
):
super
().
__init__
(
root_path
,
ann_file
,
pipeline
,
training
,
c
at_id
s
,
super
().
__init__
(
root_path
,
ann_file
,
pipeline
,
training
,
c
lasse
s
,
test_mode
,
with_label
)
self
.
data_path
=
osp
.
join
(
root_path
,
'sunrgbd_trainval'
)
def
_get_pts_filename
(
self
,
sample_idx
):
pts_filename
=
osp
.
join
(
self
.
data_path
,
'lidar'
,
f
'
{
sample_idx
:
06
d
}
.npy'
)
mmcv
.
check_file_exist
(
pts_filename
)
return
pts_filename
def
_get_ann_info
(
self
,
index
,
sample_idx
):
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
infos
[
index
]
info
=
self
.
data_
infos
[
index
]
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
]
# k, 6
gt_labels
=
info
[
'annos'
][
'class'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment