indoor_base_dataset.py 4.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
import mmcv
import numpy as np
import torch.utils.data as torch_data

from mmdet.datasets import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
10
class IndoorBaseDataset(torch_data.Dataset):
11
12
13
14
15

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
16
                 classes=None,
17
18
19
20
                 test_mode=False,
                 with_label=True):
        super().__init__()
        self.root_path = root_path
liyinhao's avatar
liyinhao committed
21
        self.CLASSES = self.get_classes(classes)
22
        self.test_mode = test_mode
liyinhao's avatar
liyinhao committed
23
        self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
24
        mmcv.check_file_exist(ann_file)
liyinhao's avatar
liyinhao committed
25
        self.data_infos = mmcv.load(ann_file)
26
27
28
29
30

        if pipeline is not None:
            self.pipeline = Compose(pipeline)
        self.with_label = with_label

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def __len__(self):
        return len(self.data_infos)

    def get_data_info(self, index):
        info = self.data_infos[index]
        sample_idx = info['point_cloud']['lidar_idx']
        pts_filename = self._get_pts_filename(sample_idx)

        input_dict = dict(pts_filename=pts_filename)

        if self.with_label:
            annos = self._get_ann_info(index, sample_idx)
            input_dict.update(annos)
        if len(input_dict['gt_bboxes_3d']) == 0:
            return None
        return input_dict

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

52
53
    def __getitem__(self, idx):
        if self.test_mode:
liyinhao's avatar
liyinhao committed
54
            return self.prepare_test_data(idx)
55
        while True:
liyinhao's avatar
liyinhao committed
56
            data = self.prepare_train_data(idx)
57
58
59
60
61
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

liyinhao's avatar
liyinhao committed
62
63
    def prepare_train_data(self, index):
        input_dict = self.get_data_info(index)
64
65
66
67
68
69
70
        if input_dict is None:
            return None
        example = self.pipeline(input_dict)
        if len(example['gt_bboxes_3d']._data) == 0:
            return None
        return example

71
72
73
74
    def prepare_test_data(self, index):
        input_dict = self.get_data_info(index)
        example = self.pipeline(input_dict)
        return example
75

liyinhao's avatar
liyinhao committed
76
77
    @classmethod
    def get_classes(cls, classes=None):
78
79
        """Get class names of current dataset.

liyinhao's avatar
liyinhao committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.
        """
        if classes is None:
            return cls.CLASSES

        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

100
    def _generate_annotations(self, output):
liyinhao's avatar
liyinhao committed
101
        """Generate annotations.
102
103
104
105

        Transform results of the model to the form of the evaluation.

        Args:
liyinhao's avatar
liyinhao committed
106
            output (list): The output of the model.
107
108
109
110
111
112
113
114
115
        """
        result = []
        bs = len(output)
        for i in range(bs):
            pred_list_i = list()
            pred_boxes = output[i]
            box3d_depth = pred_boxes['box3d_lidar']
            if box3d_depth is not None:
                label_preds = pred_boxes['label_preds']
116
                scores = pred_boxes['scores']
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                label_preds = label_preds.detach().cpu().numpy()
                num_proposal = box3d_depth.shape[0]
                for j in range(num_proposal):
                    bbox_lidar = box3d_depth[j]  # [7] in lidar
                    bbox_lidar_bottom = bbox_lidar.copy()
                    pred_list_i.append(
                        (label_preds[j], bbox_lidar_bottom, scores[j]))
                result.append(pred_list_i)
            else:
                result.append(pred_list_i)

        return result

    def format_results(self, outputs):
        results = []
        for output in outputs:
            result = self._generate_annotations(output)
            results.append(result)
        return results

    def evaluate(self, results, metric=None):
        """Evaluate.

        Evaluation in indoor protocol.

        Args:
liyinhao's avatar
liyinhao committed
143
144
            results (list): List of result.
            metric (list[float]): AP IoU thresholds.
145
146
147
148
        """
        results = self.format_results(results)
        from mmdet3d.core.evaluation import indoor_eval
        assert len(metric) > 0
149
        gt_annos = [info['annos'] for info in self.data_infos]
liyinhao's avatar
liyinhao committed
150
151
        ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat)
        return ret_dict