Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
d4551b2f
Unverified
Commit
d4551b2f
authored
Jun 24, 2020
by
Gus-Guo
Committed by
GitHub
Jun 24, 2020
Browse files
Add codes for creating kitti infos and gt database (#4)
* add codes for creating kitti infos and gt database
parent
9027859a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
213 additions
and
7 deletions
+213
-7
pcdet/datasets/kitti/kitti_dataset.py
pcdet/datasets/kitti/kitti_dataset.py
+192
-4
pcdet/utils/box_utils.py
pcdet/utils/box_utils.py
+18
-0
tools/cfgs/kitti_models/pv_rcnn.yaml
tools/cfgs/kitti_models/pv_rcnn.yaml
+3
-3
No files found.
pcdet/datasets/kitti/kitti_dataset.py
View file @
d4551b2f
...
@@ -3,8 +3,9 @@ import copy
...
@@ -3,8 +3,9 @@ import copy
import
numpy
as
np
import
numpy
as
np
from
skimage
import
io
from
skimage
import
io
from
...utils
import
box_utils
,
common_utils
,
calibration_kitti
,
object3d_kitti
from
...utils
import
box_utils
,
common_utils
,
calibration_kitti
,
object3d_kitti
from
..dataset
import
DatasetTemplate
from
..dataset
import
DatasetTemplate
from
...ops.roiaware_pool3d
import
roiaware_pool3d_utils
class
KittiDataset
(
DatasetTemplate
):
class
KittiDataset
(
DatasetTemplate
):
...
@@ -45,6 +46,16 @@ class KittiDataset(DatasetTemplate):
...
@@ -45,6 +46,16 @@ class KittiDataset(DatasetTemplate):
if
self
.
logger
is
not
None
:
if
self
.
logger
is
not
None
:
self
.
logger
.
info
(
'Total samples for KITTI dataset: %d'
%
(
len
(
kitti_infos
)))
self
.
logger
.
info
(
'Total samples for KITTI dataset: %d'
%
(
len
(
kitti_infos
)))
def
set_split
(
self
,
split
):
super
().
__init__
(
dataset_cfg
=
self
.
dataset_cfg
,
class_names
=
self
.
class_names
,
training
=
self
.
training
,
root_path
=
self
.
root_path
,
logger
=
self
.
logger
)
self
.
split
=
split
self
.
root_split_path
=
self
.
root_path
/
(
'training'
if
self
.
split
!=
'test'
else
'testing'
)
split_dir
=
self
.
root_path
/
'ImageSets'
/
(
self
.
split
+
'.txt'
)
self
.
sample_id_list
=
[
x
.
strip
()
for
x
in
open
(
split_dir
).
readlines
()]
if
split_dir
.
exists
()
else
None
def
get_lidar
(
self
,
idx
):
def
get_lidar
(
self
,
idx
):
lidar_file
=
self
.
root_split_path
/
'velodyne'
/
(
'%s.bin'
%
idx
)
lidar_file
=
self
.
root_split_path
/
'velodyne'
/
(
'%s.bin'
%
idx
)
assert
lidar_file
.
exists
()
assert
lidar_file
.
exists
()
...
@@ -102,6 +113,134 @@ class KittiDataset(DatasetTemplate):
...
@@ -102,6 +113,134 @@ class KittiDataset(DatasetTemplate):
return
pts_valid_flag
return
pts_valid_flag
def
get_infos
(
self
,
num_workers
=
4
,
has_label
=
True
,
count_inside_pts
=
True
,
sample_id_list
=
None
):
import
concurrent.futures
as
futures
def
process_single_scene
(
sample_idx
):
print
(
'%s sample_idx: %s'
%
(
self
.
split
,
sample_idx
))
info
=
{}
pc_info
=
{
'num_features'
:
4
,
'lidar_idx'
:
sample_idx
}
info
[
'point_cloud'
]
=
pc_info
image_info
=
{
'image_idx'
:
sample_idx
,
'image_shape'
:
self
.
get_image_shape
(
sample_idx
)}
info
[
'image'
]
=
image_info
calib
=
self
.
get_calib
(
sample_idx
)
P2
=
np
.
concatenate
([
calib
.
P2
,
np
.
array
([[
0.
,
0.
,
0.
,
1.
]])],
axis
=
0
)
R0_4x4
=
np
.
zeros
([
4
,
4
],
dtype
=
calib
.
R0
.
dtype
)
R0_4x4
[
3
,
3
]
=
1.
R0_4x4
[:
3
,
:
3
]
=
calib
.
R0
V2C_4x4
=
np
.
concatenate
([
calib
.
V2C
,
np
.
array
([[
0.
,
0.
,
0.
,
1.
]])],
axis
=
0
)
calib_info
=
{
'P2'
:
P2
,
'R0_rect'
:
R0_4x4
,
'Tr_velo_to_cam'
:
V2C_4x4
}
info
[
'calib'
]
=
calib_info
if
has_label
:
obj_list
=
self
.
get_label
(
sample_idx
)
annotations
=
{}
annotations
[
'name'
]
=
np
.
array
([
obj
.
cls_type
for
obj
in
obj_list
])
annotations
[
'truncated'
]
=
np
.
array
([
obj
.
truncation
for
obj
in
obj_list
])
annotations
[
'occluded'
]
=
np
.
array
([
obj
.
occlusion
for
obj
in
obj_list
])
annotations
[
'alpha'
]
=
np
.
array
([
obj
.
alpha
for
obj
in
obj_list
])
annotations
[
'bbox'
]
=
np
.
concatenate
([
obj
.
box2d
.
reshape
(
1
,
4
)
for
obj
in
obj_list
],
axis
=
0
)
annotations
[
'dimensions'
]
=
np
.
array
([[
obj
.
l
,
obj
.
h
,
obj
.
w
]
for
obj
in
obj_list
])
# lhw(camera) format
annotations
[
'location'
]
=
np
.
concatenate
([
obj
.
loc
.
reshape
(
1
,
3
)
for
obj
in
obj_list
],
axis
=
0
)
annotations
[
'rotation_y'
]
=
np
.
array
([
obj
.
ry
for
obj
in
obj_list
])
annotations
[
'score'
]
=
np
.
array
([
obj
.
score
for
obj
in
obj_list
])
annotations
[
'difficulty'
]
=
np
.
array
([
obj
.
level
for
obj
in
obj_list
],
np
.
int32
)
num_objects
=
len
([
obj
.
cls_type
for
obj
in
obj_list
if
obj
.
cls_type
!=
'DontCare'
])
num_gt
=
len
(
annotations
[
'name'
])
index
=
list
(
range
(
num_objects
))
+
[
-
1
]
*
(
num_gt
-
num_objects
)
annotations
[
'index'
]
=
np
.
array
(
index
,
dtype
=
np
.
int32
)
loc
=
annotations
[
'location'
][:
num_objects
]
dims
=
annotations
[
'dimensions'
][:
num_objects
]
rots
=
annotations
[
'rotation_y'
][:
num_objects
]
loc_lidar
=
calib
.
rect_to_lidar
(
loc
)
l
,
h
,
w
=
dims
[:,
0
:
1
],
dims
[:,
1
:
2
],
dims
[:,
2
:
3
]
loc_lidar
[:,
2
]
+=
h
[:,
0
]
/
2
gt_boxes_lidar
=
np
.
concatenate
([
loc_lidar
,
l
,
w
,
h
,
-
(
np
.
pi
/
2
+
rots
[...,
np
.
newaxis
])],
axis
=
1
)
annotations
[
'gt_boxes_lidar'
]
=
gt_boxes_lidar
info
[
'annos'
]
=
annotations
if
count_inside_pts
:
points
=
self
.
get_lidar
(
sample_idx
)
calib
=
self
.
get_calib
(
sample_idx
)
pts_rect
=
calib
.
lidar_to_rect
(
points
[:,
0
:
3
])
fov_flag
=
self
.
get_fov_flag
(
pts_rect
,
info
[
'image'
][
'image_shape'
],
calib
)
pts_fov
=
points
[
fov_flag
]
corners_lidar
=
box_utils
.
boxes_to_corners_3d
(
gt_boxes_lidar
)
num_points_in_gt
=
-
np
.
ones
(
num_gt
,
dtype
=
np
.
int32
)
for
k
in
range
(
num_objects
):
flag
=
box_utils
.
in_hull
(
pts_fov
[:,
0
:
3
],
corners_lidar
[
k
])
num_points_in_gt
[
k
]
=
flag
.
sum
()
annotations
[
'num_points_in_gt'
]
=
num_points_in_gt
return
info
# temp = process_single_scene(self.sample_id_list[0])
sample_id_list
=
sample_id_list
if
sample_id_list
is
not
None
else
self
.
sample_id_list
sample_id_list
=
sample_id_list
[:
10
]
with
futures
.
ThreadPoolExecutor
(
num_workers
)
as
executor
:
infos
=
executor
.
map
(
process_single_scene
,
sample_id_list
)
return
list
(
infos
)
def
create_groundtruth_database
(
self
,
info_path
=
None
,
used_classes
=
None
,
split
=
'train'
):
import
torch
database_save_path
=
Path
(
self
.
root_path
)
/
(
'gt_database'
if
split
==
'train'
else
(
'gt_database_%s'
%
split
))
db_info_save_path
=
Path
(
self
.
root_path
)
/
(
'kitti_dbinfos_%s.pkl'
%
split
)
database_save_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
all_db_infos
=
{}
with
open
(
info_path
,
'rb'
)
as
f
:
infos
=
pickle
.
load
(
f
)
for
k
in
range
(
len
(
infos
)):
print
(
'gt_database sample: %d/%d'
%
(
k
+
1
,
len
(
infos
)))
info
=
infos
[
k
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
points
=
self
.
get_lidar
(
sample_idx
)
annos
=
info
[
'annos'
]
names
=
annos
[
'name'
]
difficulty
=
annos
[
'difficulty'
]
bbox
=
annos
[
'bbox'
]
gt_boxes
=
annos
[
'gt_boxes_lidar'
]
num_obj
=
gt_boxes
.
shape
[
0
]
point_indices
=
roiaware_pool3d_utils
.
points_in_boxes_cpu
(
torch
.
from_numpy
(
points
[:,
0
:
3
]),
torch
.
from_numpy
(
gt_boxes
)
).
numpy
()
# (nboxes, npoints)
for
i
in
range
(
num_obj
):
filename
=
'%s_%s_%d.bin'
%
(
sample_idx
,
names
[
i
],
i
)
filepath
=
database_save_path
/
filename
gt_points
=
points
[
point_indices
[
i
]
>
0
]
gt_points
[:,
:
3
]
-=
gt_boxes
[
i
,
:
3
]
with
open
(
filepath
,
'w'
)
as
f
:
gt_points
.
tofile
(
f
)
if
(
used_classes
is
None
)
or
names
[
i
]
in
used_classes
:
db_path
=
str
(
filepath
.
relative_to
(
self
.
root_path
))
# gt_database/xxxxx.bin
db_info
=
{
'name'
:
names
[
i
],
'path'
:
db_path
,
'image_idx'
:
sample_idx
,
'gt_idx'
:
i
,
'box3d_lidar'
:
gt_boxes
[
i
],
'num_points_in_gt'
:
gt_points
.
shape
[
0
],
'difficulty'
:
difficulty
[
i
],
'bbox'
:
bbox
[
i
],
'score'
:
annos
[
'score'
][
i
]}
if
names
[
i
]
in
all_db_infos
:
all_db_infos
[
names
[
i
]].
append
(
db_info
)
else
:
all_db_infos
[
names
[
i
]]
=
[
db_info
]
for
k
,
v
in
all_db_infos
.
items
():
print
(
'Database %s: %d'
%
(
k
,
len
(
v
)))
with
open
(
db_info_save_path
,
'wb'
)
as
f
:
pickle
.
dump
(
all_db_infos
,
f
)
@
staticmethod
@
staticmethod
def
generate_prediction_dicts
(
batch_dict
,
pred_dicts
,
class_names
,
output_path
=
None
):
def
generate_prediction_dicts
(
batch_dict
,
pred_dicts
,
class_names
,
output_path
=
None
):
"""
"""
...
@@ -238,9 +377,58 @@ class KittiDataset(DatasetTemplate):
...
@@ -238,9 +377,58 @@ class KittiDataset(DatasetTemplate):
return
data_dict
return
data_dict
def
create_kitti_infos
(
data_path
,
save_path
,
workers
=
4
):
def
create_kitti_infos
(
dataset_cfg
,
class_names
,
data_path
,
save_path
,
workers
=
4
):
pass
dataset
=
KittiDataset
(
dataset_cfg
=
dataset_cfg
,
class_names
=
class_names
,
root_path
=
data_path
)
train_split
,
val_split
=
'train'
,
'val'
train_filename
=
save_path
/
(
'kitti_infos_%s.pkl'
%
train_split
)
val_filename
=
save_path
/
(
'kitti_infos_%s.pkl'
%
val_split
)
trainval_filename
=
save_path
/
'kitti_infos_trainval.pkl'
test_filename
=
save_path
/
'kitti_infos_test.pkl'
print
(
'---------------Start to generate data infos---------------'
)
dataset
.
set_split
(
train_split
)
kitti_infos_train
=
dataset
.
get_infos
(
num_workers
=
workers
,
has_label
=
True
,
count_inside_pts
=
True
)
with
open
(
train_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
kitti_infos_train
,
f
)
print
(
'Kitti info train file is saved to %s'
%
train_filename
)
dataset
.
set_split
(
val_split
)
kitti_infos_val
=
dataset
.
get_infos
(
num_workers
=
workers
,
has_label
=
True
,
count_inside_pts
=
True
)
with
open
(
val_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
kitti_infos_val
,
f
)
print
(
'Kitti info val file is saved to %s'
%
val_filename
)
with
open
(
trainval_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
kitti_infos_train
+
kitti_infos_val
,
f
)
print
(
'Kitti info trainval file is saved to %s'
%
trainval_filename
)
dataset
.
set_split
(
'test'
)
kitti_infos_test
=
dataset
.
get_infos
(
num_workers
=
workers
,
has_label
=
False
,
count_inside_pts
=
False
)
with
open
(
test_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
kitti_infos_test
,
f
)
print
(
'Kitti info test file is saved to %s'
%
test_filename
)
print
(
'---------------Start create groundtruth database for data augmentation---------------'
)
dataset
.
set_split
(
train_split
)
dataset
.
create_groundtruth_database
(
train_filename
,
split
=
train_split
)
print
(
'---------------Data preparation Done---------------'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
pass
import
sys
if
sys
.
argv
.
__len__
()
>
1
and
sys
.
argv
[
1
]
==
'create_kitti_infos'
:
import
yaml
from
pathlib
import
Path
from
easydict
import
EasyDict
dataset_cfg
=
EasyDict
(
yaml
.
load
(
open
(
sys
.
argv
[
2
])))
ROOT_DIR
=
(
Path
(
__file__
).
resolve
().
parent
/
'../../../'
).
resolve
()
create_kitti_infos
(
dataset_cfg
=
dataset_cfg
,
class_names
=
[
'Car'
,
'Pedestrian'
,
'Cyclist'
],
data_path
=
ROOT_DIR
/
'data'
/
'kitti'
,
save_path
=
ROOT_DIR
/
'data'
/
'kitti'
)
pcdet/utils/box_utils.py
View file @
d4551b2f
...
@@ -2,6 +2,24 @@ import numpy as np
...
@@ -2,6 +2,24 @@ import numpy as np
import
torch
import
torch
from
.
import
common_utils
from
.
import
common_utils
from
..ops.roiaware_pool3d
import
roiaware_pool3d_utils
from
..ops.roiaware_pool3d
import
roiaware_pool3d_utils
from
scipy.spatial
import
Delaunay
import
scipy
def
in_hull
(
p
,
hull
):
"""
:param p: (N, K) test points
:param hull: (M, K) M corners of a box
:return (N) bool
"""
try
:
if
not
isinstance
(
hull
,
Delaunay
):
hull
=
Delaunay
(
hull
)
flag
=
hull
.
find_simplex
(
p
)
>=
0
except
scipy
.
spatial
.
qhull
.
QhullError
:
print
(
'Warning: not a hull %s'
%
str
(
hull
))
flag
=
np
.
zeros
(
p
.
shape
[
0
],
dtype
=
np
.
bool
)
return
flag
def
boxes_to_corners_3d
(
boxes3d
):
def
boxes_to_corners_3d
(
boxes3d
):
...
...
tools/cfgs/kitti_models/pv_rcnn.yaml
View file @
d4551b2f
...
@@ -79,8 +79,8 @@ MODEL:
...
@@ -79,8 +79,8 @@ MODEL:
PFE
:
PFE
:
NAME
:
VoxelSetAbstraction
NAME
:
VoxelSetAbstraction
POINT_SOURCE
:
voxel_center
s
POINT_SOURCE
:
raw_point
s
NUM_KEYPOINTS
:
2048
NUM_KEYPOINTS
:
4096
NUM_OUTPUT_FEATURES
:
128
NUM_OUTPUT_FEATURES
:
128
SAMPLE_METHOD
:
FPS
SAMPLE_METHOD
:
FPS
...
@@ -131,7 +131,7 @@ MODEL:
...
@@ -131,7 +131,7 @@ MODEL:
SHARED_FC
:
[
256
,
256
]
SHARED_FC
:
[
256
,
256
]
CLS_FC
:
[
256
,
256
]
CLS_FC
:
[
256
,
256
]
REG_FC
:
[
256
,
256
]
REG_FC
:
[
256
,
256
]
DP_RATIO
:
0.
2
DP_RATIO
:
0.
3
NMS_CONFIG
:
NMS_CONFIG
:
TRAIN
:
TRAIN
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment