indoor_base_dataset.py 4.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import copy

import mmcv
import numpy as np
import torch.utils.data as torch_data

from mmdet.datasets import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
12
class IndoorBaseDataset(torch_data.Dataset):
13
14
15
16
17

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
18
                 classes=None,
19
20
21
22
                 test_mode=False,
                 with_label=True):
        super().__init__()
        self.root_path = root_path
liyinhao's avatar
liyinhao committed
23
        self.CLASSES = self.get_classes(classes)
24
        self.test_mode = test_mode
liyinhao's avatar
liyinhao committed
25
        self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
26
        mmcv.check_file_exist(ann_file)
liyinhao's avatar
liyinhao committed
27
        self.data_infos = mmcv.load(ann_file)
28
29

        # dataset config
liyinhao's avatar
liyinhao committed
30
        self.num_class = len(self.CLASSES)
31
32
33
34
        if pipeline is not None:
            self.pipeline = Compose(pipeline)
        self.with_label = with_label

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    def __len__(self):
        return len(self.data_infos)

    def get_data_info(self, index):
        info = self.data_infos[index]
        sample_idx = info['point_cloud']['lidar_idx']
        pts_filename = self._get_pts_filename(sample_idx)

        input_dict = dict(pts_filename=pts_filename)

        if self.with_label:
            annos = self._get_ann_info(index, sample_idx)
            input_dict.update(annos)
        if len(input_dict['gt_bboxes_3d']) == 0:
            return None
        return input_dict

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

56
57
    def __getitem__(self, idx):
        if self.test_mode:
liyinhao's avatar
liyinhao committed
58
            return self.prepare_test_data(idx)
59
        while True:
liyinhao's avatar
liyinhao committed
60
            data = self.prepare_train_data(idx)
61
62
63
64
65
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

liyinhao's avatar
liyinhao committed
66
67
    def prepare_train_data(self, index):
        input_dict = self.get_data_info(index)
68
69
70
71
72
73
74
        if input_dict is None:
            return None
        example = self.pipeline(input_dict)
        if len(example['gt_bboxes_3d']._data) == 0:
            return None
        return example

75
76
77
78
    def prepare_test_data(self, index):
        input_dict = self.get_data_info(index)
        example = self.pipeline(input_dict)
        return example
79

liyinhao's avatar
liyinhao committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    @classmethod
    def get_classes(cls, classes=None):
        """Get class names of current dataset
        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.
        """
        if classes is None:
            return cls.CLASSES

        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def _generate_annotations(self, output):
        """Generate Annotations.

        Transform results of the model to the form of the evaluation.

        Args:
            output (List): The output of the model.
        """
        result = []
        bs = len(output)
        for i in range(bs):
            pred_list_i = list()
            pred_boxes = output[i]
            box3d_depth = pred_boxes['box3d_lidar']
            if box3d_depth is not None:
                label_preds = pred_boxes['label_preds']
                scores = pred_boxes['scores'].detach().cpu().numpy()
                label_preds = label_preds.detach().cpu().numpy()
                num_proposal = box3d_depth.shape[0]
                for j in range(num_proposal):
                    bbox_lidar = box3d_depth[j]  # [7] in lidar
                    bbox_lidar_bottom = bbox_lidar.copy()
                    pred_list_i.append(
                        (label_preds[j], bbox_lidar_bottom, scores[j]))
                result.append(pred_list_i)
            else:
                result.append(pred_list_i)

        return result

    def format_results(self, outputs):
        results = []
        for output in outputs:
            result = self._generate_annotations(output)
            results.append(result)
        return results

    def evaluate(self, results, metric=None):
        """Evaluate.

        Evaluation in indoor protocol.

        Args:
            results (List): List of result.
            metric (List[float]): AP IoU thresholds.
        """
        results = self.format_results(results)
        from mmdet3d.core.evaluation import indoor_eval
        assert len(metric) > 0
liyinhao's avatar
liyinhao committed
152
        gt_annos = [copy.deepcopy(info['annos']) for info in self.data_infos]
liyinhao's avatar
liyinhao committed
153
154
        ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat)
        return ret_dict