scannet_dataset.py 7.46 KB
Newer Older
yinchimaoliang's avatar
yinchimaoliang committed
1
import copy
2
3
4
5
6
7
8
9
10
11
12
13
import os
import os.path as osp

import mmcv
import numpy as np
import torch.utils.data as torch_data

from mmdet.datasets import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
yinchimaoliang's avatar
yinchimaoliang committed
14
class ScannetDataset(torch_data.Dataset):
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    type2class = {
        'cabinet': 0,
        'bed': 1,
        'chair': 2,
        'sofa': 3,
        'table': 4,
        'door': 5,
        'window': 6,
        'bookshelf': 7,
        'picture': 8,
        'counter': 9,
        'desk': 10,
        'curtain': 11,
        'refrigerator': 12,
        'showercurtrain': 13,
        'toilet': 14,
        'sink': 15,
        'bathtub': 16,
        'garbagebin': 17
    }
    class2type = {
        0: 'cabinet',
        1: 'bed',
        2: 'chair',
        3: 'sofa',
        4: 'table',
        5: 'door',
        6: 'window',
        7: 'bookshelf',
        8: 'picture',
        9: 'counter',
        10: 'desk',
        11: 'curtain',
        12: 'refrigerator',
        13: 'showercurtrain',
        14: 'toilet',
        15: 'sink',
        16: 'bathtub',
        17: 'garbagebin'
    }
    CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
               'bookshelf', 'picture', 'counter', 'desk', 'curtain',
               'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
               'garbagebin')

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
                 training=False,
                 class_names=None,
yinchimaoliang's avatar
yinchimaoliang committed
66
67
                 test_mode=False,
                 with_label=True):
68
69
70
71
        super().__init__()
        self.root_path = root_path
        self.class_names = class_names if class_names else self.CLASSES

liyinhao's avatar
liyinhao committed
72
        self.data_path = osp.join(root_path, 'scannet_train_instance_data')
73
74
75
76
        self.test_mode = test_mode
        self.training = training
        self.mode = 'TRAIN' if self.training else 'TEST'

liyinhao's avatar
liyinhao committed
77
        mmcv.check_file_exist(ann_file)
78
79
80
81
82
83
84
85
86
87
88
89
90
        self.scannet_infos = mmcv.load(ann_file)

        # dataset config
        self.num_class = len(self.class_names)
        self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0]
        self.nyu40ids = np.array(
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
        self.nyu40id2class = {
            nyu40id: i
            for i, nyu40id in enumerate(list(self.nyu40ids))
        }
        if pipeline is not None:
            self.pipeline = Compose(pipeline)
yinchimaoliang's avatar
yinchimaoliang committed
91
        self.with_label = with_label
92
93
94

    def __getitem__(self, idx):
        if self.test_mode:
yinchimaoliang's avatar
yinchimaoliang committed
95
            return self._prepare_test_data(idx)
96
        while True:
yinchimaoliang's avatar
yinchimaoliang committed
97
            data = self._prepare_train_data(idx)
98
99
100
101
102
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

yinchimaoliang's avatar
yinchimaoliang committed
103
104
    def _prepare_test_data(self, index):
        input_dict = self._get_sensor_data(index)
105
106
107
        example = self.pipeline(input_dict)
        return example

yinchimaoliang's avatar
yinchimaoliang committed
108
109
110
    def _prepare_train_data(self, index):
        input_dict = self._get_sensor_data(index)
        input_dict = self._train_pre_pipeline(input_dict)
111
112
113
114
115
116
117
        if input_dict is None:
            return None
        example = self.pipeline(input_dict)
        if len(example['gt_bboxes_3d']._data) == 0:
            return None
        return example

yinchimaoliang's avatar
yinchimaoliang committed
118
    def _train_pre_pipeline(self, input_dict):
119
120
121
122
        if len(input_dict['gt_bboxes_3d']) == 0:
            return None
        return input_dict

yinchimaoliang's avatar
yinchimaoliang committed
123
    def _get_sensor_data(self, index):
124
125
        info = self.scannet_infos[index]
        sample_idx = info['point_cloud']['lidar_idx']
yinchimaoliang's avatar
yinchimaoliang committed
126
        pts_filename = self._get_pts_filename(sample_idx)
127

yinchimaoliang's avatar
yinchimaoliang committed
128
        input_dict = dict(pts_filename=pts_filename)
129
130

        if self.with_label:
yinchimaoliang's avatar
yinchimaoliang committed
131
            annos = self._get_ann_info(index, sample_idx)
132
133
134
135
            input_dict.update(annos)

        return input_dict

yinchimaoliang's avatar
yinchimaoliang committed
136
137
138
139
    def _get_pts_filename(self, sample_idx):
        pts_filename = os.path.join(self.data_path, sample_idx + '_vert.npy')
        mmcv.check_file_exist(pts_filename)
        return pts_filename
140

yinchimaoliang's avatar
yinchimaoliang committed
141
    def _get_ann_info(self, index, sample_idx):
142
        # Use index to get the annos, thus the evalhook could also use this api
yinchimaoliang's avatar
yinchimaoliang committed
143
        info = self.scannet_infos[index]
144
145
        if info['annos']['gt_num'] != 0:
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']  # k, 6
yinchimaoliang's avatar
yinchimaoliang committed
146
147
            gt_labels = info['annos']['class']
            gt_bboxes_3d_mask = np.ones_like(gt_labels).astype(np.bool)
148
149
        else:
            gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
yinchimaoliang's avatar
yinchimaoliang committed
150
151
            gt_labels = np.zeros(1, ).astype(np.bool)
            gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        pts_instance_mask_path = osp.join(self.data_path,
                                          sample_idx + '_ins_label.npy')
        pts_semantic_mask_path = osp.join(self.data_path,
                                          sample_idx + '_sem_label.npy')

        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels=gt_labels,
            gt_bboxes_3d_mask=gt_bboxes_3d_mask,
            pts_instance_mask_path=pts_instance_mask_path,
            pts_semantic_mask_path=pts_semantic_mask_path)
        return anns_results

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

yinchimaoliang's avatar
yinchimaoliang committed
169
    def _generate_annotations(self, output):
yinchimaoliang's avatar
yinchimaoliang committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        '''
        transfer input_dict & pred_dicts to anno format
        which is needed by AP calculator
        return annos: a tuple (batch_pred_map_cls,batch_gt_map_cls)
                        batch_pred_map_cls is a list: i=0,1..bs-1
                            pred_list_i:[(pred_sem_cls,
                            box_params, box_score)_j]
                            j=0,1..num_pred_obj -1

                        batch_gt_map_cls is a list: i=0,1..bs-1
                            gt_list_i: [(sem_cls_label, box_params)_j]
                            j=0,1..num_gt_obj -1
        '''
        result = []
        bs = len(output)
        for i in range(bs):
            pred_list_i = list()
            pred_boxes = output[i]
            box3d_depth = pred_boxes['box3d_lidar']
            if box3d_depth is not None:
yinchimaoliang's avatar
yinchimaoliang committed
190
                label_preds = pred_boxes['label_preds']
yinchimaoliang's avatar
yinchimaoliang committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                scores = pred_boxes['scores'].detach().cpu().numpy()
                label_preds = label_preds.detach().cpu().numpy()
                num_proposal = box3d_depth.shape[0]
                for j in range(num_proposal):
                    bbox_lidar = box3d_depth[j]  # [7] in lidar
                    bbox_lidar_bottom = bbox_lidar.copy()
                    pred_list_i.append(
                        (label_preds[j], bbox_lidar_bottom, scores[j]))
                result.append(pred_list_i)
            else:
                result.append(pred_list_i)

        return result

yinchimaoliang's avatar
yinchimaoliang committed
205
    def _format_results(self, outputs):
yinchimaoliang's avatar
yinchimaoliang committed
206
207
        results = []
        for output in outputs:
yinchimaoliang's avatar
yinchimaoliang committed
208
            result = self._generate_annotations(output)
yinchimaoliang's avatar
yinchimaoliang committed
209
210
211
212
            results.append(result)
        return results

    def evaluate(self, results, metric=None, logger=None, pklfile_prefix=None):
yinchimaoliang's avatar
yinchimaoliang committed
213
        results = self._format_results(results)
liyinhao's avatar
liyinhao committed
214
        from mmdet3d.core.evaluation import indoor_eval
yinchimaoliang's avatar
yinchimaoliang committed
215
216
217
218
        assert ('AP_IOU_THRESHHOLDS' in metric)
        gt_annos = [
            copy.deepcopy(info['annos']) for info in self.scannet_infos
        ]
liyinhao's avatar
liyinhao committed
219
220
        ap_result_str, ap_dict = indoor_eval(gt_annos, results, metric,
                                             self.class2type)
yinchimaoliang's avatar
yinchimaoliang committed
221
222
        return ap_dict

223
224
    def __len__(self):
        return len(self.scannet_infos)