Unverified Commit a0090aa1 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support S3DIS data pre-processing and dataset class (#433)

* support S3DIS data download and pre-processing (to ScanNet format)

* add S3DIS data for unittest

* add S3DIS semseg dataset class and unittest

* add config file for S3DIS dataset

* add eval_pipeline to S3DIS dataset config file

* clean code for S3DIS pre-processing scripts

* reformat code

* fix small bugs

* resolve conflicts & modify show() to use pipeline

* fix small errors

* polish data pre-processing code

* add more comments about S3DIS dataset

* fix markdown lint error
parent 3c540f71
......@@ -134,6 +134,19 @@ def scannet_data_prep(root_path, info_prefix, out_dir, workers):
root_path, info_prefix, out_dir, workers=workers)
def s3dis_data_prep(root_path, info_prefix, out_dir, workers):
"""Prepare the info file for s3dis dataset.
Args:
root_path (str): Path of dataset root.
info_prefix (str): The prefix of info filenames.
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
"""
indoor.create_indoor_info_file(
root_path, info_prefix, out_dir, workers=workers)
def sunrgbd_data_prep(root_path, info_prefix, out_dir, workers):
"""Prepare the info file for sunrgbd dataset.
......@@ -285,6 +298,12 @@ if __name__ == '__main__':
info_prefix=args.extra_tag,
out_dir=args.out_dir,
workers=args.workers)
elif args.dataset == 's3dis':
s3dis_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
out_dir=args.out_dir,
workers=args.workers)
elif args.dataset == 'sunrgbd':
sunrgbd_data_prep(
root_path=args.root_path,
......
......@@ -2,6 +2,7 @@ import mmcv
import numpy as np
import os
from tools.data_converter.s3dis_data_utils import S3DISData, S3DISSegData
from tools.data_converter.scannet_data_utils import ScanNetData, ScanNetSegData
from tools.data_converter.sunrgbd_data_utils import SUNRGBDData
......@@ -23,30 +24,38 @@ def create_indoor_info_file(data_path,
workers (int): Number of threads to be used. Default: 4.
"""
assert os.path.exists(data_path)
assert pkl_prefix in ['sunrgbd', 'scannet']
assert pkl_prefix in ['sunrgbd', 'scannet', 's3dis'], \
f'unsupported indoor dataset {pkl_prefix}'
save_path = data_path if save_path is None else save_path
assert os.path.exists(save_path)
train_filename = os.path.join(save_path, f'{pkl_prefix}_infos_train.pkl')
val_filename = os.path.join(save_path, f'{pkl_prefix}_infos_val.pkl')
if pkl_prefix == 'sunrgbd':
train_dataset = SUNRGBDData(
root_path=data_path, split='train', use_v1=use_v1)
val_dataset = SUNRGBDData(
root_path=data_path, split='val', use_v1=use_v1)
else:
train_dataset = ScanNetData(root_path=data_path, split='train')
val_dataset = ScanNetData(root_path=data_path, split='val')
test_dataset = ScanNetData(root_path=data_path, split='test')
test_filename = os.path.join(save_path, f'{pkl_prefix}_infos_test.pkl')
# generate infos for both detection and segmentation task
if pkl_prefix in ['sunrgbd', 'scannet']:
train_filename = os.path.join(save_path,
f'{pkl_prefix}_infos_train.pkl')
val_filename = os.path.join(save_path, f'{pkl_prefix}_infos_val.pkl')
if pkl_prefix == 'sunrgbd':
# SUN RGB-D has a train-val split
train_dataset = SUNRGBDData(
root_path=data_path, split='train', use_v1=use_v1)
val_dataset = SUNRGBDData(
root_path=data_path, split='val', use_v1=use_v1)
else:
# ScanNet has a train-val-test split
train_dataset = ScanNetData(root_path=data_path, split='train')
val_dataset = ScanNetData(root_path=data_path, split='val')
test_dataset = ScanNetData(root_path=data_path, split='test')
test_filename = os.path.join(save_path,
f'{pkl_prefix}_infos_test.pkl')
infos_train = train_dataset.get_infos(num_workers=workers, has_label=True)
mmcv.dump(infos_train, train_filename, 'pkl')
print(f'{pkl_prefix} info train file is saved to {train_filename}')
infos_train = train_dataset.get_infos(
num_workers=workers, has_label=True)
mmcv.dump(infos_train, train_filename, 'pkl')
print(f'{pkl_prefix} info train file is saved to {train_filename}')
infos_val = val_dataset.get_infos(num_workers=workers, has_label=True)
mmcv.dump(infos_val, val_filename, 'pkl')
print(f'{pkl_prefix} info val file is saved to {val_filename}')
infos_val = val_dataset.get_infos(num_workers=workers, has_label=True)
mmcv.dump(infos_val, val_filename, 'pkl')
print(f'{pkl_prefix} info val file is saved to {val_filename}')
if pkl_prefix == 'scannet':
infos_test = test_dataset.get_infos(
......@@ -56,6 +65,8 @@ def create_indoor_info_file(data_path,
# generate infos for the semantic segmentation task
# e.g. re-sampled scene indexes and label weights
# scene indexes are used to re-sample rooms with different number of points
# label weights are used to balance classes with different number of points
if pkl_prefix == 'scannet':
# label weight computation function is adopted from
# https://github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py#L24
......@@ -73,6 +84,24 @@ def create_indoor_info_file(data_path,
num_points=8192,
label_weight_func=lambda x: 1.0 / np.log(1.2 + x))
# no need to generate for test set
train_dataset.get_seg_infos()
val_dataset.get_seg_infos()
else:
# S3DIS doesn't have a fixed train-val split
# it has 6 areas instead, so we generate info file for each of them
# in training, we will use dataset to wrap different areas
splits = [f'Area_{i}' for i in [1, 2, 3, 4, 5, 6]]
for split in splits:
dataset = S3DISData(root_path=data_path, split=split)
info = dataset.get_infos(num_workers=workers, has_label=True)
filename = os.path.join(save_path,
f'{pkl_prefix}_infos_{split}.pkl')
mmcv.dump(info, filename, 'pkl')
print(f'{pkl_prefix} info {split} file is saved to {filename}')
seg_dataset = S3DISSegData(
data_root=data_path,
ann_file=filename,
split=split,
num_points=4096,
label_weight_func=lambda x: 1.0 / np.log(1.2 + x))
seg_dataset.get_seg_infos()
import mmcv
import numpy as np
import os
from concurrent import futures as futures
from os import path as osp
class S3DISData(object):
"""S3DIS data.
Generate s3dis infos for s3dis_converter.
Args:
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'Area_1'.
"""
def __init__(self, root_path, split='Area_1'):
self.root_dir = root_path
self.split = split
self.data_dir = osp.join(root_path,
'Stanford3dDataset_v1.2_Aligned_Version')
self.classes = [
'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'
]
self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
self.label2cat = {self.cat2label[t]: t for t in self.cat2label}
self.cat_ids = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
self.cat_ids2class = {
cat_id: i
for i, cat_id in enumerate(list(self.cat_ids))
}
assert split in [
'Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_5', 'Area_6'
]
self.sample_id_list = os.listdir(osp.join(self.data_dir,
split)) # conferenceRoom_1
for sample_id in self.sample_id_list:
if os.path.isfile(osp.join(self.data_dir, split, sample_id)):
self.sample_id_list.remove(sample_id)
def __len__(self):
return len(self.sample_id_list)
def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
"""Get data infos.
This method gets information from the raw data.
Args:
num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True.
sample_id_list (list[int]): Index list of the sample.
Default: None.
Returns:
infos (list[dict]): Information of the raw data.
"""
def process_single_scene(sample_idx):
print(f'{self.split} sample_idx: {sample_idx}')
info = dict()
pc_info = {
'num_features': 6,
'lidar_idx': f'{self.split}_{sample_idx}'
}
info['point_cloud'] = pc_info
pts_filename = osp.join(self.root_dir, 's3dis_data',
f'{self.split}_{sample_idx}_point.npy')
pts_instance_mask_path = osp.join(
self.root_dir, 's3dis_data',
f'{self.split}_{sample_idx}_ins_label.npy')
pts_semantic_mask_path = osp.join(
self.root_dir, 's3dis_data',
f'{self.split}_{sample_idx}_sem_label.npy')
points = np.load(pts_filename).astype(np.float32)
pts_instance_mask = np.load(pts_instance_mask_path).astype(np.int)
pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.int)
mmcv.mkdir_or_exist(osp.join(self.root_dir, 'points'))
mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask'))
mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask'))
points.tofile(
osp.join(self.root_dir, 'points',
f'{self.split}_{sample_idx}.bin'))
pts_instance_mask.tofile(
osp.join(self.root_dir, 'instance_mask',
f'{self.split}_{sample_idx}.bin'))
pts_semantic_mask.tofile(
osp.join(self.root_dir, 'semantic_mask',
f'{self.split}_{sample_idx}.bin'))
info['pts_path'] = osp.join('points',
f'{self.split}_{sample_idx}.bin')
info['pts_instance_mask_path'] = osp.join(
'instance_mask', f'{self.split}_{sample_idx}.bin')
info['pts_semantic_mask_path'] = osp.join(
'semantic_mask', f'{self.split}_{sample_idx}.bin')
return info
sample_id_list = sample_id_list if sample_id_list is not None \
else self.sample_id_list
with futures.ThreadPoolExecutor(num_workers) as executor:
infos = executor.map(process_single_scene, sample_id_list)
return list(infos)
class S3DISSegData(object):
"""S3DIS dataset used to generate infos for semantic segmentation task.
Args:
data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192.
label_weight_func (function): Function to compute the label weight.
Default: None.
"""
def __init__(self,
data_root,
ann_file,
split='Area_1',
num_points=4096,
label_weight_func=None):
self.data_root = data_root
self.data_infos = mmcv.load(ann_file)
self.split = split
self.num_points = num_points
self.all_ids = np.arange(13) # all possible ids
self.cat_ids = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12]) # used for seg task
self.ignore_index = len(self.cat_ids)
self.cat_id2class = np.ones((self.all_ids.shape[0],), dtype=np.int) * \
self.ignore_index
for i, cat_id in enumerate(self.cat_ids):
self.cat_id2class[cat_id] = i
# label weighting function is taken from
# https://github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py#L24
self.label_weight_func = (lambda x: 1.0 / np.log(1.2 + x)) if \
label_weight_func is None else label_weight_func
def get_seg_infos(self):
scene_idxs, label_weight = self.get_scene_idxs_and_label_weight()
save_folder = osp.join(self.data_root, 'seg_info')
mmcv.mkdir_or_exist(save_folder)
np.save(
osp.join(save_folder, f'{self.split}_resampled_scene_idxs.npy'),
scene_idxs)
np.save(
osp.join(save_folder, f'{self.split}_label_weight.npy'),
label_weight)
print(f'{self.split} resampled scene index and label weight saved')
def _convert_to_label(self, mask):
"""Convert class_id in loaded segmentation mask to label."""
if isinstance(mask, str):
if mask.endswith('npy'):
mask = np.load(mask)
else:
mask = np.fromfile(mask, dtype=np.long)
label = self.cat_id2class[mask]
return label
def get_scene_idxs_and_label_weight(self):
"""Compute scene_idxs for data sampling and label weight for loss \
calculation.
We sample more times for scenes with more points. Label_weight is
inversely proportional to number of class points.
"""
num_classes = len(self.cat_ids)
num_point_all = []
label_weight = np.zeros((num_classes + 1, )) # ignore_index
for data_info in self.data_infos:
label = self._convert_to_label(
osp.join(self.data_root, data_info['pts_semantic_mask_path']))
num_point_all.append(label.shape[0])
class_count, _ = np.histogram(label, range(num_classes + 2))
label_weight += class_count
# repeat scene_idx for num_scene_point // num_sample_point times
sample_prob = np.array(num_point_all) / float(np.sum(num_point_all))
num_iter = int(np.sum(num_point_all) / float(self.num_points))
scene_idxs = []
for idx in range(len(self.data_infos)):
scene_idxs.extend([idx] * round(sample_prob[idx] * num_iter))
scene_idxs = np.array(scene_idxs).astype(np.int32)
# calculate label weight, adopted from PointNet++
label_weight = label_weight[:-1].astype(np.float32)
label_weight = label_weight / label_weight.sum()
label_weight = self.label_weight_func(label_weight).astype(np.float32)
return scene_idxs, label_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment