Commit 70d5aef3 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support USE_SHARED_MEMORY=True in DBSampler with Global GT database

parent 9fc9f152
......@@ -25,12 +25,12 @@ class DataBaseSampler(object):
infos = pickle.load(f)
[self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
if self.use_shared_memory:
self.load_db_to_shared_memory()
self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None
self.sample_groups = {}
self.sample_class_num = {}
......@@ -58,21 +58,19 @@ class DataBaseSampler(object):
def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory')
cur_rank, num_gpus = common_utils.get_dist_info()
for cur_class in self.class_names:
cur_info_list = self.db_infos[cur_class]
cur_info_list = cur_info_list[cur_rank::num_gpus]
for info in cur_info_list:
file_path = self.root_path / info['path']
sa_key = info['path'].replace('/', '___')
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, 'Current only support single DB_DATA'
db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0]
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape([-1, self.sampler_cfg.NUM_POINT_FEATURES])
common_utils.sa_create(f"shm://{sa_key}", obj_points)
if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"):
gt_database_data = np.load(db_data_path)
common_utils.sa_create(f"shm://{sa_key}", gt_database_data)
if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been saved to shared memory')
return sa_key
def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {}
......@@ -155,10 +153,15 @@ class DataBaseSampler(object):
data_dict.pop('road_plane')
obj_points_list = []
if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
else:
gt_database_data = None
for idx, info in enumerate(total_valid_sampled_dict):
if self.use_shared_memory:
sa_key = info['path'].replace('/', '___')
obj_points = SharedArray.attach(f"shm://{sa_key}").copy()
start_offset, end_offset = info['global_data_offset']
obj_points = gt_database_data[start_offset:end_offset]
else:
file_path = self.root_path / info['path']
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
......
......@@ -24,6 +24,12 @@ DATA_AUGMENTOR:
USE_ROAD_PLANE: False
DB_INFO_PATH:
- waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl
# - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.pkl
USE_SHARED_MEMORY: False
DB_DATA_PATH:
- waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.npy
PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment