Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
38474d6c
Commit
38474d6c
authored
Aug 18, 2022
by
Shaoshuai Shi
Browse files
support to use multithreads (cpu/cuda) to generate gt_database of WOD
parent
6a1f253a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
130 additions
and
8 deletions
+130
-8
pcdet/datasets/waymo/waymo_dataset.py
pcdet/datasets/waymo/waymo_dataset.py
+130
-8
No files found.
pcdet/datasets/waymo/waymo_dataset.py
View file @
38474d6c
...
@@ -13,6 +13,8 @@ import SharedArray
...
@@ -13,6 +13,8 @@ import SharedArray
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
pathlib
import
Path
from
pathlib
import
Path
from
functools
import
partial
from
...ops.roiaware_pool3d
import
roiaware_pool3d_utils
from
...ops.roiaware_pool3d
import
roiaware_pool3d_utils
from
...utils
import
box_utils
,
common_utils
from
...utils
import
box_utils
,
common_utils
from
..dataset
import
DatasetTemplate
from
..dataset
import
DatasetTemplate
...
@@ -140,7 +142,6 @@ class WaymoDataset(DatasetTemplate):
...
@@ -140,7 +142,6 @@ class WaymoDataset(DatasetTemplate):
return
sequence_file
return
sequence_file
def
get_infos
(
self
,
raw_data_path
,
save_path
,
num_workers
=
multiprocessing
.
cpu_count
(),
has_label
=
True
,
sampled_interval
=
1
,
update_info_only
=
False
):
def
get_infos
(
self
,
raw_data_path
,
save_path
,
num_workers
=
multiprocessing
.
cpu_count
(),
has_label
=
True
,
sampled_interval
=
1
,
update_info_only
=
False
):
from
functools
import
partial
from
.
import
waymo_utils
from
.
import
waymo_utils
print
(
'---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
print
(
'---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
%
(
sampled_interval
,
len
(
self
.
sample_sequence_list
)))
%
(
sampled_interval
,
len
(
self
.
sample_sequence_list
)))
...
@@ -413,7 +414,7 @@ class WaymoDataset(DatasetTemplate):
...
@@ -413,7 +414,7 @@ class WaymoDataset(DatasetTemplate):
point_offset_cnt
=
0
point_offset_cnt
=
0
stacked_gt_points
=
[]
stacked_gt_points
=
[]
for
k
in
range
(
0
,
len
(
infos
),
sampled_interval
):
for
k
in
tqdm
(
range
(
0
,
len
(
infos
),
sampled_interval
)
)
:
print
(
'gt_database sample: %d/%d'
%
(
k
+
1
,
len
(
infos
)))
print
(
'gt_database sample: %d/%d'
%
(
k
+
1
,
len
(
infos
)))
info
=
infos
[
k
]
info
=
infos
[
k
]
...
@@ -487,6 +488,118 @@ class WaymoDataset(DatasetTemplate):
...
@@ -487,6 +488,118 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points
=
np
.
concatenate
(
stacked_gt_points
,
axis
=
0
)
stacked_gt_points
=
np
.
concatenate
(
stacked_gt_points
,
axis
=
0
)
np
.
save
(
db_data_save_path
,
stacked_gt_points
)
np
.
save
(
db_data_save_path
,
stacked_gt_points
)
def
create_gt_database_of_single_scene
(
self
,
info_with_idx
,
database_save_path
=
None
,
use_sequence_data
=
False
,
used_classes
=
None
,
use_cuda
=
False
):
info
,
info_idx
=
info_with_idx
all_db_infos
=
{}
pc_info
=
info
[
'point_cloud'
]
sequence_name
=
pc_info
[
'lidar_sequence'
]
sample_idx
=
pc_info
[
'sample_idx'
]
points
=
self
.
get_lidar
(
sequence_name
,
sample_idx
)
if
use_sequence_data
:
points
,
num_points_all
,
sample_idx_pre_list
=
self
.
get_sequence_data
(
info
,
points
,
sequence_name
,
sample_idx
,
self
.
dataset_cfg
.
SEQUENCE_CONFIG
)
annos
=
info
[
'annos'
]
names
=
annos
[
'name'
]
difficulty
=
annos
[
'difficulty'
]
gt_boxes
=
annos
[
'gt_boxes_lidar'
]
if
info_idx
%
4
!=
0
and
len
(
names
)
>
0
:
mask
=
(
names
==
'Vehicle'
)
names
=
names
[
~
mask
]
difficulty
=
difficulty
[
~
mask
]
gt_boxes
=
gt_boxes
[
~
mask
]
if
info_idx
%
2
!=
0
and
len
(
names
)
>
0
:
mask
=
(
names
==
'Pedestrian'
)
names
=
names
[
~
mask
]
difficulty
=
difficulty
[
~
mask
]
gt_boxes
=
gt_boxes
[
~
mask
]
num_obj
=
gt_boxes
.
shape
[
0
]
if
num_obj
==
0
:
return
{}
if
use_cuda
:
box_idxs_of_pts
=
roiaware_pool3d_utils
.
points_in_boxes_gpu
(
torch
.
from_numpy
(
points
[:,
0
:
3
]).
unsqueeze
(
dim
=
0
).
float
().
cuda
(),
torch
.
from_numpy
(
gt_boxes
[:,
0
:
7
]).
unsqueeze
(
dim
=
0
).
float
().
cuda
()
).
long
().
squeeze
(
dim
=
0
).
cpu
().
numpy
()
else
:
box_point_mask
=
roiaware_pool3d_utils
.
points_in_boxes_cpu
(
torch
.
from_numpy
(
points
[:,
0
:
3
]).
float
(),
torch
.
from_numpy
(
gt_boxes
[:,
0
:
7
]).
float
()
).
long
().
numpy
()
# (num_boxes, num_points)
for
i
in
range
(
num_obj
):
filename
=
'%s_%04d_%s_%d.bin'
%
(
sequence_name
,
sample_idx
,
names
[
i
],
i
)
filepath
=
database_save_path
/
filename
if
use_cuda
:
gt_points
=
points
[
box_idxs_of_pts
==
i
]
else
:
gt_points
=
points
[
box_point_mask
[
i
]
>
0
]
gt_points
[:,
:
3
]
-=
gt_boxes
[
i
,
:
3
]
if
(
used_classes
is
None
)
or
names
[
i
]
in
used_classes
:
with
open
(
filepath
,
'w'
)
as
f
:
gt_points
.
tofile
(
f
)
db_path
=
str
(
filepath
.
relative_to
(
self
.
root_path
))
# gt_database/xxxxx.bin
db_info
=
{
'name'
:
names
[
i
],
'path'
:
db_path
,
'sequence_name'
:
sequence_name
,
'sample_idx'
:
sample_idx
,
'gt_idx'
:
i
,
'box3d_lidar'
:
gt_boxes
[
i
],
'num_points_in_gt'
:
gt_points
.
shape
[
0
],
'difficulty'
:
difficulty
[
i
]}
if
names
[
i
]
in
all_db_infos
:
all_db_infos
[
names
[
i
]].
append
(
db_info
)
else
:
all_db_infos
[
names
[
i
]]
=
[
db_info
]
return
all_db_infos
def
create_groundtruth_database_parallel
(
self
,
info_path
,
save_path
,
used_classes
=
None
,
split
=
'train'
,
sampled_interval
=
10
,
processed_data_tag
=
None
,
num_workers
=
16
):
use_sequence_data
=
self
.
dataset_cfg
.
get
(
'SEQUENCE_CONFIG'
,
None
)
is
not
None
and
self
.
dataset_cfg
.
SEQUENCE_CONFIG
.
ENABLED
if
use_sequence_data
:
st_frame
,
ed_frame
=
self
.
dataset_cfg
.
SEQUENCE_CONFIG
.
SAMPLE_OFFSET
[
0
],
self
.
dataset_cfg
.
SEQUENCE_CONFIG
.
SAMPLE_OFFSET
[
1
]
st_frame
=
min
(
-
4
,
st_frame
)
# at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
database_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_parallel'
%
(
processed_data_tag
,
split
,
sampled_interval
,
st_frame
,
ed_frame
))
db_info_save_path
=
save_path
/
(
'%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_parallel.pkl'
%
(
processed_data_tag
,
split
,
sampled_interval
,
st_frame
,
ed_frame
))
db_data_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global_parallel.npy'
%
(
processed_data_tag
,
split
,
sampled_interval
,
st_frame
,
ed_frame
))
else
:
database_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_parallel'
%
(
processed_data_tag
,
split
,
sampled_interval
))
db_info_save_path
=
save_path
/
(
'%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl'
%
(
processed_data_tag
,
split
,
sampled_interval
))
db_data_save_path
=
save_path
/
(
'%s_gt_database_%s_sampled_%d_global_parallel.npy'
%
(
processed_data_tag
,
split
,
sampled_interval
))
database_save_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
open
(
info_path
,
'rb'
)
as
f
:
infos
=
pickle
.
load
(
f
)
create_gt_database_of_single_scene
=
partial
(
self
.
create_gt_database_of_single_scene
,
use_sequence_data
=
use_sequence_data
,
database_save_path
=
database_save_path
,
used_classes
=
used_classes
,
use_cuda
=
True
)
# create_gt_database_of_single_scene((infos[0], 0))
with
multiprocessing
.
Pool
(
num_workers
)
as
p
:
all_db_infos_list
=
list
(
tqdm
(
p
.
imap
(
create_gt_database_of_single_scene
,
zip
(
infos
,
np
.
arange
(
len
(
infos
)))),
total
=
len
(
infos
)))
all_db_infos
=
{}
for
cur_db_infos
in
all_db_infos_list
:
for
key
,
val
in
cur_db_infos
.
items
():
if
key
not
in
all_db_infos
:
all_db_infos
[
key
]
=
val
else
:
all_db_infos
[
key
].
extend
(
val
)
for
k
,
v
in
all_db_infos
.
items
():
print
(
'Database %s: %d'
%
(
k
,
len
(
v
)))
with
open
(
db_info_save_path
,
'wb'
)
as
f
:
pickle
.
dump
(
all_db_infos
,
f
)
def
create_waymo_infos
(
dataset_cfg
,
class_names
,
data_path
,
save_path
,
def
create_waymo_infos
(
dataset_cfg
,
class_names
,
data_path
,
save_path
,
raw_data_tag
=
'raw_data'
,
processed_data_tag
=
'waymo_processed_data'
,
raw_data_tag
=
'raw_data'
,
processed_data_tag
=
'waymo_processed_data'
,
...
@@ -538,7 +651,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
...
@@ -538,7 +651,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
def
create_waymo_gt_database
(
def
create_waymo_gt_database
(
dataset_cfg
,
class_names
,
data_path
,
save_path
,
processed_data_tag
=
'waymo_processed_data'
,
dataset_cfg
,
class_names
,
data_path
,
save_path
,
processed_data_tag
=
'waymo_processed_data'
,
workers
=
min
(
16
,
multiprocessing
.
cpu_count
())):
workers
=
min
(
16
,
multiprocessing
.
cpu_count
())
,
use_parallel
=
False
):
dataset
=
WaymoDataset
(
dataset
=
WaymoDataset
(
dataset_cfg
=
dataset_cfg
,
class_names
=
class_names
,
root_path
=
data_path
,
dataset_cfg
=
dataset_cfg
,
class_names
=
class_names
,
root_path
=
data_path
,
training
=
False
,
logger
=
common_utils
.
create_logger
()
training
=
False
,
logger
=
common_utils
.
create_logger
()
...
@@ -549,10 +662,17 @@ def create_waymo_gt_database(
...
@@ -549,10 +662,17 @@ def create_waymo_gt_database(
print
(
'---------------Start create groundtruth database for data augmentation---------------'
)
print
(
'---------------Start create groundtruth database for data augmentation---------------'
)
dataset
.
set_split
(
train_split
)
dataset
.
set_split
(
train_split
)
dataset
.
create_groundtruth_database
(
if
use_parallel
:
info_path
=
train_filename
,
save_path
=
save_path
,
split
=
'train'
,
sampled_interval
=
1
,
dataset
.
create_groundtruth_database_parallel
(
used_classes
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
processed_data_tag
=
processed_data_tag
info_path
=
train_filename
,
save_path
=
save_path
,
split
=
'train'
,
sampled_interval
=
1
,
)
used_classes
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
processed_data_tag
=
processed_data_tag
,
num_workers
=
workers
)
else
:
dataset
.
create_groundtruth_database
(
info_path
=
train_filename
,
save_path
=
save_path
,
split
=
'train'
,
sampled_interval
=
1
,
used_classes
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
processed_data_tag
=
processed_data_tag
)
print
(
'---------------Data preparation Done---------------'
)
print
(
'---------------Data preparation Done---------------'
)
...
@@ -566,6 +686,7 @@ if __name__ == '__main__':
...
@@ -566,6 +686,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--func'
,
type
=
str
,
default
=
'create_waymo_infos'
,
help
=
''
)
parser
.
add_argument
(
'--func'
,
type
=
str
,
default
=
'create_waymo_infos'
,
help
=
''
)
parser
.
add_argument
(
'--processed_data_tag'
,
type
=
str
,
default
=
'waymo_processed_data_v0_5_0'
,
help
=
''
)
parser
.
add_argument
(
'--processed_data_tag'
,
type
=
str
,
default
=
'waymo_processed_data_v0_5_0'
,
help
=
''
)
parser
.
add_argument
(
'--update_info_only'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
parser
.
add_argument
(
'--update_info_only'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
parser
.
add_argument
(
'--use_parallel'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -599,7 +720,8 @@ if __name__ == '__main__':
...
@@ -599,7 +720,8 @@ if __name__ == '__main__':
class_names
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
class_names
=
[
'Vehicle'
,
'Pedestrian'
,
'Cyclist'
],
data_path
=
ROOT_DIR
/
'data'
/
'waymo'
,
data_path
=
ROOT_DIR
/
'data'
/
'waymo'
,
save_path
=
ROOT_DIR
/
'data'
/
'waymo'
,
save_path
=
ROOT_DIR
/
'data'
/
'waymo'
,
processed_data_tag
=
args
.
processed_data_tag
processed_data_tag
=
args
.
processed_data_tag
,
use_parallel
=
args
.
use_parallel
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment