Commit 880257ab authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support USE_SHARED_MEMORY=True in WaymoDataset

parent 01b425bd
...@@ -9,6 +9,8 @@ import copy ...@@ -9,6 +9,8 @@ import copy
import numpy as np import numpy as np
import torch import torch
import multiprocessing import multiprocessing
import SharedArray
import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from ...ops.roiaware_pool3d import roiaware_pool3d_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils
...@@ -29,6 +31,11 @@ class WaymoDataset(DatasetTemplate): ...@@ -29,6 +31,11 @@ class WaymoDataset(DatasetTemplate):
self.infos = [] self.infos = []
self.include_waymo_data(self.mode) self.include_waymo_data(self.mode)
self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
if self.use_shared_memory:
self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
self.load_data_to_shared_memory()
def set_split(self, split): def set_split(self, split):
super().__init__( super().__init__(
dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
...@@ -67,6 +74,28 @@ class WaymoDataset(DatasetTemplate): ...@@ -67,6 +74,28 @@ class WaymoDataset(DatasetTemplate):
self.infos = sampled_waymo_infos self.infos = sampled_waymo_infos
self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos)) self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
def load_data_to_shared_memory(self):
self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')
cur_rank, num_gpus = common_utils.get_dist_info()
all_infos = self.infos[:self.shared_memory_file_limit] \
if self.shared_memory_file_limit < len(self.infos) else self.infos
cur_infos = all_infos[cur_rank::num_gpus]
for info in cur_infos:
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
sa_key = f'{sequence_name}___{sample_idx}'
if os.path.exists(f"/dev/shm/{sa_key}"):
continue
points = self.get_lidar(sequence_name, sample_idx)
common_utils.sa_create(f"shm://{sa_key}", points)
dist.barrier()
self.logger.info('Training data has been saved to shared memory')
@staticmethod @staticmethod
def check_sequence_name_with_all_version(sequence_file): def check_sequence_name_with_all_version(sequence_file):
if not sequence_file.exists(): if not sequence_file.exists():
...@@ -128,6 +157,11 @@ class WaymoDataset(DatasetTemplate): ...@@ -128,6 +157,11 @@ class WaymoDataset(DatasetTemplate):
pc_info = info['point_cloud'] pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence'] sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx'] sample_idx = pc_info['sample_idx']
if self.use_shared_memory and index < self.shared_memory_file_limit:
sa_key = f'{sequence_name}___{sample_idx}'
points = SharedArray.attach(f"shm://{sa_key}").copy()
else:
points = self.get_lidar(sequence_name, sample_idx) points = self.get_lidar(sequence_name, sample_idx)
input_dict = { input_dict = {
......
...@@ -8,3 +8,4 @@ scikit-image ...@@ -8,3 +8,4 @@ scikit-image
tqdm tqdm
kornia kornia
torchvision torchvision
SharedArray
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment