Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
38474d6c
Commit
38474d6c
authored
Aug 18, 2022
by
Shaoshuai Shi
Browse files
support to use multithreads (cpu/cuda) to generate gt_database of WOD
parent
6a1f253a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
130 additions
and
8 deletions
+130
-8
pcdet/datasets/waymo/waymo_dataset.py
pcdet/datasets/waymo/waymo_dataset.py
+130
-8
No files found.
pcdet/datasets/waymo/waymo_dataset.py
View file @
38474d6c
...
...
@@ -13,6 +13,8 @@ import SharedArray
import
torch.distributed
as
dist
from
tqdm
import
tqdm
from
pathlib
import
Path
from
functools
import
partial
from
...ops.roiaware_pool3d
import
roiaware_pool3d_utils
from
...utils
import
box_utils
,
common_utils
from
..dataset
import
DatasetTemplate
...
...
@@ -140,7 +142,6 @@ class WaymoDataset(DatasetTemplate):
return
sequence_file
def
get_infos
(
self
,
raw_data_path
,
save_path
,
num_workers
=
multiprocessing
.
cpu_count
(),
has_label
=
True
,
sampled_interval
=
1
,
update_info_only
=
False
):
from
functools
import
partial
from
.
import
waymo_utils
print
(
'---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
%
(
sampled_interval
,
len
(
self
.
sample_sequence_list
)))
...
...
@@ -413,7 +414,7 @@ class WaymoDataset(DatasetTemplate):
point_offset_cnt
=
0
stacked_gt_points
=
[]
for
k
in
range
(
0
,
len
(
infos
),
sampled_interval
):
for
k
in
tqdm
(
range
(
0
,
len
(
infos
),
sampled_interval
)
)
:
print
(
'gt_database sample: %d/%d'
%
(
k
+
1
,
len
(
infos
)))
info
=
infos
[
k
]
...
...
@@ -487,6 +488,118 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points
=
np
.
concatenate
(
stacked_gt_points
,
axis
=
0
)
np
.
save
(
db_data_save_path
,
stacked_gt_points
)
def
create_gt_database_of_single_scene
(
self
,
info_with_idx
,
database_save_path
=
None
,
use_sequence_data
=
False
,
used_classes
=
None
,
use_cuda
=
False
):
info
,
info_idx
=
info_with_idx
all_db_infos
=
{}
pc_info
=
info
[
'point_cloud'
]
sequence_name
=
pc_info
[
'lidar_sequence'
]
sample_idx
=
pc_info
[
'sample_idx'
]
points
=
self
.
get_lidar
(
sequence_name
,
sample_idx
)
if
use_sequence_data
:
points
,
num_points_all
,
sample_idx_pre_list
=
self
.
get_sequence_data
(
info
,
points
,
sequence_name
,
sample_idx
,
self
.
dataset_cfg
.
SEQUENCE_CONFIG
)
annos
=
info
[
'annos'
]
names
=
annos
[
'name'
]
difficulty
=
annos
[
'difficulty'
]
gt_boxes
=
annos
[
'gt_boxes_lidar'
]
if
info_idx
%
4
!=
0
and
len
(
names
)
>
0
:
mask
=
(
names
==
'Vehicle'
)
names
=
names
[
~
mask
]
difficulty
=
difficulty
[
~
mask
]
gt_boxes
=
gt_boxes
[
~
mask
]
if
info_idx
%
2
!=
0
and
len
(
names
)
>
0
:
mask
=
(
names
==
'Pedestrian'
)
names
=
names
[
~
mask
]
difficulty
=
difficulty
[
~
mask
]
gt_boxes
=
gt_boxes
[
~
mask
]
num_obj
=
gt_boxes
.
shape
[
0
]
if
num_obj
==
0
:
return
{}
if
use_cuda
:
box_idxs_of_pts
=
roiaware_pool3d_utils
.
points_in_boxes_gpu
(
torch
.
from_numpy
(
points
[:,
0
:
3
]).
unsqueeze
(
dim
=
0
).
float
().
cuda
(),
torch
.
from_numpy
(
gt_boxes
[:,
0
:
7
]).
unsqueeze
(
dim
=
0
).
float
().
cuda
()
).
long
().
squeeze
(
dim
=
0
).
cpu
().
numpy
()
else
:
box_point_mask
=
roiaware_pool3d_utils
.
points_in_boxes_cpu
(
torch
.
from_numpy
(
points
[:,
0
:
3
]).
float
(),
torch
.
from_numpy
(
gt_boxes
[:,
0
:
7
]).
float
()
).
long
().
numpy
()
# (num_boxes, num_points)
for
i
in
range
(
num_obj
):
filename
=
'%s_%04d_%s_%d.bin'
%
(
sequence_name
,
sample_idx
,
names
[
i
],
i
)
filepath
=
database_save_path
/
filename
if
use_cuda
:
gt_points
=
points
[
box_idxs_of_pts
==
i
]
else
:
gt_points
=
points
[
box_point_mask
[
i
]
>
0
]
gt_points
[:,
:
3
]
-=
gt_boxes
[
i
,
:
3
]
if
(
used_classes
is
None
)
or
names
[
i
]
in
used_classes
:
with
open
(
filepath
,
'w'
)
as
f
:
gt_points
.
tofile
(
f
)
db_path
=
str
(
filepath
.
relative_to
(
self
.
root_path
))
# gt_database/xxxxx.bin
db_info
=
{
'name'
:
names
[
i
],
'path'
:
db_path
,
'sequence_name'
:
sequence_name
,
'sample_idx'
:
sample_idx
,
'gt_idx'
:
i
,
'box3d_lidar'
:
gt_boxes
[
i
],
'num_points_in_gt'
:
gt_points
.
shape
[
0
],
'difficulty'
:
difficulty
[
i
]}
if
names
[
i
]
in
all_db_infos
:
all_db_infos
[
names
[
i
]].
append
(
db_info
)
else
:
all_db_infos
[
names
[
i
]]
=
[
db_info
]
return
all_db_infos
def
create_groundtruth_database_parallel
(
self
,
info_path
,
save_path
,
used_classes
=
None
,
split
=
'train'
,
sampled_interval
=
10
,
processed_data_tag
=
None
,
num_workers
=
16
):
use_sequence_data
=
self
.
dataset_cfg
.
get
(
'SEQUENCE_CONFIG'
,
None
)
is
not
None
and
self
.
dataset_cfg
.
SEQUENCE_CONFIG
.
ENABLED
if
use_sequence_data
:
st_frame
,
ed_frame
=
self
.
dataset_cfg
.
SEQUENCE_CONFIG
.
SAMPLE_OFFSET
[
0
],
self
.
dataset_cfg
.
SEQUENCE_CONFIG
.
SAMPLE_OFFSET
[
1
]
st_frame
=
min
(
-
4
,
st_frame
)
# at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
database_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_parallel'
%
(
processed_data_tag
,
split
,
sampled_interval
,
st_frame
,
ed_frame
))
db_info_save_path
=
save_path
/
(
'%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_parallel.pkl'
%
(
processed_data_tag
,
split
,
sampled_interval
,
st_frame
,
ed_frame
))
db_data_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global_parallel.npy'
%
(
processed_data_tag
,
split
,
sampled_interval
,
st_frame
,
ed_frame
))
else
:
database_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_parallel'
%
(
processed_data_tag
,
split
,
sampled_interval
))
db_info_save_path
=
save_path
/
(
'%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl'
%
(
processed_data_tag
,
split
,
sampled_interval
))
db_data_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_global_parallel.npy'
%
(
processed_data_tag
,
split
,
sampled_interval
))
database_save_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
open
(
info_path
,
'rb'
)
as
f
:
infos
=
pickle
.
load
(
f
)
create_gt_database_of_single_scene
=
partial
(
self
.
create_gt_database_of_single_scene
,
use_sequence_data
=
use_sequence_data
,
database_save_path
=
database_save_path
,
used_classes
=
used_classes
,
use_cuda
=
True
)
# create_gt_database_of_single_scene((infos[0], 0))
with
multiprocessing
.
Pool
(
num_workers
)
as
p
:
all_db_infos_list
=
list
(
tqdm
(
p
.
imap
(
create_gt_database_of_single_scene
,
zip
(
infos
,
np
.
arange
(
len
(
infos
)))),
total
=
len
(
infos
)))
all_db_infos
=
{}
for
cur_db_infos
in
all_db_infos_list
:
for
key
,
val
in
cur_db_infos
.
items
():
if
key
not
in
all_db_infos
:
all_db_infos
[
key
]
=
val
else
:
all_db_infos
[
key
].
extend
(
val
)
for
k
,
v
in
all_db_infos
.
items
():
print
(
'Database %s: %d'
%
(
k
,
len
(
v
)))
with
open
(
db_info_save_path
,
'wb'
)
as
f
:
pickle
.
dump
(
all_db_infos
,
f
)
def
create_waymo_infos
(
dataset_cfg
,
class_names
,
data_path
,
save_path
,
raw_data_tag
=
'raw_data'
,
processed_data_tag
=
'waymo_processed_data'
,
...
...
@@ -538,7 +651,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
def
create_waymo_gt_database
(
dataset_cfg
,
class_names
,
data_path
,
save_path
,
processed_data_tag
=
'waymo_processed_data'
,
workers
=
min
(
16
,
multiprocessing
.
cpu_count
())):
workers
=
min
(
16
,
multiprocessing
.
cpu_count
())
,
use_parallel
=
False
):
dataset
=
WaymoDataset
(
dataset_cfg
=
dataset_cfg
,
class_names
=
class_names
,
root_path
=
data_path
,
training
=
False
,
logger
=
common_utils
.
create_logger
()
...
...
@@ -549,10 +662,17 @@ def create_waymo_gt_database(
print
(
'---------------Start create groundtruth database for data augmentation---------------'
)
dataset
.
set_split
(
train_split
)
dataset
.
create_groundtruth_database
(
info_path
=
train_filename
,
save_path
=
save_path
,
split
=
'train'
,
sampled_interval
=
1
,
used_classes
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
processed_data_tag
=
processed_data_tag
)
if
use_parallel
:
dataset
.
create_groundtruth_database_parallel
(
info_path
=
train_filename
,
save_path
=
save_path
,
split
=
'train'
,
sampled_interval
=
1
,
used_classes
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
processed_data_tag
=
processed_data_tag
,
num_workers
=
workers
)
else
:
dataset
.
create_groundtruth_database
(
info_path
=
train_filename
,
save_path
=
save_path
,
split
=
'train'
,
sampled_interval
=
1
,
used_classes
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
processed_data_tag
=
processed_data_tag
)
print
(
'---------------Data preparation Done---------------'
)
...
...
@@ -566,6 +686,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--func'
,
type
=
str
,
default
=
'create_waymo_infos'
,
help
=
''
)
parser
.
add_argument
(
'--processed_data_tag'
,
type
=
str
,
default
=
'waymo_processed_data_v0_5_0'
,
help
=
''
)
parser
.
add_argument
(
'--update_info_only'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
parser
.
add_argument
(
'--use_parallel'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
args
=
parser
.
parse_args
()
...
...
@@ -599,7 +720,8 @@ if __name__ == '__main__':
class_names
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
data_path
=
ROOT_DIR
/
'data'
/
'waymo'
,
save_path
=
ROOT_DIR
/
'data'
/
'waymo'
,
processed_data_tag
=
args
.
processed_data_tag
processed_data_tag
=
args
.
processed_data_tag
,
use_parallel
=
args
.
use_parallel
)
else
:
raise
NotImplementedError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment