Commit d9f21dc9 authored by liyinhao's avatar liyinhao
Browse files

change based on the fourth comment

parent d3564d6d
import argparse import argparse
import os.path as osp import os.path as osp
import tools.data_converter.indoor_converter as indoor
import tools.data_converter.kitti_converter as kitti import tools.data_converter.kitti_converter as kitti
import tools.data_converter.nuscenes_converter as nuscenes_converter import tools.data_converter.nuscenes_converter as nuscenes_converter
import tools.data_converter.scannet_converter as scannet
import tools.data_converter.sunrgbd_converter as sunrgbd
from tools.data_converter.create_gt_database import create_groundtruth_database from tools.data_converter.create_gt_database import create_groundtruth_database
...@@ -46,11 +45,11 @@ def nuscenes_data_prep(root_path, ...@@ -46,11 +45,11 @@ def nuscenes_data_prep(root_path,
def scannet_data_prep(root_path, info_prefix, out_dir): def scannet_data_prep(root_path, info_prefix, out_dir):
scannet.create_scannet_info_file(root_path, info_prefix, out_dir) indoor.create_indoor_info_file(root_path, info_prefix, out_dir)
def sunrgbd_data_prep(root_path, info_prefix, out_dir): def sunrgbd_data_prep(root_path, info_prefix, out_dir):
sunrgbd.create_sunrgbd_info_file(root_path, info_prefix, out_dir) indoor.create_indoor_info_file(root_path, info_prefix, out_dir)
parser = argparse.ArgumentParser(description='Data converter arg parser') parser = argparse.ArgumentParser(description='Data converter arg parser')
......
import os import os
import mmcv import mmcv
from tools.data_converter.scannet_data_utils import ScanNetData
from tools.data_converter.sunrgbd_data_utils import SUNRGBDData from tools.data_converter.sunrgbd_data_utils import SUNRGBDData
def create_sunrgbd_info_file(data_path, def create_indoor_info_file(data_path,
pkl_prefix='sunrgbd', pkl_prefix='sunrgbd',
save_path=None, save_path=None,
use_v1=False): use_v1=False):
''' """
Create sunrgbd information file. Create indoor information file.
Get information of the raw data and save it to the pkl file. Get information of the raw data and save it to the pkl file.
Args: Args:
data_path (str): Path of the data. data_path (str): Path of the data.
pkl_prefix (str): Prefix ofr the pkl to be saved. Default: 'sunrgbd'. pkl_prefix (str): Prefix of the pkl to be saved. Default: 'sunrgbd'.
save_path (str): Path of the pkl to be saved. Default: None. save_path (str): Path of the pkl to be saved. Default: None.
use_v1 (bool): Whether to use v1. Default: False. use_v1 (bool): Whether to use v1. Default: False.
Returns: Returns:
None None
''' """
assert os.path.exists(data_path) assert os.path.exists(data_path)
assert pkl_prefix in ['sunrgbd', 'scannet']
if save_path is None: if save_path is None:
save_path = data_path save_path = data_path
else:
save_path = save_path
assert os.path.exists(save_path) assert os.path.exists(save_path)
train_filename = os.path.join(save_path, f'{pkl_prefix}_infos_train.pkl') train_filename = os.path.join(save_path, f'{pkl_prefix}_infos_train.pkl')
val_filename = os.path.join(save_path, f'{pkl_prefix}_infos_val.pkl') val_filename = os.path.join(save_path, f'{pkl_prefix}_infos_val.pkl')
train_dataset = SUNRGBDData( if pkl_prefix == 'sunrgbd':
root_path=data_path, split='train', use_v1=use_v1) train_dataset = SUNRGBDData(
val_dataset = SUNRGBDData(root_path=data_path, split='val', use_v1=use_v1) root_path=data_path, split='train', use_v1=use_v1)
sunrgbd_infos_train = train_dataset.get_sunrgbd_infos(has_label=True) val_dataset = SUNRGBDData(
root_path=data_path, split='val', use_v1=use_v1)
else:
train_dataset = ScanNetData(root_path=data_path, split='train')
val_dataset = ScanNetData(root_path=data_path, split='val')
infos_train = train_dataset.get_infos(has_label=True)
with open(train_filename, 'wb') as f: with open(train_filename, 'wb') as f:
mmcv.dump(sunrgbd_infos_train, f, 'pkl') mmcv.dump(infos_train, f, 'pkl')
print(f'Sunrgbd info train file is saved to {train_filename}') print(f'{pkl_prefix} info train file is saved to {train_filename}')
sunrgbd_infos_val = val_dataset.get_sunrgbd_infos(has_label=True) infos_val = val_dataset.get_infos(has_label=True)
with open(val_filename, 'wb') as f: with open(val_filename, 'wb') as f:
mmcv.dump(sunrgbd_infos_val, f, 'pkl') mmcv.dump(infos_val, f, 'pkl')
print(f'Sunrgbd info val file is saved to {val_filename}') print(f'{pkl_prefix} info val file is saved to {val_filename}')
import os
import mmcv
from tools.data_converter.scannet_data_utils import ScanNetData
def create_scannet_info_file(data_path, pkl_prefix='scannet', save_path=None):
'''
Create scannet information file.
Get information of the raw data and save it to the pkl file.
Args:
data_path (str): Path of the data.
pkl_prefix (str): Prefix ofr the pkl to be saved. Default: 'scannet'. # noqa: E501
save_path (str): Path of the pkl to be saved. Default: None.
Returns:
None
'''
assert os.path.exists(data_path)
if save_path is None:
save_path = data_path
else:
save_path = save_path
assert os.path.exists(save_path)
train_filename = os.path.join(save_path, f'{pkl_prefix}_infos_train.pkl')
val_filename = os.path.join(save_path, f'{pkl_prefix}_infos_val.pkl')
train_dataset = ScanNetData(root_path=data_path, split='train')
val_dataset = ScanNetData(root_path=data_path, split='val')
scannet_infos_train = train_dataset.get_scannet_infos(has_label=True)
with open(train_filename, 'wb') as f:
mmcv.dump(scannet_infos_train, f, 'pkl')
print(f'Scannet info train file is saved to {train_filename}')
scannet_infos_val = val_dataset.get_scannet_infos(has_label=True)
with open(val_filename, 'wb') as f:
mmcv.dump(scannet_infos_val, f, 'pkl')
print(f'Scannet info val file is saved to {val_filename}')
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
class ScanNetData(object): class ScanNetData(object):
''' """
ScanNet Data ScanNet Data
Generate scannet infos for scannet_converter Generate scannet infos for scannet_converter
...@@ -13,7 +13,7 @@ class ScanNetData(object): ...@@ -13,7 +13,7 @@ class ScanNetData(object):
Args: Args:
root_path (str): Root path of the raw data root_path (str): Root path of the raw data
split (str): Set split type of the data. Default: 'train'. split (str): Set split type of the data. Default: 'train'.
''' """
def __init__(self, root_path, split='train'): def __init__(self, root_path, split='train'):
self.root_dir = root_path self.root_dir = root_path
...@@ -48,23 +48,21 @@ class ScanNetData(object): ...@@ -48,23 +48,21 @@ class ScanNetData(object):
assert os.path.exists(box_file) assert os.path.exists(box_file)
return np.load(box_file) return np.load(box_file)
def get_scannet_infos(self, def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
num_workers=4, """
has_label=True, Get data infos.
sample_id_list=None):
'''
Get scannet infos.
This method gets information from the raw data. This method gets information from the raw data.
Args: Args:
num_workers (int): Number of threads to be used. Default: 4. num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True. has_label (bool): Whether the data has label. Default: True.
sample_id_list (List[int]): Index list of the sample. Default: None. # noqa: E501 sample_id_list (List[int]): Index list of the sample.
Default: None.
Returns: Returns:
infos (List[dict]): Information of the raw data. infos (List[dict]): Information of the raw data.
''' """
def process_single_scene(sample_idx): def process_single_scene(sample_idx):
print(f'{self.split} sample_idx: {sample_idx}') print(f'{self.split} sample_idx: {sample_idx}')
......
import concurrent.futures as futures import concurrent.futures as futures
import os import os
import cv2 import mmcv
import numpy as np import numpy as np
import scipy.io as sio import scipy.io as sio
def random_sampling(pc, num_samples, replace=None, return_choices=False): def random_sampling(pc, num_samples, replace=None, return_choices=False):
''' """
Random Sampling. Random Sampling.
Sampling point cloud to num_samples points. Sampling point cloud to num_samples points.
...@@ -20,7 +20,8 @@ def random_sampling(pc, num_samples, replace=None, return_choices=False): ...@@ -20,7 +20,8 @@ def random_sampling(pc, num_samples, replace=None, return_choices=False):
Returns: Returns:
pc (ndarray): Point cloud after sampling. pc (ndarray): Point cloud after sampling.
''' """
if replace is None: if replace is None:
replace = (pc.shape[0] < num_samples) replace = (pc.shape[0] < num_samples)
choices = np.random.choice(pc.shape[0], num_samples, replace=replace) choices = np.random.choice(pc.shape[0], num_samples, replace=replace)
...@@ -57,7 +58,7 @@ class SUNRGBDInstance(object): ...@@ -57,7 +58,7 @@ class SUNRGBDInstance(object):
class SUNRGBDData(object): class SUNRGBDData(object):
''' """
SUNRGBD Data SUNRGBD Data
Generate scannet infos for sunrgbd_converter Generate scannet infos for sunrgbd_converter
...@@ -66,7 +67,7 @@ class SUNRGBDData(object): ...@@ -66,7 +67,7 @@ class SUNRGBDData(object):
root_path (str): Root path of the raw data. root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'. split (str): Set split type of the data. Default: 'train'.
use_v1 (bool): Whether to use v1. Default: False. use_v1 (bool): Whether to use v1. Default: False.
''' """
def __init__(self, root_path, split='train', use_v1=False): def __init__(self, root_path, split='train', use_v1=False):
self.root_dir = root_path self.root_dir = root_path
...@@ -100,7 +101,7 @@ class SUNRGBDData(object): ...@@ -100,7 +101,7 @@ class SUNRGBDData(object):
def get_image(self, idx): def get_image(self, idx):
img_filename = os.path.join(self.image_dir, f'{idx:06d}.jpg') img_filename = os.path.join(self.image_dir, f'{idx:06d}.jpg')
return cv2.imread(img_filename) return mmcv.imread(img_filename)
def get_image_shape(self, idx): def get_image_shape(self, idx):
image = self.get_image(idx) image = self.get_image(idx)
...@@ -125,23 +126,21 @@ class SUNRGBDData(object): ...@@ -125,23 +126,21 @@ class SUNRGBDData(object):
objects = [SUNRGBDInstance(line) for line in lines] objects = [SUNRGBDInstance(line) for line in lines]
return objects return objects
def get_sunrgbd_infos(self, def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
num_workers=4, """
has_label=True, Get data infos.
sample_id_list=None):
'''
Get sunrgbd infos.
This method gets information from the raw data. This method gets information from the raw data.
Args: Args:
num_workers (int): Number of threads to be used. Default: 4. num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True. has_label (bool): Whether the data has label. Default: True.
sample_id_list (List[int]): Index list of the sample. Default: None. # noqa: E501 sample_id_list (List[int]): Index list of the sample.
Default: None.
Returns: Returns:
infos (List[dict]): Information of the raw data. infos (List[dict]): Information of the raw data.
''' """
def process_single_scene(sample_idx): def process_single_scene(sample_idx):
print(f'{self.split} sample_idx: {sample_idx}') print(f'{self.split} sample_idx: {sample_idx}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment