Commit 72350b2d authored by liyinhao's avatar liyinhao
Browse files

merge funcs, change names

parent cbb549aa
......@@ -16,44 +16,42 @@ class IndoorBaseDataset(torch_data.Dataset):
ann_file,
pipeline=None,
training=False,
cat_ids=None,
classes=None,
test_mode=False,
with_label=True):
super().__init__()
self.root_path = root_path
self.cat_ids = cat_ids if cat_ids else self.CLASSES
self.CLASSES = classes if classes else self.CLASSES
self.test_mode = test_mode
self.training = training
self.mode = 'TRAIN' if self.training else 'TEST'
self.label2cat = {i: cat_id for i, cat_id in enumerate(self.cat_ids)}
self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
mmcv.check_file_exist(ann_file)
self.infos = mmcv.load(ann_file)
self.data_infos = mmcv.load(ann_file)
# dataset config
self.num_class = len(self.cat_ids)
self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0]
self.num_class = len(self.CLASSES)
if pipeline is not None:
self.pipeline = Compose(pipeline)
self.with_label = with_label
def __getitem__(self, idx):
if self.test_mode:
return self._prepare_test_data(idx)
return self.prepare_test_data(idx)
while True:
data = self._prepare_train_data(idx)
data = self.prepare_train_data(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data
def _prepare_test_data(self, index):
input_dict = self._get_sensor_data(index)
def prepare_test_data(self, index):
input_dict = self.get_data_info(index)
example = self.pipeline(input_dict)
return example
def _prepare_train_data(self, index):
input_dict = self._get_sensor_data(index)
input_dict = self._train_pre_pipeline(input_dict)
def prepare_train_data(self, index):
input_dict = self.get_data_info(index)
if input_dict is None:
return None
example = self.pipeline(input_dict)
......@@ -61,13 +59,8 @@ class IndoorBaseDataset(torch_data.Dataset):
return None
return example
def _train_pre_pipeline(self, input_dict):
if len(input_dict['gt_bboxes_3d']) == 0:
return None
return input_dict
def _get_sensor_data(self, index):
info = self.infos[index]
def get_data_info(self, index):
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = self._get_pts_filename(sample_idx)
......@@ -76,7 +69,8 @@ class IndoorBaseDataset(torch_data.Dataset):
if self.with_label:
annos = self._get_ann_info(index, sample_idx)
input_dict.update(annos)
if len(input_dict['gt_bboxes_3d']) == 0:
return None
return input_dict
def _rand_another(self, idx):
......@@ -132,9 +126,9 @@ class IndoorBaseDataset(torch_data.Dataset):
results = self.format_results(results)
from mmdet3d.core.evaluation import indoor_eval
assert len(metric) > 0
gt_annos = [copy.deepcopy(info['annos']) for info in self.infos]
gt_annos = [copy.deepcopy(info['annos']) for info in self.data_infos]
ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat)
return ret_dict
def __len__(self):
return len(self.infos)
return len(self.data_infos)
import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets import DATASETS
......@@ -20,22 +19,21 @@ class ScannetBaseDataset(IndoorBaseDataset):
ann_file,
pipeline=None,
training=False,
cat_ids=None,
classes=None,
test_mode=False,
with_label=True):
super().__init__(root_path, ann_file, pipeline, training, cat_ids,
super().__init__(root_path, ann_file, pipeline, training, classes,
test_mode, with_label)
self.data_path = osp.join(root_path, 'scannet_train_instance_data')
def _get_pts_filename(self, sample_idx):
pts_filename = osp.join(self.data_path, f'{sample_idx}_vert.npy')
mmcv.check_file_exist(pts_filename)
return pts_filename
def _get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api
info = self.infos[index]
info = self.data_infos[index]
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class']
......
import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets import DATASETS
......@@ -18,22 +17,21 @@ class SunrgbdBaseDataset(IndoorBaseDataset):
ann_file,
pipeline=None,
training=False,
cat_ids=None,
classes=None,
test_mode=False,
with_label=True):
super().__init__(root_path, ann_file, pipeline, training, cat_ids,
super().__init__(root_path, ann_file, pipeline, training, classes,
test_mode, with_label)
self.data_path = osp.join(root_path, 'sunrgbd_trainval')
def _get_pts_filename(self, sample_idx):
pts_filename = osp.join(self.data_path, 'lidar',
f'{sample_idx:06d}.npy')
mmcv.check_file_exist(pts_filename)
return pts_filename
def _get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api
info = self.infos[index]
info = self.data_infos[index]
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment