Commit 5219773a authored by liyinhao's avatar liyinhao
Browse files

change names, change some os method to mmcv method

parent 90396ed6
import concurrent.futures as futures import concurrent.futures as futures
import os import os
import mmcv
import numpy as np import numpy as np
...@@ -33,10 +34,10 @@ class ScanNetData(object): ...@@ -33,10 +34,10 @@ class ScanNetData(object):
for i, nyu40id in enumerate(list(self.cat_ids)) for i, nyu40id in enumerate(list(self.cat_ids))
} }
assert split in ['train', 'val', 'test'] assert split in ['train', 'val', 'test']
split_dir = os.path.join(self.root_dir, 'meta_data', split_file = os.path.join(self.root_dir, 'meta_data',
f'scannetv2_{split}.txt') f'scannetv2_{split}.txt')
self.sample_id_list = [x.strip() for x in open(split_dir).readlines() mmcv.check_file_exist(split_file)
] if os.path.exists(split_dir) else None self.sample_id_list = mmcv.list_from_file(split_file)
def __len__(self): def __len__(self):
return len(self.sample_id_list) return len(self.sample_id_list)
......
...@@ -6,14 +6,14 @@ import numpy as np ...@@ -6,14 +6,14 @@ import numpy as np
import scipy.io as sio import scipy.io as sio
def random_sampling(pc, num_samples, replace=None, return_choices=False): def random_sampling(pc, num_points, replace=None, return_choices=False):
"""Random Sampling. """Random Sampling.
Sampling point cloud to num_samples points. Sampling point cloud to a certain number of points.
Args: Args:
pc (ndarray): Point cloud. pc (ndarray): Point cloud.
num_samples (int): The number of samples. num_points (int): The number of samples.
replace (bool): Whether the sample is with or without replacement. replace (bool): Whether the sample is with or without replacement.
return_choices (bool): Whether to return choices. return_choices (bool): Whether to return choices.
...@@ -22,8 +22,8 @@ def random_sampling(pc, num_samples, replace=None, return_choices=False): ...@@ -22,8 +22,8 @@ def random_sampling(pc, num_samples, replace=None, return_choices=False):
""" """
if replace is None: if replace is None:
replace = (pc.shape[0] < num_samples) replace = (pc.shape[0] < num_points)
choices = np.random.choice(pc.shape[0], num_samples, replace=replace) choices = np.random.choice(pc.shape[0], num_points, replace=replace)
if return_choices: if return_choices:
return pc[choices], choices return pc[choices], choices
else: else:
...@@ -81,11 +81,9 @@ class SUNRGBDData(object): ...@@ -81,11 +81,9 @@ class SUNRGBDData(object):
for label in range(len(self.classes)) for label in range(len(self.classes))
} }
assert split in ['train', 'val', 'test'] assert split in ['train', 'val', 'test']
split_dir = os.path.join(self.root_dir, f'{split}_data_idx.txt') split_file = os.path.join(self.root_dir, f'{split}_data_idx.txt')
self.sample_id_list = [ mmcv.check_file_exist(split_file)
int(x.strip()) for x in open(split_dir).readlines() self.sample_id_list = map(int, mmcv.list_from_file(split_file))
] if os.path.exists(split_dir) else None
self.image_dir = os.path.join(self.split_dir, 'image') self.image_dir = os.path.join(self.split_dir, 'image')
self.calib_dir = os.path.join(self.split_dir, 'calib') self.calib_dir = os.path.join(self.split_dir, 'calib')
self.depth_dir = os.path.join(self.split_dir, 'depth') self.depth_dir = os.path.join(self.split_dir, 'depth')
...@@ -143,8 +141,9 @@ class SUNRGBDData(object): ...@@ -143,8 +141,9 @@ class SUNRGBDData(object):
print(f'{self.split} sample_idx: {sample_idx}') print(f'{self.split} sample_idx: {sample_idx}')
# convert depth to points # convert depth to points
SAMPLE_NUM = 50000 SAMPLE_NUM = 50000
# TODO: Check whether can move the point
# sampling process during training.
pc_upright_depth = self.get_depth(sample_idx) pc_upright_depth = self.get_depth(sample_idx)
# TODO : sample points in loading process and test
pc_upright_depth_subsampled = random_sampling( pc_upright_depth_subsampled = random_sampling(
pc_upright_depth, SAMPLE_NUM) pc_upright_depth, SAMPLE_NUM)
np.savez_compressed( np.savez_compressed(
...@@ -213,8 +212,7 @@ class SUNRGBDData(object): ...@@ -213,8 +212,7 @@ class SUNRGBDData(object):
return info return info
lidar_save_dir = os.path.join(self.root_dir, 'lidar') lidar_save_dir = os.path.join(self.root_dir, 'lidar')
if not os.path.exists(lidar_save_dir): mmcv.mkdir_or_exist(lidar_save_dir)
os.mkdir(lidar_save_dir)
sample_id_list = sample_id_list if \ sample_id_list = sample_id_list if \
sample_id_list is not None else self.sample_id_list sample_id_list is not None else self.sample_id_list
with futures.ThreadPoolExecutor(num_workers) as executor: with futures.ThreadPoolExecutor(num_workers) as executor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment