custom_3d.py 5.6 KB
Newer Older
1
2
import mmcv
import numpy as np
zhangwenwei's avatar
zhangwenwei committed
3
from torch.utils.data import Dataset
4
5
6
7
8
9

from mmdet.datasets import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
10
class Custom3DDataset(Dataset):
11
12

    def __init__(self,
zhangwenwei's avatar
zhangwenwei committed
13
                 data_root,
14
15
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
16
                 classes=None,
zhangwenwei's avatar
zhangwenwei committed
17
18
                 modality=None,
                 test_mode=False):
19
        super().__init__()
zhangwenwei's avatar
zhangwenwei committed
20
21
        self.data_root = data_root
        self.ann_file = ann_file
22
        self.test_mode = test_mode
zhangwenwei's avatar
zhangwenwei committed
23
24
25
26
        self.modality = modality

        self.CLASSES = self.get_classes(classes)
        self.data_infos = self.load_annotations(self.ann_file)
27
28
29
30

        if pipeline is not None:
            self.pipeline = Compose(pipeline)

zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
35
36
        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()

    def load_annotations(self, ann_file):
        return mmcv.load(ann_file)
37
38
39
40
41
42
43
44

    def get_data_info(self, index):
        info = self.data_infos[index]
        sample_idx = info['point_cloud']['lidar_idx']
        pts_filename = self._get_pts_filename(sample_idx)

        input_dict = dict(pts_filename=pts_filename)

zhangwenwei's avatar
zhangwenwei committed
45
        if not self.test_mode:
liyinhao's avatar
liyinhao committed
46
            annos = self.get_ann_info(index)
zhangwenwei's avatar
zhangwenwei committed
47
48
49
            input_dict['ann_info'] = annos
            if len(annos['gt_bboxes_3d']) == 0:
                return None
50
51
        return input_dict

zhangwenwei's avatar
zhangwenwei committed
52
53
54
55
    def pre_pipeline(self, results):
        results['bbox3d_fields'] = []
        results['pts_mask_fields'] = []
        results['pts_seg_fields'] = []
56

liyinhao's avatar
liyinhao committed
57
58
    def prepare_train_data(self, index):
        input_dict = self.get_data_info(index)
59
60
        if input_dict is None:
            return None
zhangwenwei's avatar
zhangwenwei committed
61
        self.pre_pipeline(input_dict)
62
        example = self.pipeline(input_dict)
zhangwenwei's avatar
zhangwenwei committed
63
        if example is None or len(example['gt_bboxes_3d']._data) == 0:
64
65
66
            return None
        return example

67
68
    def prepare_test_data(self, index):
        input_dict = self.get_data_info(index)
zhangwenwei's avatar
zhangwenwei committed
69
        self.pre_pipeline(input_dict)
70
71
        example = self.pipeline(input_dict)
        return example
72

liyinhao's avatar
liyinhao committed
73
74
    @classmethod
    def get_classes(cls, classes=None):
75
76
        """Get class names of current dataset.

liyinhao's avatar
liyinhao committed
77
78
79
80
81
82
        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.
zhangwenwei's avatar
zhangwenwei committed
83
84
85

        Return:
            list[str]: return the list of class names
liyinhao's avatar
liyinhao committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        """
        if classes is None:
            return cls.CLASSES

        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

100
    def _generate_annotations(self, output):
liyinhao's avatar
liyinhao committed
101
        """Generate annotations.
102
103
104
105

        Transform results of the model to the form of the evaluation.

        Args:
liyinhao's avatar
liyinhao committed
106
            output (list): The output of the model.
107
108
109
110
111
112
113
114
115
        """
        result = []
        bs = len(output)
        for i in range(bs):
            pred_list_i = list()
            pred_boxes = output[i]
            box3d_depth = pred_boxes['box3d_lidar']
            if box3d_depth is not None:
                label_preds = pred_boxes['label_preds']
116
                scores = pred_boxes['scores']
117
                label_preds = label_preds.detach().cpu().numpy()
zhangwenwei's avatar
zhangwenwei committed
118
                for j in range(box3d_depth.shape[0]):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
                    bbox_lidar = box3d_depth[j]  # [7] in lidar
                    bbox_lidar_bottom = bbox_lidar.copy()
                    pred_list_i.append(
                        (label_preds[j], bbox_lidar_bottom, scores[j]))
                result.append(pred_list_i)
            else:
                result.append(pred_list_i)

        return result

    def format_results(self, outputs):
        results = []
        for output in outputs:
            result = self._generate_annotations(output)
            results.append(result)
        return results

    def evaluate(self, results, metric=None):
        """Evaluate.

        Evaluation in indoor protocol.

        Args:
liyinhao's avatar
liyinhao committed
142
143
            results (list): List of result.
            metric (list[float]): AP IoU thresholds.
144
145
146
147
        """
        results = self.format_results(results)
        from mmdet3d.core.evaluation import indoor_eval
        assert len(metric) > 0
148
        gt_annos = [info['annos'] for info in self.data_infos]
zhangwenwei's avatar
zhangwenwei committed
149
150
        label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
        ret_dict = indoor_eval(gt_annos, results, metric, label2cat)
liyinhao's avatar
liyinhao committed
151
        return ret_dict
zhangwenwei's avatar
zhangwenwei committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def __len__(self):
        return len(self.data_infos)

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_data(idx)
        while True:
            data = self.prepare_train_data(idx)
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        In 3D datasets, they are all the same, thus are all zeros

        """
        self.flag = np.zeros(len(self), dtype=np.uint8)