Commit ef19d3cb authored by liyinhao's avatar liyinhao
Browse files

finish scannet dataset before evaluate

parent f584b970
import os
import os.path as osp
import mmcv
import numpy as np
import torch.utils.data as torch_data
from mmdet.datasets import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class ScannetDataset(torch_data.dataset):
type2class = {
'cabinet': 0,
'bed': 1,
'chair': 2,
'sofa': 3,
'table': 4,
'door': 5,
'window': 6,
'bookshelf': 7,
'picture': 8,
'counter': 9,
'desk': 10,
'curtain': 11,
'refrigerator': 12,
'showercurtrain': 13,
'toilet': 14,
'sink': 15,
'bathtub': 16,
'garbagebin': 17
}
class2type = {
0: 'cabinet',
1: 'bed',
2: 'chair',
3: 'sofa',
4: 'table',
5: 'door',
6: 'window',
7: 'bookshelf',
8: 'picture',
9: 'counter',
10: 'desk',
11: 'curtain',
12: 'refrigerator',
13: 'showercurtrain',
14: 'toilet',
15: 'sink',
16: 'bathtub',
17: 'garbagebin'
}
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'garbagebin')
def __init__(self,
root_path,
ann_file,
split,
pipeline=None,
training=False,
class_names=None,
test_mode=False):
super().__init__()
self.root_path = root_path
self.class_names = class_names if class_names else self.CLASSES
self.split = split
self.data_path = os.path.join(root_path, 'scannet_train_instance_data')
self.test_mode = test_mode
self.training = training
self.mode = 'TRAIN' if self.training else 'TEST'
self.ann_file = ann_file
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
self.scannet_infos = mmcv.load(ann_file)
# dataset config
self.num_class = len(self.class_names)
self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0]
self.nyu40ids = np.array(
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
self.nyu40id2class = {
nyu40id: i
for i, nyu40id in enumerate(list(self.nyu40ids))
}
if pipeline is not None:
self.pipeline = Compose(pipeline)
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_data(idx)
while True:
data = self.prepare_train_data(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data
def prepare_test_data(self, index):
input_dict = self.get_sensor_data(index)
example = self.pipeline(input_dict)
return example
def prepare_train_data(self, index):
input_dict = self.get_sensor_data(index)
input_dict = self.train_pre_pipeline(input_dict)
if input_dict is None:
return None
example = self.pipeline(input_dict)
if len(example['gt_bboxes_3d']._data) == 0:
return None
return example
def train_pre_pipeline(self, input_dict):
if len(input_dict['gt_bboxes_3d']) == 0:
return None
return input_dict
def get_sensor_data(self, index):
info = self.scannet_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
points = self.get_lidar(sample_idx)
input_dict = dict(
sample_idx=sample_idx,
points=points,
)
if self.with_label:
annos = self.get_ann_info(index, sample_idx)
input_dict.update(annos)
return input_dict
def get_lidar(self, sample_idx):
lidar_file = os.path.join(self.data_path, sample_idx + '_vert.npy')
assert os.path.exists(lidar_file)
return np.load(lidar_file)
def get_ann_info(self, index, sample_idx):
# Use index to get the annos, thus the evalhook could also use this api
info = self.kitti_infos[index]
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] # k, 6
gt_labels = info['annos']['class'].reshape(-1, 1)
gt_bboxes_3d_mask = np.ones_like(gt_labels)
else:
gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
gt_labels = np.zeros((1, 1))
gt_bboxes_3d_mask = np.zeros((1, 1))
pts_instance_mask_path = osp.join(self.data_path,
sample_idx + '_ins_label.npy')
pts_semantic_mask_path = osp.join(self.data_path,
sample_idx + '_sem_label.npy')
anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d,
gt_labels=gt_labels,
gt_bboxes_3d_mask=gt_bboxes_3d_mask,
pts_instance_mask_path=pts_instance_mask_path,
pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
def _rand_another(self, idx):
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def __len__(self):
return len(self.scannet_infos)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment