Commit 2e4d090e authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support SharedMemory option for WOD Data and GT_database to speed up the training process

parents 14e2adc5 14fdedca
import pickle import pickle
import os
import copy
import numpy as np import numpy as np
import SharedArray
import torch.distributed as dist
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from ...utils import box_utils from ...utils import box_utils, common_utils
class DataBaseSampler(object): class DataBaseSampler(object):
...@@ -15,7 +19,7 @@ class DataBaseSampler(object): ...@@ -15,7 +19,7 @@ class DataBaseSampler(object):
self.db_infos = {} self.db_infos = {}
for class_name in class_names: for class_name in class_names:
self.db_infos[class_name] = [] self.db_infos[class_name] = []
for db_info_path in sampler_cfg.DB_INFO_PATH: for db_info_path in sampler_cfg.DB_INFO_PATH:
db_info_path = self.root_path.resolve() / db_info_path db_info_path = self.root_path.resolve() / db_info_path
with open(str(db_info_path), 'rb') as f: with open(str(db_info_path), 'rb') as f:
...@@ -25,9 +29,13 @@ class DataBaseSampler(object): ...@@ -25,9 +29,13 @@ class DataBaseSampler(object):
for func_name, val in sampler_cfg.PREPARE.items(): for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val) self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None
self.sample_groups = {} self.sample_groups = {}
self.sample_class_num = {} self.sample_class_num = {}
self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False) self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False)
for x in sampler_cfg.SAMPLE_GROUPS: for x in sampler_cfg.SAMPLE_GROUPS:
class_name, sample_num = x.split(':') class_name, sample_num = x.split(':')
if class_name not in class_names: if class_name not in class_names:
...@@ -47,6 +55,35 @@ class DataBaseSampler(object): ...@@ -47,6 +55,35 @@ class DataBaseSampler(object):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
def __del__(self):
if self.use_shared_memory:
self.logger.info('Deleting GT database from shared memory')
cur_rank, num_gpus = common_utils.get_dist_info()
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
if cur_rank % num_gpus == 0 and os.path.exists(f"/dev/shm/{sa_key}"):
SharedArray.delete(f"shm://{sa_key}")
if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been removed from shared memory')
def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory')
cur_rank, world_size, num_gpus = common_utils.get_dist_info(return_gpu_per_machine=True)
assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, 'Current only support single DB_DATA'
db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0]
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"):
gt_database_data = np.load(db_data_path)
common_utils.sa_create(f"shm://{sa_key}", gt_database_data)
if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been saved to shared memory')
return sa_key
def filter_by_difficulty(self, db_infos, removed_difficulty): def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {} new_db_infos = {}
for key, dinfos in db_infos.items(): for key, dinfos in db_infos.items():
...@@ -128,10 +165,20 @@ class DataBaseSampler(object): ...@@ -128,10 +165,20 @@ class DataBaseSampler(object):
data_dict.pop('road_plane') data_dict.pop('road_plane')
obj_points_list = [] obj_points_list = []
if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
gt_database_data.setflags(write=0)
else:
gt_database_data = None
for idx, info in enumerate(total_valid_sampled_dict): for idx, info in enumerate(total_valid_sampled_dict):
file_path = self.root_path / info['path'] if self.use_shared_memory:
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape( start_offset, end_offset = info['global_data_offset']
[-1, self.sampler_cfg.NUM_POINT_FEATURES]) obj_points = copy.deepcopy(gt_database_data[start_offset:end_offset])
else:
file_path = self.root_path / info['path']
obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
[-1, self.sampler_cfg.NUM_POINT_FEATURES])
obj_points[:, :3] += info['box3d_lidar'][:3] obj_points[:, :3] += info['box3d_lidar'][:3]
......
...@@ -9,6 +9,8 @@ import copy ...@@ -9,6 +9,8 @@ import copy
import numpy as np import numpy as np
import torch import torch
import multiprocessing import multiprocessing
import SharedArray
import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from ...ops.roiaware_pool3d import roiaware_pool3d_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils
...@@ -29,6 +31,11 @@ class WaymoDataset(DatasetTemplate): ...@@ -29,6 +31,11 @@ class WaymoDataset(DatasetTemplate):
self.infos = [] self.infos = []
self.include_waymo_data(self.mode) self.include_waymo_data(self.mode)
self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
if self.use_shared_memory:
self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
self.load_data_to_shared_memory()
def set_split(self, split): def set_split(self, split):
super().__init__( super().__init__(
dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
...@@ -67,6 +74,50 @@ class WaymoDataset(DatasetTemplate): ...@@ -67,6 +74,50 @@ class WaymoDataset(DatasetTemplate):
self.infos = sampled_waymo_infos self.infos = sampled_waymo_infos
self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos)) self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
def load_data_to_shared_memory(self):
self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')
cur_rank, num_gpus = common_utils.get_dist_info()
all_infos = self.infos[:self.shared_memory_file_limit] \
if self.shared_memory_file_limit < len(self.infos) else self.infos
cur_infos = all_infos[cur_rank::num_gpus]
for info in cur_infos:
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
sa_key = f'{sequence_name}___{sample_idx}'
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
points = self.get_lidar(sequence_name, sample_idx)
common_utils.sa_create(f"shm://{sa_key}", points)
dist.barrier()
self.logger.info('Training data has been saved to shared memory')
def clean_shared_memory(self):
self.logger.info(f'Clean training data from shared memory (file limit={self.shared_memory_file_limit})')
cur_rank, num_gpus = common_utils.get_dist_info()
all_infos = self.infos[:self.shared_memory_file_limit] \
if self.shared_memory_file_limit < len(self.infos) else self.infos
cur_infos = all_infos[cur_rank::num_gpus]
for info in cur_infos:
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
sa_key = f'{sequence_name}___{sample_idx}'
if not os.path.exists(f"/dev/shm/{sa_key}"):
continue
SharedArray.delete(f"shm://{sa_key}")
if num_gpus > 1:
dist.barrier()
self.logger.info('Training data has been deleted from shared memory')
@staticmethod @staticmethod
def check_sequence_name_with_all_version(sequence_file): def check_sequence_name_with_all_version(sequence_file):
if not sequence_file.exists(): if not sequence_file.exists():
...@@ -128,7 +179,12 @@ class WaymoDataset(DatasetTemplate): ...@@ -128,7 +179,12 @@ class WaymoDataset(DatasetTemplate):
pc_info = info['point_cloud'] pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence'] sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx'] sample_idx = pc_info['sample_idx']
points = self.get_lidar(sequence_name, sample_idx)
if self.use_shared_memory and index < self.shared_memory_file_limit:
sa_key = f'{sequence_name}___{sample_idx}'
points = SharedArray.attach(f"shm://{sa_key}").copy()
else:
points = self.get_lidar(sequence_name, sample_idx)
input_dict = { input_dict = {
'points': points, 'points': points,
......
...@@ -4,6 +4,7 @@ import pickle ...@@ -4,6 +4,7 @@ import pickle
import random import random
import shutil import shutil
import subprocess import subprocess
import SharedArray
import numpy as np import numpy as np
import torch import torch
...@@ -172,7 +173,7 @@ def init_dist_pytorch(tcp_port, local_rank, backend='nccl'): ...@@ -172,7 +173,7 @@ def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
return num_gpus, rank return num_gpus, rank
def get_dist_info(): def get_dist_info(return_gpu_per_machine=False):
if torch.__version__ < '1.0': if torch.__version__ < '1.0':
initialized = dist._initialized initialized = dist._initialized
else: else:
...@@ -186,6 +187,11 @@ def get_dist_info(): ...@@ -186,6 +187,11 @@ def get_dist_info():
else: else:
rank = 0 rank = 0
world_size = 1 world_size = 1
if return_gpu_per_machine:
gpu_per_machine = torch.cuda.device_count()
return rank, world_size, gpu_per_machine
return rank, world_size return rank, world_size
...@@ -233,3 +239,9 @@ def generate_voxel2pinds(sparse_tensor): ...@@ -233,3 +239,9 @@ def generate_voxel2pinds(sparse_tensor):
return v2pinds_tensor return v2pinds_tensor
def sa_create(name, var):
x = SharedArray.create(name, var.shape, dtype=var.dtype)
x[...] = var[...]
x.flags.writeable = False
return x
...@@ -8,3 +8,4 @@ scikit-image ...@@ -8,3 +8,4 @@ scikit-image
tqdm tqdm
kornia kornia
torchvision torchvision
SharedArray
...@@ -14,6 +14,9 @@ SAMPLED_INTERVAL: { ...@@ -14,6 +14,9 @@ SAMPLED_INTERVAL: {
'test': 5 'test': 5
} }
USE_SHARED_MEMORY: False
SHARED_MEMORY_FILE_LIMIT: 35000 # set it based on the size of your shared memory
DATA_AUGMENTOR: DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder'] DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST: AUG_CONFIG_LIST:
...@@ -21,6 +24,12 @@ DATA_AUGMENTOR: ...@@ -21,6 +24,12 @@ DATA_AUGMENTOR:
USE_ROAD_PLANE: False USE_ROAD_PLANE: False
DB_INFO_PATH: DB_INFO_PATH:
- waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl - waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl
# - waymo_processed_data_v0_3_1_waymo_dbinfos_train_sampled_1_global.pkl
USE_SHARED_MEMORY: False
DB_DATA_PATH: # this file should be generated along with the above DB_INFO_PATH by setting USE_SHARED_MEMORY=True
- waymo_processed_data_v0_3_1_gt_database_train_sampled_1_global.npy
PREPARE: { PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'], filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1], filter_by_difficulty: [-1],
......
...@@ -193,6 +193,8 @@ def main(): ...@@ -193,6 +193,8 @@ def main():
logger.info('**********************End evaluation %s/%s(%s)**********************' % logger.info('**********************End evaluation %s/%s(%s)**********************' %
(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
train_set.clean_shared_memory()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment