Commit 01b425bd authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support USE_SHARED_MEMORY=True for GT sampling

parent ee11621c
import pickle import pickle
import os
import numpy as np import numpy as np
import SharedArray
import torch.distributed as dist
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from ...utils import box_utils from ...utils import box_utils, common_utils
class DataBaseSampler(object): class DataBaseSampler(object):
...@@ -25,9 +28,14 @@ class DataBaseSampler(object): ...@@ -25,9 +28,14 @@ class DataBaseSampler(object):
for func_name, val in sampler_cfg.PREPARE.items(): for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val) self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
if self.use_shared_memory:
self.load_db_to_shared_memory()
self.sample_groups = {} self.sample_groups = {}
self.sample_class_num = {} self.sample_class_num = {}
self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False) self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False)
for x in sampler_cfg.SAMPLE_GROUPS: for x in sampler_cfg.SAMPLE_GROUPS:
class_name, sample_num = x.split(':') class_name, sample_num = x.split(':')
if class_name not in class_names: if class_name not in class_names:
...@@ -47,6 +55,25 @@ class DataBaseSampler(object): ...@@ -47,6 +55,25 @@ class DataBaseSampler(object):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory')
cur_rank, num_gpus = common_utils.get_dist_info()
for cur_class in self.class_names:
cur_info_list = self.db_infos[cur_class]
cur_info_list = cur_info_list[cur_rank::num_gpus]
for info in cur_info_list:
file_path = self.root_path / info['path']
sa_key = info['path'].replace('/', '___')
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape([-1, self.sampler_cfg.NUM_POINT_FEATURES])
common_utils.sa_create(f"shm://{sa_key}", obj_points)
dist.barrier()
self.logger.info('GT database has been saved to shared memory')
def filter_by_difficulty(self, db_infos, removed_difficulty): def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {} new_db_infos = {}
for key, dinfos in db_infos.items(): for key, dinfos in db_infos.items():
...@@ -129,6 +156,10 @@ class DataBaseSampler(object): ...@@ -129,6 +156,10 @@ class DataBaseSampler(object):
obj_points_list = [] obj_points_list = []
for idx, info in enumerate(total_valid_sampled_dict): for idx, info in enumerate(total_valid_sampled_dict):
if self.use_shared_memory:
sa_key = info['path'].replace('/', '___')
obj_points = SharedArray.attach(f"shm://{sa_key}").copy()
else:
file_path = self.root_path / info['path'] file_path = self.root_path / info['path']
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape( obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
[-1, self.sampler_cfg.NUM_POINT_FEATURES]) [-1, self.sampler_cfg.NUM_POINT_FEATURES])
......
...@@ -4,6 +4,7 @@ import pickle ...@@ -4,6 +4,7 @@ import pickle
import random import random
import shutil import shutil
import subprocess import subprocess
import SharedArray
import numpy as np import numpy as np
import torch import torch
...@@ -233,3 +234,9 @@ def generate_voxel2pinds(sparse_tensor): ...@@ -233,3 +234,9 @@ def generate_voxel2pinds(sparse_tensor):
return v2pinds_tensor return v2pinds_tensor
def sa_create(name, var):
x = SharedArray.create(name, var.shape, dtype=var.dtype)
x[...] = var[...]
x.flags.writeable = False
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment