Commit 2e4d090e authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support SharedMemory option for WOD Data and GT_database to speed up the training process

parents 14e2adc5 14fdedca
import pickle
import os
import copy
import numpy as np
import SharedArray
import torch.distributed as dist
from ...ops.iou3d_nms import iou3d_nms_utils
from ...utils import box_utils
from ...utils import box_utils, common_utils
class DataBaseSampler(object):
......@@ -25,9 +29,13 @@ class DataBaseSampler(object):
for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None
self.sample_groups = {}
self.sample_class_num = {}
self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False)
for x in sampler_cfg.SAMPLE_GROUPS:
class_name, sample_num = x.split(':')
if class_name not in class_names:
......@@ -47,6 +55,35 @@ class DataBaseSampler(object):
def __setstate__(self, d):
self.__dict__.update(d)
def __del__(self):
if self.use_shared_memory:
self.logger.info('Deleting GT database from shared memory')
cur_rank, num_gpus = common_utils.get_dist_info()
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
if cur_rank % num_gpus == 0 and os.path.exists(f"/dev/shm/{sa_key}"):
SharedArray.delete(f"shm://{sa_key}")
if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been removed from shared memory')
def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory')
cur_rank, world_size, num_gpus = common_utils.get_dist_info(return_gpu_per_machine=True)
assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, 'Current only support single DB_DATA'
db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0]
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"):
gt_database_data = np.load(db_data_path)
common_utils.sa_create(f"shm://{sa_key}", gt_database_data)
if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been saved to shared memory')
return sa_key
def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {}
for key, dinfos in db_infos.items():
......@@ -128,7 +165,17 @@ class DataBaseSampler(object):
data_dict.pop('road_plane')
obj_points_list = []
if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
gt_database_data.setflags(write=0)
else:
gt_database_data = None
for idx, info in enumerate(total_valid_sampled_dict):
if self.use_shared_memory:
start_offset, end_offset = info['global_data_offset']
obj_points = copy.deepcopy(gt_database_data[start_offset:end_offset])
else:
file_path = self.root_path / info['path']
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
[-1, self.sampler_cfg.NUM_POINT_FEATURES])
......
......@@ -9,6 +9,8 @@ import copy
import numpy as np
import torch
import multiprocessing
import SharedArray
import torch.distributed as dist
from tqdm import tqdm
from pathlib import Path
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
......@@ -29,6 +31,11 @@ class WaymoDataset(DatasetTemplate):
self.infos = []
self.include_waymo_data(self.mode)
self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
if self.use_shared_memory:
self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
self.load_data_to_shared_memory()
def set_split(self, split):
super().__init__(
dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
......@@ -67,6 +74,50 @@ class WaymoDataset(DatasetTemplate):
self.infos = sampled_waymo_infos
self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
def load_data_to_shared_memory(self):
self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')
cur_rank, num_gpus = common_utils.get_dist_info()
all_infos = self.infos[:self.shared_memory_file_limit] \
if self.shared_memory_file_limit < len(self.infos) else self.infos
cur_infos = all_infos[cur_rank::num_gpus]
for info in cur_infos:
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
sa_key = f'{sequence_name}___{sample_idx}'
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
points = self.get_lidar(sequence_name, sample_idx)
common_utils.sa_create(f"shm://{sa_key}", points)
dist.barrier()
self.logger.info('Training data has been saved to shared memory')
def clean_shared_memory(self):
self.logger.info(f'Clean training data from shared memory (file limit={self.shared_memory_file_limit})')
cur_rank, num_gpus = common_utils.get_dist_info()
all_infos = self.infos[:self.shared_memory_file_limit] \
if self.shared_memory_file_limit < len(self.infos) else self.infos
cur_infos = all_infos[cur_rank::num_gpus]
for info in cur_infos:
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
sa_key = f'{sequence_name}___{sample_idx}'
if not os.path.exists(f"/dev/shm/{sa_key}"):
continue
SharedArray.delete(f"shm://{sa_key}")
if num_gpus > 1:
dist.barrier()
self.logger.info('Training data has been deleted from shared memory')
@staticmethod
def check_sequence_name_with_all_version(sequence_file):
if not sequence_file.exists():
......@@ -128,6 +179,11 @@ class WaymoDataset(DatasetTemplate):
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
if self.use_shared_memory and index < self.shared_memory_file_limit:
sa_key = f'{sequence_name}___{sample_idx}'
points = SharedArray.attach(f"shm://{sa_key}").copy()
else:
points = self.get_lidar(sequence_name, sample_idx)
input_dict = {
......
......@@ -4,6 +4,7 @@ import pickle
import random
import shutil
import subprocess
import SharedArray
import numpy as np
import torch
......@@ -172,7 +173,7 @@ def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
return num_gpus, rank
def get_dist_info():
def get_dist_info(return_gpu_per_machine=False):
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
......@@ -186,6 +187,11 @@ def get_dist_info():
else:
rank = 0
world_size = 1
if return_gpu_per_machine:
gpu_per_machine = torch.cuda.device_count()
return rank, world_size, gpu_per_machine
return rank, world_size
......@@ -233,3 +239,9 @@ def generate_voxel2pinds(sparse_tensor):
return v2pinds_tensor
def sa_create(name, var):
x = SharedArray.create(name, var.shape, dtype=var.dtype)
x[...] = var[...]
x.flags.writeable = False
return x
......@@ -14,6 +14,9 @@ SAMPLED_INTERVAL: {
'test': 5
}
USE_SHARED_MEMORY: False
SHARED_MEMORY_FILE_LIMIT: 35000 # set it based on the size of your shared memory
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
......@@ -21,6 +24,12 @@ DATA_AUGMENTOR:
USE_ROAD_PLANE: False
DB_INFO_PATH:
- waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl
# - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.pkl
USE_SHARED_MEMORY: False
DB_DATA_PATH: # this file should be generated along with the above DB_INFO_PATH by setting USE_SHARED_MEMORY=True
- waymo_processed_data_v0_3_1_gt_database_train_sampled_1_global.npy
PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1],
......
......@@ -193,6 +193,8 @@ def main():
logger.info('**********************End evaluation %s/%s(%s)**********************' %
(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
train_set.clean_shared_memory()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment