"googlemock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "0c799d0436e1b6d867c1738f6ff58166d153cacc"
Commit 70d5aef3 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support USE_SHARED_MEMORY=True in DBSampler with Global GT database

parent 9fc9f152
...@@ -25,12 +25,12 @@ class DataBaseSampler(object): ...@@ -25,12 +25,12 @@ class DataBaseSampler(object):
infos = pickle.load(f) infos = pickle.load(f)
[self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names] [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
for func_name, val in sampler_cfg.PREPARE.items(): for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val) self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False) self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
if self.use_shared_memory: self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None
self.load_db_to_shared_memory()
self.sample_groups = {} self.sample_groups = {}
self.sample_class_num = {} self.sample_class_num = {}
...@@ -58,21 +58,19 @@ class DataBaseSampler(object): ...@@ -58,21 +58,19 @@ class DataBaseSampler(object):
def load_db_to_shared_memory(self): def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory') self.logger.info('Loading GT database to shared memory')
cur_rank, num_gpus = common_utils.get_dist_info() cur_rank, num_gpus = common_utils.get_dist_info()
for cur_class in self.class_names:
cur_info_list = self.db_infos[cur_class]
cur_info_list = cur_info_list[cur_rank::num_gpus]
for info in cur_info_list: assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, 'Current only support single DB_DATA'
file_path = self.root_path / info['path'] db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0]
sa_key = info['path'].replace('/', '___') sa_key = self.sampler_cfg.DB_DATA_PATH[0]
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape([-1, self.sampler_cfg.NUM_POINT_FEATURES]) if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"):
common_utils.sa_create(f"shm://{sa_key}", obj_points) gt_database_data = np.load(db_data_path)
common_utils.sa_create(f"shm://{sa_key}", gt_database_data)
if num_gpus > 1:
dist.barrier() dist.barrier()
self.logger.info('GT database has been saved to shared memory') self.logger.info('GT database has been saved to shared memory')
return sa_key
def filter_by_difficulty(self, db_infos, removed_difficulty): def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {} new_db_infos = {}
...@@ -155,10 +153,15 @@ class DataBaseSampler(object): ...@@ -155,10 +153,15 @@ class DataBaseSampler(object):
data_dict.pop('road_plane') data_dict.pop('road_plane')
obj_points_list = [] obj_points_list = []
if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
else:
gt_database_data = None
for idx, info in enumerate(total_valid_sampled_dict): for idx, info in enumerate(total_valid_sampled_dict):
if self.use_shared_memory: if self.use_shared_memory:
sa_key = info['path'].replace('/', '___') start_offset, end_offset = info['global_data_offset']
obj_points = SharedArray.attach(f"shm://{sa_key}").copy() obj_points = gt_database_data[start_offset:end_offset]
else: else:
file_path = self.root_path / info['path'] file_path = self.root_path / info['path']
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape( obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
......
...@@ -24,6 +24,12 @@ DATA_AUGMENTOR: ...@@ -24,6 +24,12 @@ DATA_AUGMENTOR:
USE_ROAD_PLANE: False USE_ROAD_PLANE: False
DB_INFO_PATH: DB_INFO_PATH:
- waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1.pkl
# - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.pkl
USE_SHARED_MEMORY: False
DB_DATA_PATH:
- waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.npy
PREPARE: { PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'], filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1], filter_by_difficulty: [-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment