Commit 89bda282 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'master' into fix-train-runtime

parents ff8623e1 99db60dd
import os
import pickle
from pathlib import Path
from tools.data_converter.sunrgbd_data_utils import SUNRGBDData
def create_sunrgbd_info_file(data_path,
pkl_prefix='sunrgbd',
save_path=None,
use_v1=False):
assert os.path.exists(data_path)
if save_path is None:
save_path = Path(data_path)
else:
save_path = Path(save_path)
assert os.path.exists(save_path)
train_filename = save_path / f'{pkl_prefix}_infos_train.pkl'
val_filename = save_path / f'{pkl_prefix}_infos_val.pkl'
train_dataset = SUNRGBDData(
root_path=data_path, split='train', use_v1=use_v1)
val_dataset = SUNRGBDData(root_path=data_path, split='val', use_v1=use_v1)
sunrgbd_infos_train = train_dataset.get_sunrgbd_infos(has_label=True)
with open(train_filename, 'wb') as f:
pickle.dump(sunrgbd_infos_train, f)
print('Sunrgbd info train file is saved to %s' % train_filename)
sunrgbd_infos_val = val_dataset.get_sunrgbd_infos(has_label=True)
with open(val_filename, 'wb') as f:
pickle.dump(sunrgbd_infos_val, f)
print('Sunrgbd info val file is saved to %s' % val_filename)
if __name__ == '__main__':
create_sunrgbd_info_file(
data_path='./data/sunrgbd/sunrgbd_trainval',
save_path='./data/sunrgbd')
import concurrent.futures as futures
import os
import cv2
import mmcv
import numpy as np
import scipy.io as sio
def random_sampling(pc, num_sample, replace=None, return_choices=False):
""" Input is NxC, output is num_samplexC
def random_sampling(points, num_points, replace=None, return_choices=False):
"""Random Sampling.
Sampling point cloud to a certain number of points.
Args:
points (ndarray): Point cloud.
num_points (int): The number of samples.
replace (bool): Whether the sample is with or without replacement.
return_choices (bool): Whether to return choices.
Returns:
points (ndarray): Point cloud after sampling.
"""
if replace is None:
replace = (pc.shape[0] < num_sample)
choices = np.random.choice(pc.shape[0], num_sample, replace=replace)
replace = (points.shape[0] < num_points)
choices = np.random.choice(points.shape[0], num_points, replace=replace)
if return_choices:
return pc[choices], choices
return points[choices], choices
else:
return pc[choices]
return points[choices]
class SUNRGBDInstance(object):
......@@ -44,7 +57,15 @@ class SUNRGBDInstance(object):
class SUNRGBDData(object):
''' Load and parse object data '''
"""SUNRGBD Data
Generate scannet infos for sunrgbd_converter
Args:
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'.
use_v1 (bool): Whether to use v1. Default: False.
"""
def __init__(self, root_path, split='train', use_v1=False):
self.root_dir = root_path
......@@ -60,11 +81,9 @@ class SUNRGBDData(object):
for label in range(len(self.classes))
}
assert split in ['train', 'val', 'test']
split_dir = os.path.join(self.root_dir, '%s_data_idx.txt' % split)
self.sample_id_list = [
int(x.strip()) for x in open(split_dir).readlines()
] if os.path.exists(split_dir) else None
split_file = os.path.join(self.root_dir, f'{split}_data_idx.txt')
mmcv.check_file_exist(split_file)
self.sample_id_list = map(int, mmcv.list_from_file(split_file))
self.image_dir = os.path.join(self.split_dir, 'image')
self.calib_dir = os.path.join(self.split_dir, 'calib')
self.depth_dir = os.path.join(self.split_dir, 'depth')
......@@ -77,20 +96,20 @@ class SUNRGBDData(object):
return len(self.sample_id_list)
def get_image(self, idx):
img_filename = os.path.join(self.image_dir, '%06d.jpg' % (idx))
return cv2.imread(img_filename)
img_filename = os.path.join(self.image_dir, f'{idx:06d}.jpg')
return mmcv.imread(img_filename)
def get_image_shape(self, idx):
image = self.get_image(idx)
return np.array(image.shape[:2], dtype=np.int32)
def get_depth(self, idx):
depth_filename = os.path.join(self.depth_dir, '%06d.mat' % (idx))
depth_filename = os.path.join(self.depth_dir, f'{idx:06d}.mat')
depth = sio.loadmat(depth_filename)['instance']
return depth
def get_calibration(self, idx):
calib_filepath = os.path.join(self.calib_dir, '%06d.txt' % (idx))
calib_filepath = os.path.join(self.calib_dir, f'{idx:06d}.txt')
lines = [line.rstrip() for line in open(calib_filepath)]
Rt = np.array([float(x) for x in lines[0].split(' ')])
Rt = np.reshape(Rt, (3, 3), order='F')
......@@ -98,33 +117,43 @@ class SUNRGBDData(object):
return K, Rt
def get_label_objects(self, idx):
label_filename = os.path.join(self.label_dir, '%06d.txt' % (idx))
label_filename = os.path.join(self.label_dir, f'{idx:06d}.txt')
lines = [line.rstrip() for line in open(label_filename)]
objects = [SUNRGBDInstance(line) for line in lines]
return objects
def get_sunrgbd_infos(self,
num_workers=4,
has_label=True,
sample_id_list=None):
import concurrent.futures as futures
def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
"""Get data infos.
This method gets information from the raw data.
Args:
num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True.
sample_id_list (List[int]): Index list of the sample.
Default: None.
Returns:
infos (List[dict]): Information of the raw data.
"""
def process_single_scene(sample_idx):
print('%s sample_idx: %s' % (self.split, sample_idx))
print(f'{self.split} sample_idx: {sample_idx}')
# convert depth to points
SAMPLE_NUM = 50000
# TODO: Check whether can move the point
# sampling process during training.
pc_upright_depth = self.get_depth(sample_idx)
# TODO : sample points in loading process and test
pc_upright_depth_subsampled = random_sampling(
pc_upright_depth, SAMPLE_NUM)
np.savez_compressed(
os.path.join(self.root_dir, 'lidar', '%06d.npz' % sample_idx),
os.path.join(self.root_dir, 'lidar', f'{sample_idx:06d}.npz'),
pc=pc_upright_depth_subsampled)
info = dict()
pc_info = {'num_features': 6, 'lidar_idx': sample_idx}
info['point_cloud'] = pc_info
img_name = os.path.join(self.image_dir, '%06d.jpg' % (sample_idx))
img_name = os.path.join(self.image_dir, f'{sample_idx:06d}')
img_path = os.path.join(self.image_dir, img_name)
image_info = {
'image_idx': sample_idx,
......@@ -183,8 +212,7 @@ class SUNRGBDData(object):
return info
lidar_save_dir = os.path.join(self.root_dir, 'lidar')
if not os.path.exists(lidar_save_dir):
os.mkdir(lidar_save_dir)
mmcv.mkdir_or_exist(lidar_save_dir)
sample_id_list = sample_id_list if \
sample_id_list is not None else self.sample_id_list
with futures.ThreadPoolExecutor(num_workers) as executor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment