Commit 880257ab authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support USE_SHARED_MEMORY=True in WaymoDataset

parent 01b425bd
......@@ -9,6 +9,8 @@ import copy
import numpy as np
import torch
import multiprocessing
import SharedArray
import torch.distributed as dist
from tqdm import tqdm
from pathlib import Path
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
......@@ -29,6 +31,11 @@ class WaymoDataset(DatasetTemplate):
self.infos = []
self.include_waymo_data(self.mode)
self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
if self.use_shared_memory:
self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
self.load_data_to_shared_memory()
def set_split(self, split):
super().__init__(
dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
......@@ -67,6 +74,28 @@ class WaymoDataset(DatasetTemplate):
self.infos = sampled_waymo_infos
self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
def load_data_to_shared_memory(self):
self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')
cur_rank, num_gpus = common_utils.get_dist_info()
all_infos = self.infos[:self.shared_memory_file_limit] \
if self.shared_memory_file_limit < len(self.infos) else self.infos
cur_infos = all_infos[cur_rank::num_gpus]
for info in cur_infos:
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
sa_key = f'{sequence_name}___{sample_idx}'
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
points = self.get_lidar(sequence_name, sample_idx)
common_utils.sa_create(f"shm://{sa_key}", points)
dist.barrier()
self.logger.info('Training data has been saved to shared memory')
@staticmethod
def check_sequence_name_with_all_version(sequence_file):
if not sequence_file.exists():
......@@ -128,7 +157,12 @@ class WaymoDataset(DatasetTemplate):
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
points = self.get_lidar(sequence_name, sample_idx)
if self.use_shared_memory and index < self.shared_memory_file_limit:
sa_key = f'{sequence_name}___{sample_idx}'
points = SharedArray.attach(f"shm://{sa_key}").copy()
else:
points = self.get_lidar(sequence_name, sample_idx)
input_dict = {
'points': points,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment