"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "30df59ed8cc0e366a9b9a3c2fa6a0c718da8d8bc"
Commit 70d5aef3 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support USE_SHARED_MEMORY=True in DBSampler with Global GT database

parent 9fc9f152
...@@ -18,19 +18,19 @@ class DataBaseSampler(object): ...@@ -18,19 +18,19 @@ class DataBaseSampler(object):
self.db_infos = {} self.db_infos = {}
for class_name in class_names: for class_name in class_names:
self.db_infos[class_name] = [] self.db_infos[class_name] = []
for db_info_path in sampler_cfg.DB_INFO_PATH: for db_info_path in sampler_cfg.DB_INFO_PATH:
db_info_path = self.root_path.resolve() / db_info_path db_info_path = self.root_path.resolve() / db_info_path
with open(str(db_info_path), 'rb') as f: with open(str(db_info_path), 'rb') as f:
infos = pickle.load(f) infos = pickle.load(f)
[self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names] [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
for func_name, val in sampler_cfg.PREPARE.items(): for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val) self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False) self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
if self.use_shared_memory: self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None
self.load_db_to_shared_memory()
self.sample_groups = {} self.sample_groups = {}
self.sample_class_num = {} self.sample_class_num = {}
...@@ -58,21 +58,19 @@ class DataBaseSampler(object): ...@@ -58,21 +58,19 @@ class DataBaseSampler(object):
def load_db_to_shared_memory(self): def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory') self.logger.info('Loading GT database to shared memory')
cur_rank, num_gpus = common_utils.get_dist_info() cur_rank, num_gpus = common_utils.get_dist_info()
for cur_class in self.class_names:
cur_info_list = self.db_infos[cur_class]
cur_info_list = cur_info_list[cur_rank::num_gpus]
for info in cur_info_list:
file_path = self.root_path / info['path']
sa_key = info['path'].replace('/', '___')
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape([-1, self.sampler_cfg.NUM_POINT_FEATURES]) assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, 'Current only support single DB_DATA'
common_utils.sa_create(f"shm://{sa_key}", obj_points) db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0]
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
dist.barrier()
if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"):
gt_database_data = np.load(db_data_path)
common_utils.sa_create(f"shm://{sa_key}", gt_database_data)
if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been saved to shared memory') self.logger.info('GT database has been saved to shared memory')
return sa_key
def filter_by_difficulty(self, db_infos, removed_difficulty): def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {} new_db_infos = {}
...@@ -155,10 +153,15 @@ class DataBaseSampler(object): ...@@ -155,10 +153,15 @@ class DataBaseSampler(object):
data_dict.pop('road_plane') data_dict.pop('road_plane')
obj_points_list = [] obj_points_list = []
if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
else:
gt_database_data = None
for idx, info in enumerate(total_valid_sampled_dict): for idx, info in enumerate(total_valid_sampled_dict):
if self.use_shared_memory: if self.use_shared_memory:
sa_key = info['path'].replace('/', '___') start_offset, end_offset = info['global_data_offset']
obj_points = SharedArray.attach(f"shm://{sa_key}").copy() obj_points = gt_database_data[start_offset:end_offset]
else: else:
file_path = self.root_path / info['path'] file_path = self.root_path / info['path']
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape( obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
......
...@@ -24,6 +24,12 @@ DATA_AUGMENTOR: ...@@ -24,6 +24,12 @@ DATA_AUGMENTOR:
USE_ROAD_PLANE: False USE_ROAD_PLANE: False
DB_INFO_PATH: DB_INFO_PATH:
- waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl
# - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.pkl
USE_SHARED_MEMORY: False
DB_DATA_PATH:
- waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.npy
PREPARE: { PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'], filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1], filter_by_difficulty: [-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment