Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
70d5aef3
Commit
70d5aef3
authored
Sep 11, 2021
by
Shaoshuai Shi
Browse files
support USE_SHARED_MEMORY=True in DBSampler with Global GT database
parent
9fc9f152
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
18 deletions
+27
-18
pcdet/datasets/augmentor/database_sampler.py
pcdet/datasets/augmentor/database_sampler.py
+21
-18
tools/cfgs/dataset_configs/waymo_dataset.yaml
tools/cfgs/dataset_configs/waymo_dataset.yaml
+6
-0
No files found.
pcdet/datasets/augmentor/database_sampler.py
View file @
70d5aef3
...
@@ -25,12 +25,12 @@ class DataBaseSampler(object):
...
@@ -25,12 +25,12 @@ class DataBaseSampler(object):
infos
=
pickle
.
load
(
f
)
infos
=
pickle
.
load
(
f
)
[
self
.
db_infos
[
cur_class
].
extend
(
infos
[
cur_class
])
for
cur_class
in
class_names
]
[
self
.
db_infos
[
cur_class
].
extend
(
infos
[
cur_class
])
for
cur_class
in
class_names
]
for
func_name
,
val
in
sampler_cfg
.
PREPARE
.
items
():
for
func_name
,
val
in
sampler_cfg
.
PREPARE
.
items
():
self
.
db_infos
=
getattr
(
self
,
func_name
)(
self
.
db_infos
,
val
)
self
.
db_infos
=
getattr
(
self
,
func_name
)(
self
.
db_infos
,
val
)
self
.
use_shared_memory
=
sampler_cfg
.
get
(
'USE_SHARED_MEMORY'
,
False
)
self
.
use_shared_memory
=
sampler_cfg
.
get
(
'USE_SHARED_MEMORY'
,
False
)
if
self
.
use_shared_memory
:
self
.
gt_database_data_key
=
self
.
load_db_to_shared_memory
()
if
self
.
use_shared_memory
else
None
self
.
load_db_to_shared_memory
()
self
.
sample_groups
=
{}
self
.
sample_groups
=
{}
self
.
sample_class_num
=
{}
self
.
sample_class_num
=
{}
...
@@ -58,21 +58,19 @@ class DataBaseSampler(object):
...
@@ -58,21 +58,19 @@ class DataBaseSampler(object):
def
load_db_to_shared_memory
(
self
):
def
load_db_to_shared_memory
(
self
):
self
.
logger
.
info
(
'Loading GT database to shared memory'
)
self
.
logger
.
info
(
'Loading GT database to shared memory'
)
cur_rank
,
num_gpus
=
common_utils
.
get_dist_info
()
cur_rank
,
num_gpus
=
common_utils
.
get_dist_info
()
for
cur_class
in
self
.
class_names
:
cur_info_list
=
self
.
db_infos
[
cur_class
]
cur_info_list
=
cur_info_list
[
cur_rank
::
num_gpus
]
for
info
in
cur_info_list
:
assert
self
.
sampler_cfg
.
DB_DATA_PATH
.
__len__
()
==
1
,
'Current only support single DB_DATA'
file_path
=
self
.
root_path
/
info
[
'path'
]
db_data_path
=
self
.
root_path
.
resolve
()
/
self
.
sampler_cfg
.
DB_DATA_PATH
[
0
]
sa_key
=
info
[
'path'
].
replace
(
'/'
,
'___'
)
sa_key
=
self
.
sampler_cfg
.
DB_DATA_PATH
[
0
]
if
os
.
path
.
exists
(
f
"/dev/shm/
{
sa_key
}
"
):
continue
obj_points
=
np
.
fromfile
(
str
(
file_path
),
dtype
=
np
.
float32
).
reshape
([
-
1
,
self
.
sampler_cfg
.
NUM_POINT_FEATURES
])
if
cur_rank
%
num_gpus
==
0
and
not
os
.
path
.
exists
(
f
"/dev/shm/
{
sa_key
}
"
):
common_utils
.
sa_create
(
f
"shm://
{
sa_key
}
"
,
obj_points
)
gt_database_data
=
np
.
load
(
db_data_path
)
common_utils
.
sa_create
(
f
"shm://
{
sa_key
}
"
,
gt_database_data
)
if
num_gpus
>
1
:
dist
.
barrier
()
dist
.
barrier
()
self
.
logger
.
info
(
'GT database has been saved to shared memory'
)
self
.
logger
.
info
(
'GT database has been saved to shared memory'
)
return
sa_key
def
filter_by_difficulty
(
self
,
db_infos
,
removed_difficulty
):
def
filter_by_difficulty
(
self
,
db_infos
,
removed_difficulty
):
new_db_infos
=
{}
new_db_infos
=
{}
...
@@ -155,10 +153,15 @@ class DataBaseSampler(object):
...
@@ -155,10 +153,15 @@ class DataBaseSampler(object):
data_dict
.
pop
(
'road_plane'
)
data_dict
.
pop
(
'road_plane'
)
obj_points_list
=
[]
obj_points_list
=
[]
if
self
.
use_shared_memory
:
gt_database_data
=
SharedArray
.
attach
(
f
"shm://
{
self
.
gt_database_data_key
}
"
)
else
:
gt_database_data
=
None
for
idx
,
info
in
enumerate
(
total_valid_sampled_dict
):
for
idx
,
info
in
enumerate
(
total_valid_sampled_dict
):
if
self
.
use_shared_memory
:
if
self
.
use_shared_memory
:
s
a_key
=
info
[
'path'
].
replace
(
'/'
,
'___'
)
s
tart_offset
,
end_offset
=
info
[
'global_data_offset'
]
obj_points
=
SharedArray
.
attach
(
f
"shm://
{
sa_key
}
"
).
copy
()
obj_points
=
gt_database_data
[
start_offset
:
end_offset
]
else
:
else
:
file_path
=
self
.
root_path
/
info
[
'path'
]
file_path
=
self
.
root_path
/
info
[
'path'
]
obj_points
=
np
.
fromfile
(
str
(
file_path
),
dtype
=
np
.
float32
).
reshape
(
obj_points
=
np
.
fromfile
(
str
(
file_path
),
dtype
=
np
.
float32
).
reshape
(
...
...
tools/cfgs/dataset_configs/waymo_dataset.yaml
View file @
70d5aef3
...
@@ -24,6 +24,12 @@ DATA_AUGMENTOR:
...
@@ -24,6 +24,12 @@ DATA_AUGMENTOR:
USE_ROAD_PLANE
:
False
USE_ROAD_PLANE
:
False
DB_INFO_PATH
:
DB_INFO_PATH
:
-
waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl
-
waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl
# - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.pkl
USE_SHARED_MEMORY
:
False
DB_DATA_PATH
:
-
waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.npy
PREPARE
:
{
PREPARE
:
{
filter_by_min_points
:
[
'
Vehicle:5'
,
'
Pedestrian:5'
,
'
Cyclist:5'
],
filter_by_min_points
:
[
'
Vehicle:5'
,
'
Pedestrian:5'
,
'
Cyclist:5'
],
filter_by_difficulty
:
[
-1
],
filter_by_difficulty
:
[
-1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment