Commit 72350b2d authored by liyinhao's avatar liyinhao
Browse files

merge funcs, change names

parent cbb549aa
...@@ -16,44 +16,42 @@ class IndoorBaseDataset(torch_data.Dataset): ...@@ -16,44 +16,42 @@ class IndoorBaseDataset(torch_data.Dataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
training=False, training=False,
cat_ids=None, classes=None,
test_mode=False, test_mode=False,
with_label=True): with_label=True):
super().__init__() super().__init__()
self.root_path = root_path self.root_path = root_path
self.cat_ids = cat_ids if cat_ids else self.CLASSES self.CLASSES = classes if classes else self.CLASSES
self.test_mode = test_mode self.test_mode = test_mode
self.training = training self.training = training
self.mode = 'TRAIN' if self.training else 'TEST' self.mode = 'TRAIN' if self.training else 'TEST'
self.label2cat = {i: cat_id for i, cat_id in enumerate(self.cat_ids)} self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
mmcv.check_file_exist(ann_file) mmcv.check_file_exist(ann_file)
self.infos = mmcv.load(ann_file) self.data_infos = mmcv.load(ann_file)
# dataset config # dataset config
self.num_class = len(self.cat_ids) self.num_class = len(self.CLASSES)
self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0]
if pipeline is not None: if pipeline is not None:
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
self.with_label = with_label self.with_label = with_label
def __getitem__(self, idx): def __getitem__(self, idx):
if self.test_mode: if self.test_mode:
return self._prepare_test_data(idx) return self.prepare_test_data(idx)
while True: while True:
data = self._prepare_train_data(idx) data = self.prepare_train_data(idx)
if data is None: if data is None:
idx = self._rand_another(idx) idx = self._rand_another(idx)
continue continue
return data return data
def _prepare_test_data(self, index): def prepare_test_data(self, index):
input_dict = self._get_sensor_data(index) input_dict = self.get_data_info(index)
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
return example return example
def _prepare_train_data(self, index): def prepare_train_data(self, index):
input_dict = self._get_sensor_data(index) input_dict = self.get_data_info(index)
input_dict = self._train_pre_pipeline(input_dict)
if input_dict is None: if input_dict is None:
return None return None
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
...@@ -61,13 +59,8 @@ class IndoorBaseDataset(torch_data.Dataset): ...@@ -61,13 +59,8 @@ class IndoorBaseDataset(torch_data.Dataset):
return None return None
return example return example
def _train_pre_pipeline(self, input_dict): def get_data_info(self, index):
if len(input_dict['gt_bboxes_3d']) == 0: info = self.data_infos[index]
return None
return input_dict
def _get_sensor_data(self, index):
info = self.infos[index]
sample_idx = info['point_cloud']['lidar_idx'] sample_idx = info['point_cloud']['lidar_idx']
pts_filename = self._get_pts_filename(sample_idx) pts_filename = self._get_pts_filename(sample_idx)
...@@ -76,7 +69,8 @@ class IndoorBaseDataset(torch_data.Dataset): ...@@ -76,7 +69,8 @@ class IndoorBaseDataset(torch_data.Dataset):
if self.with_label: if self.with_label:
annos = self._get_ann_info(index, sample_idx) annos = self._get_ann_info(index, sample_idx)
input_dict.update(annos) input_dict.update(annos)
if len(input_dict['gt_bboxes_3d']) == 0:
return None
return input_dict return input_dict
def _rand_another(self, idx): def _rand_another(self, idx):
...@@ -132,9 +126,9 @@ class IndoorBaseDataset(torch_data.Dataset): ...@@ -132,9 +126,9 @@ class IndoorBaseDataset(torch_data.Dataset):
results = self.format_results(results) results = self.format_results(results)
from mmdet3d.core.evaluation import indoor_eval from mmdet3d.core.evaluation import indoor_eval
assert len(metric) > 0 assert len(metric) > 0
gt_annos = [copy.deepcopy(info['annos']) for info in self.infos] gt_annos = [copy.deepcopy(info['annos']) for info in self.data_infos]
ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat) ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat)
return ret_dict return ret_dict
def __len__(self): def __len__(self):
return len(self.infos) return len(self.data_infos)
import os.path as osp import os.path as osp
import mmcv
import numpy as np import numpy as np
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
...@@ -20,22 +19,21 @@ class ScannetBaseDataset(IndoorBaseDataset): ...@@ -20,22 +19,21 @@ class ScannetBaseDataset(IndoorBaseDataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
training=False, training=False,
cat_ids=None, classes=None,
test_mode=False, test_mode=False,
with_label=True): with_label=True):
super().__init__(root_path, ann_file, pipeline, training, cat_ids, super().__init__(root_path, ann_file, pipeline, training, classes,
test_mode, with_label) test_mode, with_label)
self.data_path = osp.join(root_path, 'scannet_train_instance_data') self.data_path = osp.join(root_path, 'scannet_train_instance_data')
def _get_pts_filename(self, sample_idx): def _get_pts_filename(self, sample_idx):
pts_filename = osp.join(self.data_path, f'{sample_idx}_vert.npy') pts_filename = osp.join(self.data_path, f'{sample_idx}_vert.npy')
mmcv.check_file_exist(pts_filename)
return pts_filename return pts_filename
def _get_ann_info(self, index, sample_idx): def _get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.infos[index] info = self.data_infos[index]
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6 gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class'] gt_labels = info['annos']['class']
......
import os.path as osp import os.path as osp
import mmcv
import numpy as np import numpy as np
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
...@@ -18,22 +17,21 @@ class SunrgbdBaseDataset(IndoorBaseDataset): ...@@ -18,22 +17,21 @@ class SunrgbdBaseDataset(IndoorBaseDataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
training=False, training=False,
cat_ids=None, classes=None,
test_mode=False, test_mode=False,
with_label=True): with_label=True):
super().__init__(root_path, ann_file, pipeline, training, cat_ids, super().__init__(root_path, ann_file, pipeline, training, classes,
test_mode, with_label) test_mode, with_label)
self.data_path = osp.join(root_path, 'sunrgbd_trainval') self.data_path = osp.join(root_path, 'sunrgbd_trainval')
def _get_pts_filename(self, sample_idx): def _get_pts_filename(self, sample_idx):
pts_filename = osp.join(self.data_path, 'lidar', pts_filename = osp.join(self.data_path, 'lidar',
f'{sample_idx:06d}.npy') f'{sample_idx:06d}.npy')
mmcv.check_file_exist(pts_filename)
return pts_filename return pts_filename
def _get_ann_info(self, index, sample_idx): def _get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.infos[index] info = self.data_infos[index]
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6 gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class'] gt_labels = info['annos']['class']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment