Commit d71edf6c authored by yinchimaoliang's avatar yinchimaoliang
Browse files

finish test getitem

parent 49121b64
......@@ -127,6 +127,7 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
gt_bboxes_3d_mask = results['gt_bboxes_3d_mask']
results['gt_bboxes_3d'] = results['gt_bboxes_3d'][
gt_bboxes_3d_mask]
if 'gt_names_3d' in results:
results['gt_names_3d'] = results['gt_names_3d'][
gt_bboxes_3d_mask]
if 'gt_bboxes_mask' in results:
......@@ -151,8 +152,10 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
dtype=np.int64)
# we still assume one pipeline for one frame LiDAR
# thus, the 3D name is list[string]
if 'gt_names_3d' in results:
results['gt_labels_3d'] = np.array([
self.class_names.index(n) for n in results['gt_names_3d']
self.class_names.index(n)
for n in results['gt_names_3d']
],
dtype=np.int64)
results = super(DefaultFormatBundle3D, self).__call__(results)
......
......@@ -11,7 +11,7 @@ from .pipelines import Compose
@DATASETS.register_module()
class ScannetDataset(torch_data.dataset):
class ScannetDataset(torch_data.Dataset):
type2class = {
'cabinet': 0,
'bed': 1,
......@@ -60,15 +60,14 @@ class ScannetDataset(torch_data.dataset):
def __init__(self,
root_path,
ann_file,
split,
pipeline=None,
training=False,
class_names=None,
test_mode=False):
test_mode=False,
with_label=True):
super().__init__()
self.root_path = root_path
self.class_names = class_names if class_names else self.CLASSES
self.split = split
self.data_path = os.path.join(root_path, 'scannet_train_instance_data')
self.test_mode = test_mode
......@@ -76,10 +75,6 @@ class ScannetDataset(torch_data.dataset):
self.mode = 'TRAIN' if self.training else 'TEST'
self.ann_file = ann_file
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
self.scannet_infos = mmcv.load(ann_file)
# dataset config
......@@ -93,25 +88,26 @@ class ScannetDataset(torch_data.dataset):
}
if pipeline is not None:
self.pipeline = Compose(pipeline)
self.with_label = with_label
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_data(idx)
return self._prepare_test_data(idx)
while True:
data = self.prepare_train_data(idx)
data = self._prepare_train_data(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data
def prepare_test_data(self, index):
input_dict = self.get_sensor_data(index)
def _prepare_test_data(self, index):
input_dict = self._get_sensor_data(index)
example = self.pipeline(input_dict)
return example
def prepare_train_data(self, index):
input_dict = self.get_sensor_data(index)
input_dict = self.train_pre_pipeline(input_dict)
def _prepare_train_data(self, index):
input_dict = self._get_sensor_data(index)
input_dict = self._train_pre_pipeline(input_dict)
if input_dict is None:
return None
example = self.pipeline(input_dict)
......@@ -119,43 +115,40 @@ class ScannetDataset(torch_data.dataset):
return None
return example
def train_pre_pipeline(self, input_dict):
def _train_pre_pipeline(self, input_dict):
if len(input_dict['gt_bboxes_3d']) == 0:
return None
return input_dict
def get_sensor_data(self, index):
def _get_sensor_data(self, index):
info = self.scannet_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
points = self.get_lidar(sample_idx)
pts_filename = self._get_pts_filename(sample_idx)
input_dict = dict(
sample_idx=sample_idx,
points=points,
)
input_dict = dict(pts_filename=pts_filename)
if self.with_label:
annos = self.get_ann_info(index, sample_idx)
annos = self._get_ann_info(index, sample_idx)
input_dict.update(annos)
return input_dict
def get_lidar(self, sample_idx):
lidar_file = os.path.join(self.data_path, sample_idx + '_vert.npy')
assert os.path.exists(lidar_file)
return np.load(lidar_file)
def _get_pts_filename(self, sample_idx):
pts_filename = os.path.join(self.data_path, sample_idx + '_vert.npy')
mmcv.check_file_exist(pts_filename)
return pts_filename
def get_ann_info(self, index, sample_idx):
def _get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api
info = self.kitti_infos[index]
info = self.scannet_infos[index]
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class'].reshape(-1, 1)
gt_bboxes_3d_mask = np.ones_like(gt_labels)
gt_labels = info['annos']['class']
gt_bboxes_3d_mask = np.ones_like(gt_labels).astype(np.bool)
else:
gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
gt_labels = np.zeros((1, 1))
gt_bboxes_3d_mask = np.zeros((1, 1))
gt_labels = np.zeros(1, ).astype(np.bool)
gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)
pts_instance_mask_path = osp.join(self.data_path,
sample_idx + '_ins_label.npy')
pts_semantic_mask_path = osp.join(self.data_path,
......@@ -173,7 +166,7 @@ class ScannetDataset(torch_data.dataset):
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def generate_annotations(self, output):
def _generate_annotations(self, output):
'''
transfer input_dict & pred_dicts to anno format
which is needed by AP calculator
......@@ -209,15 +202,15 @@ class ScannetDataset(torch_data.dataset):
return result
def format_results(self, outputs):
def _format_results(self, outputs):
results = []
for output in outputs:
result = self.generate_annotations(output)
result = self._generate_annotations(output)
results.append(result)
return results
def evaluate(self, results, metric=None, logger=None, pklfile_prefix=None):
results = self.format_results(results)
results = self._format_results(results)
from mmdet3d.core.evaluation.scannet_utils.eval import scannet_eval
assert ('AP_IOU_THRESHHOLDS' in metric)
gt_annos = [
......
import numpy as np
from mmdet3d.datasets.scannet_dataset import ScannetDataset
def test_getitem():
np.random.seed(0)
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet',
'sink', 'bathtub', 'garbagebin')
pipelines = [
dict(
type='IndoorLoadPointsFromFile',
use_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='IndoorLoadAnnotations3D'),
dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict(
type='IndoorGlobalRotScale',
use_height=True,
rot_range=[-np.pi * 1 / 36, np.pi * 1 / 36],
scale_range=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'points', 'gt_bboxes_3d', 'gt_labels', 'pts_semantic_mask',
'pts_instance_mask'
]),
]
scannet_dataset = ScannetDataset(root_path, ann_file, pipelines, True)
data = scannet_dataset[0]
points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data
gt_labels = data['gt_labels']._data
pts_semantic_mask = data['pts_semantic_mask']
pts_instance_mask = data['pts_instance_mask']
expected_points = np.array(
[[-2.9078157, -1.9569951, 2.3543026, 2.389488],
[-0.71360034, -3.4359822, 2.1330001, 2.1681855],
[-1.332374, 1.474838, -0.04405887, -0.00887359],
[2.1336637, -1.3265059, -0.02880373, 0.00638155],
[0.43895668, -3.0259454, 1.5560012, 1.5911865]])
expected_gt_bboxes_3d = np.array([
[-1.5005362, -3.512584, 1.8565295, 1.7457027, 0.24149807, 0.57235193],
[-2.8848705, 3.4961755, 1.5268247, 0.66170084, 0.17433672, 0.67153597],
[-1.1585636, -2.192365, 0.61649567, 0.5557011, 2.5375574, 1.2144762],
[-2.930457, -2.4856408, 0.9722377, 0.6270478, 1.8461524, 0.28697443],
[3.3114715, -0.00476722, 1.0712197, 0.46191898, 3.8605113, 2.1603441]
])
expected_gt_labels = np.array([
6, 6, 4, 9, 11, 11, 10, 0, 15, 17, 17, 17, 3, 12, 4, 4, 14, 1, 0, 0, 0,
0, 0, 0, 5, 5, 5
])
expected_pts_semantic_mask = np.array([3, 1, 2, 2, 15])
expected_pts_instance_mask = np.array([44, 22, 10, 10, 57])
assert np.allclose(points, expected_points)
assert gt_bboxes_3d[:5].shape == (5, 6)
assert np.allclose(gt_bboxes_3d[:5], expected_gt_bboxes_3d)
assert np.all(gt_labels.numpy() == expected_gt_labels)
assert np.all(pts_semantic_mask == expected_pts_semantic_mask)
assert np.all(pts_instance_mask == expected_pts_instance_mask)
......@@ -161,7 +161,7 @@ def main():
mmcv.dump(outputs, args.out)
kwargs = {} if args.options is None else args.options
if args.format_only:
dataset.format_results(outputs, **kwargs)
dataset._format_results(outputs, **kwargs)
if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment