Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
01b425bd
"pcdet/ops/vscode:/vscode.git/clone" did not exist on "8a64de5d41359d6fb84c5644caac4f9636c3bd27"
Commit
01b425bd
authored
Sep 09, 2021
by
Shaoshuai Shi
Browse files
support USE_SHARED_MEMORY=True for GT sampling
parent
ee11621c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
4 deletions
+42
-4
pcdet/datasets/augmentor/database_sampler.py
pcdet/datasets/augmentor/database_sampler.py
+35
-4
pcdet/utils/common_utils.py
pcdet/utils/common_utils.py
+7
-0
No files found.
pcdet/datasets/augmentor/database_sampler.py
View file @
01b425bd
import
pickle
import
pickle
import
os
import
numpy
as
np
import
numpy
as
np
import
SharedArray
import
torch.distributed
as
dist
from
...ops.iou3d_nms
import
iou3d_nms_utils
from
...ops.iou3d_nms
import
iou3d_nms_utils
from
...utils
import
box_utils
from
...utils
import
box_utils
,
common_utils
class
DataBaseSampler
(
object
):
class
DataBaseSampler
(
object
):
...
@@ -25,9 +28,14 @@ class DataBaseSampler(object):
...
@@ -25,9 +28,14 @@ class DataBaseSampler(object):
for
func_name
,
val
in
sampler_cfg
.
PREPARE
.
items
():
for
func_name
,
val
in
sampler_cfg
.
PREPARE
.
items
():
self
.
db_infos
=
getattr
(
self
,
func_name
)(
self
.
db_infos
,
val
)
self
.
db_infos
=
getattr
(
self
,
func_name
)(
self
.
db_infos
,
val
)
self
.
use_shared_memory
=
sampler_cfg
.
get
(
'USE_SHARED_MEMORY'
,
False
)
if
self
.
use_shared_memory
:
self
.
load_db_to_shared_memory
()
self
.
sample_groups
=
{}
self
.
sample_groups
=
{}
self
.
sample_class_num
=
{}
self
.
sample_class_num
=
{}
self
.
limit_whole_scene
=
sampler_cfg
.
get
(
'LIMIT_WHOLE_SCENE'
,
False
)
self
.
limit_whole_scene
=
sampler_cfg
.
get
(
'LIMIT_WHOLE_SCENE'
,
False
)
for
x
in
sampler_cfg
.
SAMPLE_GROUPS
:
for
x
in
sampler_cfg
.
SAMPLE_GROUPS
:
class_name
,
sample_num
=
x
.
split
(
':'
)
class_name
,
sample_num
=
x
.
split
(
':'
)
if
class_name
not
in
class_names
:
if
class_name
not
in
class_names
:
...
@@ -47,6 +55,25 @@ class DataBaseSampler(object):
...
@@ -47,6 +55,25 @@ class DataBaseSampler(object):
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
def
load_db_to_shared_memory
(
self
):
self
.
logger
.
info
(
'Loading GT database to shared memory'
)
cur_rank
,
num_gpus
=
common_utils
.
get_dist_info
()
for
cur_class
in
self
.
class_names
:
cur_info_list
=
self
.
db_infos
[
cur_class
]
cur_info_list
=
cur_info_list
[
cur_rank
::
num_gpus
]
for
info
in
cur_info_list
:
file_path
=
self
.
root_path
/
info
[
'path'
]
sa_key
=
info
[
'path'
].
replace
(
'/'
,
'___'
)
if
os
.
path
.
exists
(
f
"/dev/shm/
{
sa_key
}
"
):
continue
obj_points
=
np
.
fromfile
(
str
(
file_path
),
dtype
=
np
.
float32
).
reshape
([
-
1
,
self
.
sampler_cfg
.
NUM_POINT_FEATURES
])
common_utils
.
sa_create
(
f
"shm://
{
sa_key
}
"
,
obj_points
)
dist
.
barrier
()
self
.
logger
.
info
(
'GT database has been saved to shared memory'
)
def
filter_by_difficulty
(
self
,
db_infos
,
removed_difficulty
):
def
filter_by_difficulty
(
self
,
db_infos
,
removed_difficulty
):
new_db_infos
=
{}
new_db_infos
=
{}
for
key
,
dinfos
in
db_infos
.
items
():
for
key
,
dinfos
in
db_infos
.
items
():
...
@@ -129,6 +156,10 @@ class DataBaseSampler(object):
...
@@ -129,6 +156,10 @@ class DataBaseSampler(object):
obj_points_list
=
[]
obj_points_list
=
[]
for
idx
,
info
in
enumerate
(
total_valid_sampled_dict
):
for
idx
,
info
in
enumerate
(
total_valid_sampled_dict
):
if
self
.
use_shared_memory
:
sa_key
=
info
[
'path'
].
replace
(
'/'
,
'___'
)
obj_points
=
SharedArray
.
attach
(
f
"shm://
{
sa_key
}
"
).
copy
()
else
:
file_path
=
self
.
root_path
/
info
[
'path'
]
file_path
=
self
.
root_path
/
info
[
'path'
]
obj_points
=
np
.
fromfile
(
str
(
file_path
),
dtype
=
np
.
float32
).
reshape
(
obj_points
=
np
.
fromfile
(
str
(
file_path
),
dtype
=
np
.
float32
).
reshape
(
[
-
1
,
self
.
sampler_cfg
.
NUM_POINT_FEATURES
])
[
-
1
,
self
.
sampler_cfg
.
NUM_POINT_FEATURES
])
...
...
pcdet/utils/common_utils.py
View file @
01b425bd
...
@@ -4,6 +4,7 @@ import pickle
...
@@ -4,6 +4,7 @@ import pickle
import
random
import
random
import
shutil
import
shutil
import
subprocess
import
subprocess
import
SharedArray
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -233,3 +234,9 @@ def generate_voxel2pinds(sparse_tensor):
...
@@ -233,3 +234,9 @@ def generate_voxel2pinds(sparse_tensor):
return
v2pinds_tensor
return
v2pinds_tensor
def
sa_create
(
name
,
var
):
x
=
SharedArray
.
create
(
name
,
var
.
shape
,
dtype
=
var
.
dtype
)
x
[...]
=
var
[...]
x
.
flags
.
writeable
=
False
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment