Commit d71edf6c authored by yinchimaoliang's avatar yinchimaoliang
Browse files

finish test getitem

parent 49121b64
...@@ -127,8 +127,9 @@ class DefaultFormatBundle3D(DefaultFormatBundle): ...@@ -127,8 +127,9 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
gt_bboxes_3d_mask = results['gt_bboxes_3d_mask'] gt_bboxes_3d_mask = results['gt_bboxes_3d_mask']
results['gt_bboxes_3d'] = results['gt_bboxes_3d'][ results['gt_bboxes_3d'] = results['gt_bboxes_3d'][
gt_bboxes_3d_mask] gt_bboxes_3d_mask]
results['gt_names_3d'] = results['gt_names_3d'][ if 'gt_names_3d' in results:
gt_bboxes_3d_mask] results['gt_names_3d'] = results['gt_names_3d'][
gt_bboxes_3d_mask]
if 'gt_bboxes_mask' in results: if 'gt_bboxes_mask' in results:
gt_bboxes_mask = results['gt_bboxes_mask'] gt_bboxes_mask = results['gt_bboxes_mask']
if 'gt_bboxes' in results: if 'gt_bboxes' in results:
...@@ -151,10 +152,12 @@ class DefaultFormatBundle3D(DefaultFormatBundle): ...@@ -151,10 +152,12 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
dtype=np.int64) dtype=np.int64)
# we still assume one pipeline for one frame LiDAR # we still assume one pipeline for one frame LiDAR
# thus, the 3D name is list[string] # thus, the 3D name is list[string]
results['gt_labels_3d'] = np.array([ if 'gt_names_3d' in results:
self.class_names.index(n) for n in results['gt_names_3d'] results['gt_labels_3d'] = np.array([
], self.class_names.index(n)
dtype=np.int64) for n in results['gt_names_3d']
],
dtype=np.int64)
results = super(DefaultFormatBundle3D, self).__call__(results) results = super(DefaultFormatBundle3D, self).__call__(results)
return results return results
......
...@@ -11,7 +11,7 @@ from .pipelines import Compose ...@@ -11,7 +11,7 @@ from .pipelines import Compose
@DATASETS.register_module() @DATASETS.register_module()
class ScannetDataset(torch_data.dataset): class ScannetDataset(torch_data.Dataset):
type2class = { type2class = {
'cabinet': 0, 'cabinet': 0,
'bed': 1, 'bed': 1,
...@@ -60,15 +60,14 @@ class ScannetDataset(torch_data.dataset): ...@@ -60,15 +60,14 @@ class ScannetDataset(torch_data.dataset):
def __init__(self, def __init__(self,
root_path, root_path,
ann_file, ann_file,
split,
pipeline=None, pipeline=None,
training=False, training=False,
class_names=None, class_names=None,
test_mode=False): test_mode=False,
with_label=True):
super().__init__() super().__init__()
self.root_path = root_path self.root_path = root_path
self.class_names = class_names if class_names else self.CLASSES self.class_names = class_names if class_names else self.CLASSES
self.split = split
self.data_path = os.path.join(root_path, 'scannet_train_instance_data') self.data_path = os.path.join(root_path, 'scannet_train_instance_data')
self.test_mode = test_mode self.test_mode = test_mode
...@@ -76,10 +75,6 @@ class ScannetDataset(torch_data.dataset): ...@@ -76,10 +75,6 @@ class ScannetDataset(torch_data.dataset):
self.mode = 'TRAIN' if self.training else 'TEST' self.mode = 'TRAIN' if self.training else 'TEST'
self.ann_file = ann_file self.ann_file = ann_file
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
self.scannet_infos = mmcv.load(ann_file) self.scannet_infos = mmcv.load(ann_file)
# dataset config # dataset config
...@@ -93,25 +88,26 @@ class ScannetDataset(torch_data.dataset): ...@@ -93,25 +88,26 @@ class ScannetDataset(torch_data.dataset):
} }
if pipeline is not None: if pipeline is not None:
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
self.with_label = with_label
def __getitem__(self, idx): def __getitem__(self, idx):
if self.test_mode: if self.test_mode:
return self.prepare_test_data(idx) return self._prepare_test_data(idx)
while True: while True:
data = self.prepare_train_data(idx) data = self._prepare_train_data(idx)
if data is None: if data is None:
idx = self._rand_another(idx) idx = self._rand_another(idx)
continue continue
return data return data
def prepare_test_data(self, index): def _prepare_test_data(self, index):
input_dict = self.get_sensor_data(index) input_dict = self._get_sensor_data(index)
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
return example return example
def prepare_train_data(self, index): def _prepare_train_data(self, index):
input_dict = self.get_sensor_data(index) input_dict = self._get_sensor_data(index)
input_dict = self.train_pre_pipeline(input_dict) input_dict = self._train_pre_pipeline(input_dict)
if input_dict is None: if input_dict is None:
return None return None
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
...@@ -119,43 +115,40 @@ class ScannetDataset(torch_data.dataset): ...@@ -119,43 +115,40 @@ class ScannetDataset(torch_data.dataset):
return None return None
return example return example
def train_pre_pipeline(self, input_dict): def _train_pre_pipeline(self, input_dict):
if len(input_dict['gt_bboxes_3d']) == 0: if len(input_dict['gt_bboxes_3d']) == 0:
return None return None
return input_dict return input_dict
def get_sensor_data(self, index): def _get_sensor_data(self, index):
info = self.scannet_infos[index] info = self.scannet_infos[index]
sample_idx = info['point_cloud']['lidar_idx'] sample_idx = info['point_cloud']['lidar_idx']
points = self.get_lidar(sample_idx) pts_filename = self._get_pts_filename(sample_idx)
input_dict = dict( input_dict = dict(pts_filename=pts_filename)
sample_idx=sample_idx,
points=points,
)
if self.with_label: if self.with_label:
annos = self.get_ann_info(index, sample_idx) annos = self._get_ann_info(index, sample_idx)
input_dict.update(annos) input_dict.update(annos)
return input_dict return input_dict
def get_lidar(self, sample_idx): def _get_pts_filename(self, sample_idx):
lidar_file = os.path.join(self.data_path, sample_idx + '_vert.npy') pts_filename = os.path.join(self.data_path, sample_idx + '_vert.npy')
assert os.path.exists(lidar_file) mmcv.check_file_exist(pts_filename)
return np.load(lidar_file) return pts_filename
def get_ann_info(self, index, sample_idx): def _get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.kitti_infos[index] info = self.scannet_infos[index]
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6 gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class'].reshape(-1, 1) gt_labels = info['annos']['class']
gt_bboxes_3d_mask = np.ones_like(gt_labels) gt_bboxes_3d_mask = np.ones_like(gt_labels).astype(np.bool)
else: else:
gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32) gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
gt_labels = np.zeros((1, 1)) gt_labels = np.zeros(1, ).astype(np.bool)
gt_bboxes_3d_mask = np.zeros((1, 1)) gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)
pts_instance_mask_path = osp.join(self.data_path, pts_instance_mask_path = osp.join(self.data_path,
sample_idx + '_ins_label.npy') sample_idx + '_ins_label.npy')
pts_semantic_mask_path = osp.join(self.data_path, pts_semantic_mask_path = osp.join(self.data_path,
...@@ -173,7 +166,7 @@ class ScannetDataset(torch_data.dataset): ...@@ -173,7 +166,7 @@ class ScannetDataset(torch_data.dataset):
pool = np.where(self.flag == self.flag[idx])[0] pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool) return np.random.choice(pool)
def generate_annotations(self, output): def _generate_annotations(self, output):
''' '''
transfer input_dict & pred_dicts to anno format transfer input_dict & pred_dicts to anno format
which is needed by AP calculator which is needed by AP calculator
...@@ -209,15 +202,15 @@ class ScannetDataset(torch_data.dataset): ...@@ -209,15 +202,15 @@ class ScannetDataset(torch_data.dataset):
return result return result
def format_results(self, outputs): def _format_results(self, outputs):
results = [] results = []
for output in outputs: for output in outputs:
result = self.generate_annotations(output) result = self._generate_annotations(output)
results.append(result) results.append(result)
return results return results
def evaluate(self, results, metric=None, logger=None, pklfile_prefix=None): def evaluate(self, results, metric=None, logger=None, pklfile_prefix=None):
results = self.format_results(results) results = self._format_results(results)
from mmdet3d.core.evaluation.scannet_utils.eval import scannet_eval from mmdet3d.core.evaluation.scannet_utils.eval import scannet_eval
assert ('AP_IOU_THRESHHOLDS' in metric) assert ('AP_IOU_THRESHHOLDS' in metric)
gt_annos = [ gt_annos = [
......
import numpy as np
from mmdet3d.datasets.scannet_dataset import ScannetDataset
def test_getitem():
np.random.seed(0)
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet',
'sink', 'bathtub', 'garbagebin')
pipelines = [
dict(
type='IndoorLoadPointsFromFile',
use_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='IndoorLoadAnnotations3D'),
dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict(
type='IndoorGlobalRotScale',
use_height=True,
rot_range=[-np.pi * 1 / 36, np.pi * 1 / 36],
scale_range=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'points', 'gt_bboxes_3d', 'gt_labels', 'pts_semantic_mask',
'pts_instance_mask'
]),
]
scannet_dataset = ScannetDataset(root_path, ann_file, pipelines, True)
data = scannet_dataset[0]
points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data
gt_labels = data['gt_labels']._data
pts_semantic_mask = data['pts_semantic_mask']
pts_instance_mask = data['pts_instance_mask']
expected_points = np.array(
[[-2.9078157, -1.9569951, 2.3543026, 2.389488],
[-0.71360034, -3.4359822, 2.1330001, 2.1681855],
[-1.332374, 1.474838, -0.04405887, -0.00887359],
[2.1336637, -1.3265059, -0.02880373, 0.00638155],
[0.43895668, -3.0259454, 1.5560012, 1.5911865]])
expected_gt_bboxes_3d = np.array([
[-1.5005362, -3.512584, 1.8565295, 1.7457027, 0.24149807, 0.57235193],
[-2.8848705, 3.4961755, 1.5268247, 0.66170084, 0.17433672, 0.67153597],
[-1.1585636, -2.192365, 0.61649567, 0.5557011, 2.5375574, 1.2144762],
[-2.930457, -2.4856408, 0.9722377, 0.6270478, 1.8461524, 0.28697443],
[3.3114715, -0.00476722, 1.0712197, 0.46191898, 3.8605113, 2.1603441]
])
expected_gt_labels = np.array([
6, 6, 4, 9, 11, 11, 10, 0, 15, 17, 17, 17, 3, 12, 4, 4, 14, 1, 0, 0, 0,
0, 0, 0, 5, 5, 5
])
expected_pts_semantic_mask = np.array([3, 1, 2, 2, 15])
expected_pts_instance_mask = np.array([44, 22, 10, 10, 57])
assert np.allclose(points, expected_points)
assert gt_bboxes_3d[:5].shape == (5, 6)
assert np.allclose(gt_bboxes_3d[:5], expected_gt_bboxes_3d)
assert np.all(gt_labels.numpy() == expected_gt_labels)
assert np.all(pts_semantic_mask == expected_pts_semantic_mask)
assert np.all(pts_instance_mask == expected_pts_instance_mask)
...@@ -161,7 +161,7 @@ def main(): ...@@ -161,7 +161,7 @@ def main():
mmcv.dump(outputs, args.out) mmcv.dump(outputs, args.out)
kwargs = {} if args.options is None else args.options kwargs = {} if args.options is None else args.options
if args.format_only: if args.format_only:
dataset.format_results(outputs, **kwargs) dataset._format_results(outputs, **kwargs)
if args.eval: if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs) dataset.evaluate(outputs, args.eval, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment