indoor_base_dataset.py 4.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import copy

import mmcv
import numpy as np
import torch.utils.data as torch_data

from mmdet.datasets import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
12
class IndoorBaseDataset(torch_data.Dataset):
13
14
15
16
17
18

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
                 training=False,
liyinhao's avatar
liyinhao committed
19
                 classes=None,
20
21
22
23
                 test_mode=False,
                 with_label=True):
        super().__init__()
        self.root_path = root_path
liyinhao's avatar
liyinhao committed
24
        self.CLASSES = classes if classes else self.CLASSES
25
26
27
        self.test_mode = test_mode
        self.training = training
        self.mode = 'TRAIN' if self.training else 'TEST'
liyinhao's avatar
liyinhao committed
28
        self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
29
        mmcv.check_file_exist(ann_file)
liyinhao's avatar
liyinhao committed
30
        self.data_infos = mmcv.load(ann_file)
31
32

        # dataset config
liyinhao's avatar
liyinhao committed
33
        self.num_class = len(self.CLASSES)
34
35
36
37
38
39
        if pipeline is not None:
            self.pipeline = Compose(pipeline)
        self.with_label = with_label

    def __getitem__(self, idx):
        if self.test_mode:
liyinhao's avatar
liyinhao committed
40
            return self.prepare_test_data(idx)
41
        while True:
liyinhao's avatar
liyinhao committed
42
            data = self.prepare_train_data(idx)
43
44
45
46
47
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

liyinhao's avatar
liyinhao committed
48
49
    def prepare_test_data(self, index):
        input_dict = self.get_data_info(index)
50
51
52
        example = self.pipeline(input_dict)
        return example

liyinhao's avatar
liyinhao committed
53
54
    def prepare_train_data(self, index):
        input_dict = self.get_data_info(index)
55
56
57
58
59
60
61
        if input_dict is None:
            return None
        example = self.pipeline(input_dict)
        if len(example['gt_bboxes_3d']._data) == 0:
            return None
        return example

liyinhao's avatar
liyinhao committed
62
63
    def get_data_info(self, index):
        info = self.data_infos[index]
64
65
66
67
68
69
70
71
        sample_idx = info['point_cloud']['lidar_idx']
        pts_filename = self._get_pts_filename(sample_idx)

        input_dict = dict(pts_filename=pts_filename)

        if self.with_label:
            annos = self._get_ann_info(index, sample_idx)
            input_dict.update(annos)
liyinhao's avatar
liyinhao committed
72
73
        if len(input_dict['gt_bboxes_3d']) == 0:
            return None
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        return input_dict

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

    def _generate_annotations(self, output):
        """Generate Annotations.

        Transform results of the model to the form of the evaluation.

        Args:
            output (List): The output of the model.
        """
        result = []
        bs = len(output)
        for i in range(bs):
            pred_list_i = list()
            pred_boxes = output[i]
            box3d_depth = pred_boxes['box3d_lidar']
            if box3d_depth is not None:
                label_preds = pred_boxes['label_preds']
                scores = pred_boxes['scores'].detach().cpu().numpy()
                label_preds = label_preds.detach().cpu().numpy()
                num_proposal = box3d_depth.shape[0]
                for j in range(num_proposal):
                    bbox_lidar = box3d_depth[j]  # [7] in lidar
                    bbox_lidar_bottom = bbox_lidar.copy()
                    pred_list_i.append(
                        (label_preds[j], bbox_lidar_bottom, scores[j]))
                result.append(pred_list_i)
            else:
                result.append(pred_list_i)

        return result

    def format_results(self, outputs):
        results = []
        for output in outputs:
            result = self._generate_annotations(output)
            results.append(result)
        return results

    def evaluate(self, results, metric=None):
        """Evaluate.

        Evaluation in indoor protocol.

        Args:
            results (List): List of result.
            metric (List[float]): AP IoU thresholds.
        """
        results = self.format_results(results)
        from mmdet3d.core.evaluation import indoor_eval
        assert len(metric) > 0
liyinhao's avatar
liyinhao committed
129
        gt_annos = [copy.deepcopy(info['annos']) for info in self.data_infos]
liyinhao's avatar
liyinhao committed
130
131
        ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat)
        return ret_dict
132
133

    def __len__(self):
liyinhao's avatar
liyinhao committed
134
        return len(self.data_infos)