scannet_dataset.py 6.77 KB
Newer Older
yinchimaoliang's avatar
yinchimaoliang committed
1
import copy
2
3
4
5
6
7
8
9
10
11
12
13
import os
import os.path as osp

import mmcv
import numpy as np
import torch.utils.data as torch_data

from mmdet.datasets import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
yinchimaoliang's avatar
yinchimaoliang committed
14
class ScannetDataset(torch_data.Dataset):
liyinhao's avatar
liyinhao committed
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    class2type = {
        0: 'cabinet',
        1: 'bed',
        2: 'chair',
        3: 'sofa',
        4: 'table',
        5: 'door',
        6: 'window',
        7: 'bookshelf',
        8: 'picture',
        9: 'counter',
        10: 'desk',
        11: 'curtain',
        12: 'refrigerator',
        13: 'showercurtrain',
        14: 'toilet',
        15: 'sink',
        16: 'bathtub',
        17: 'garbagebin'
    }
    CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
               'bookshelf', 'picture', 'counter', 'desk', 'curtain',
               'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
               'garbagebin')

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
                 training=False,
                 class_names=None,
yinchimaoliang's avatar
yinchimaoliang committed
47
48
                 test_mode=False,
                 with_label=True):
49
50
51
52
        super().__init__()
        self.root_path = root_path
        self.class_names = class_names if class_names else self.CLASSES

liyinhao's avatar
liyinhao committed
53
        self.data_path = osp.join(root_path, 'scannet_train_instance_data')
54
55
56
57
        self.test_mode = test_mode
        self.training = training
        self.mode = 'TRAIN' if self.training else 'TEST'

liyinhao's avatar
liyinhao committed
58
        mmcv.check_file_exist(ann_file)
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.scannet_infos = mmcv.load(ann_file)

        # dataset config
        self.num_class = len(self.class_names)
        self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0]
        self.nyu40ids = np.array(
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
        self.nyu40id2class = {
            nyu40id: i
            for i, nyu40id in enumerate(list(self.nyu40ids))
        }
        if pipeline is not None:
            self.pipeline = Compose(pipeline)
yinchimaoliang's avatar
yinchimaoliang committed
72
        self.with_label = with_label
73
74
75

    def __getitem__(self, idx):
        if self.test_mode:
yinchimaoliang's avatar
yinchimaoliang committed
76
            return self._prepare_test_data(idx)
77
        while True:
yinchimaoliang's avatar
yinchimaoliang committed
78
            data = self._prepare_train_data(idx)
79
80
81
82
83
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

yinchimaoliang's avatar
yinchimaoliang committed
84
85
    def _prepare_test_data(self, index):
        input_dict = self._get_sensor_data(index)
86
87
88
        example = self.pipeline(input_dict)
        return example

yinchimaoliang's avatar
yinchimaoliang committed
89
90
91
    def _prepare_train_data(self, index):
        input_dict = self._get_sensor_data(index)
        input_dict = self._train_pre_pipeline(input_dict)
92
93
94
95
96
97
98
        if input_dict is None:
            return None
        example = self.pipeline(input_dict)
        if len(example['gt_bboxes_3d']._data) == 0:
            return None
        return example

yinchimaoliang's avatar
yinchimaoliang committed
99
    def _train_pre_pipeline(self, input_dict):
100
101
102
103
        if len(input_dict['gt_bboxes_3d']) == 0:
            return None
        return input_dict

yinchimaoliang's avatar
yinchimaoliang committed
104
    def _get_sensor_data(self, index):
105
106
        info = self.scannet_infos[index]
        sample_idx = info['point_cloud']['lidar_idx']
yinchimaoliang's avatar
yinchimaoliang committed
107
        pts_filename = self._get_pts_filename(sample_idx)
108

yinchimaoliang's avatar
yinchimaoliang committed
109
        input_dict = dict(pts_filename=pts_filename)
110
111

        if self.with_label:
yinchimaoliang's avatar
yinchimaoliang committed
112
            annos = self._get_ann_info(index, sample_idx)
113
114
115
116
            input_dict.update(annos)

        return input_dict

yinchimaoliang's avatar
yinchimaoliang committed
117
    def _get_pts_filename(self, sample_idx):
liyinhao's avatar
liyinhao committed
118
        pts_filename = os.path.join(self.data_path, f'{sample_idx}_vert.npy')
yinchimaoliang's avatar
yinchimaoliang committed
119
120
        mmcv.check_file_exist(pts_filename)
        return pts_filename
121

yinchimaoliang's avatar
yinchimaoliang committed
122
    def _get_ann_info(self, index, sample_idx):
123
        # Use index to get the annos, thus the evalhook could also use this api
yinchimaoliang's avatar
yinchimaoliang committed
124
        info = self.scannet_infos[index]
125
126
        if info['annos']['gt_num'] != 0:
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']  # k, 6
yinchimaoliang's avatar
yinchimaoliang committed
127
128
            gt_labels = info['annos']['class']
            gt_bboxes_3d_mask = np.ones_like(gt_labels).astype(np.bool)
129
130
        else:
            gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
yinchimaoliang's avatar
yinchimaoliang committed
131
132
            gt_labels = np.zeros(1, ).astype(np.bool)
            gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)
133
        pts_instance_mask_path = osp.join(self.data_path,
liyinhao's avatar
liyinhao committed
134
                                          f'{sample_idx}_ins_label.npy')
135
        pts_semantic_mask_path = osp.join(self.data_path,
liyinhao's avatar
liyinhao committed
136
                                          f'{sample_idx}_sem_label.npy')
137
138
139
140
141
142
143
144
145
146
147
148
149

        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels=gt_labels,
            gt_bboxes_3d_mask=gt_bboxes_3d_mask,
            pts_instance_mask_path=pts_instance_mask_path,
            pts_semantic_mask_path=pts_semantic_mask_path)
        return anns_results

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

yinchimaoliang's avatar
yinchimaoliang committed
150
    def _generate_annotations(self, output):
liyinhao's avatar
liyinhao committed
151
152
153
154
155
156
157
        """Generate Annotations.

        Transform results of the model to the form of the evaluation.

        Args:
            output (List): The output of the model.
        """
yinchimaoliang's avatar
yinchimaoliang committed
158
159
160
161
162
163
164
        result = []
        bs = len(output)
        for i in range(bs):
            pred_list_i = list()
            pred_boxes = output[i]
            box3d_depth = pred_boxes['box3d_lidar']
            if box3d_depth is not None:
yinchimaoliang's avatar
yinchimaoliang committed
165
                label_preds = pred_boxes['label_preds']
yinchimaoliang's avatar
yinchimaoliang committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
                scores = pred_boxes['scores'].detach().cpu().numpy()
                label_preds = label_preds.detach().cpu().numpy()
                num_proposal = box3d_depth.shape[0]
                for j in range(num_proposal):
                    bbox_lidar = box3d_depth[j]  # [7] in lidar
                    bbox_lidar_bottom = bbox_lidar.copy()
                    pred_list_i.append(
                        (label_preds[j], bbox_lidar_bottom, scores[j]))
                result.append(pred_list_i)
            else:
                result.append(pred_list_i)

        return result

liyinhao's avatar
liyinhao committed
180
    def format_results(self, outputs):
yinchimaoliang's avatar
yinchimaoliang committed
181
182
        results = []
        for output in outputs:
yinchimaoliang's avatar
yinchimaoliang committed
183
            result = self._generate_annotations(output)
yinchimaoliang's avatar
yinchimaoliang committed
184
185
186
            results.append(result)
        return results

liyinhao's avatar
liyinhao committed
187
188
189
190
191
192
193
    def evaluate(self, results, metric=None):
        """Evaluate.

        Evaluation in indoor protocol.

        Args:
            results (List): List of result.
liyinhao's avatar
liyinhao committed
194
            metric (List[float]): AP IoU thresholds.
liyinhao's avatar
liyinhao committed
195
        """
liyinhao's avatar
liyinhao committed
196
        results = self.format_results(results)
liyinhao's avatar
liyinhao committed
197
        from mmdet3d.core.evaluation import indoor_eval
liyinhao's avatar
liyinhao committed
198
        assert len(metric) > 0
yinchimaoliang's avatar
yinchimaoliang committed
199
200
201
        gt_annos = [
            copy.deepcopy(info['annos']) for info in self.scannet_infos
        ]
liyinhao's avatar
liyinhao committed
202
203
        ap_result_str, ap_dict = indoor_eval(gt_annos, results, metric,
                                             self.class2type)
yinchimaoliang's avatar
yinchimaoliang committed
204
205
        return ap_dict

206
207
    def __len__(self):
        return len(self.scannet_infos)