custom_dataset.py 11.2 KB
Newer Older
YangXiuyu's avatar
YangXiuyu committed
1
2
3
4
5
6
7
import copy
import pickle
import os

import numpy as np

from ...ops.roiaware_pool3d import roiaware_pool3d_utils
jihanyang's avatar
jihanyang committed
8
from ...utils import box_utils, common_utils
YangXiuyu's avatar
YangXiuyu committed
9
10
from ..dataset import DatasetTemplate

jihanyang's avatar
jihanyang committed
11

YangXiuyu's avatar
YangXiuyu committed
12
class CustomDataset(DatasetTemplate):
jihanyang's avatar
jihanyang committed
13
    def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
YangXiuyu's avatar
YangXiuyu committed
14
15
16
17
18
19
20
21
22
23
24
25
26
        """
        Args:
            root_path:
            dataset_cfg:
            class_names:
            training:
            logger:
        """
        super().__init__(
            dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
        )
        self.split = self.dataset_cfg.DATA_SPLIT[self.mode]

jihanyang's avatar
jihanyang committed
27
        split_dir = os.path.join(self.root_path, 'ImageSets', (self.split + '.txt'))
YangXiuyu's avatar
YangXiuyu committed
28
29
30
        self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if os.path.exists(split_dir) else None

        self.custom_infos = []
jihanyang's avatar
jihanyang committed
31
32
        self.include_data(self.mode)
        self.map_class_to_kitti = self.dataset_cfg.MAP_CLASS_TO_KITTI
YangXiuyu's avatar
YangXiuyu committed
33

jihanyang's avatar
jihanyang committed
34
35
    def include_data(self, mode):
        self.logger.info('Loading Custom dataset.')
YangXiuyu's avatar
YangXiuyu committed
36
37
38
39
40
41
42
43
44
        custom_infos = []

        for info_path in self.dataset_cfg.INFO_PATH[mode]:
            info_path = self.root_path / info_path
            if not info_path.exists():
                continue
            with open(info_path, 'rb') as f:
                infos = pickle.load(f)
                custom_infos.extend(infos)
jihanyang's avatar
jihanyang committed
45

YangXiuyu's avatar
YangXiuyu committed
46
        self.custom_infos.extend(custom_infos)
jihanyang's avatar
jihanyang committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        self.logger.info('Total samples for CUSTOM dataset: %d' % (len(custom_infos)))

    def get_label(self, idx):
        label_file = self.root_path / 'labels' / ('%s.txt' % idx)
        assert label_file.exists()
        with open(label_file, 'r') as f:
            lines = f.readlines()

        # [N, 8]: (x y z dx dy dz heading_angle category_id)
        gt_boxes = [line.strip().split(' ') for line in lines]
        return np.array(gt_boxes, dtype=np.float32)

    def get_lidar(self, idx):
        lidar_file = self.root_path / 'points' / ('%s.npy' % idx)
        assert lidar_file.exists()
        point_features = np.load(lidar_file)
        return point_features

    def set_split(self, split):
        super().__init__(
            dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
            root_path=self.root_path, logger=self.logger
        )
        self.split = split

        split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
        self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None

    def __len__(self):
        if self._merge_all_iters_to_one_epoch:
            return len(self.sample_id_list) * self.total_epochs

        return len(self.custom_infos)

    def __getitem__(self, index):
        if self._merge_all_iters_to_one_epoch:
            index = index % len(self.custom_infos)

        info = copy.deepcopy(self.custom_infos[index])
        sample_idx = info['point_cloud']['lidar_idx']
        points = self.get_lidar(sample_idx)
        input_dict = {
            'frame_id': self.sample_id_list[index],
            'points': points
        }

        if 'annos' in info:
            annos = info['annos']
            annos = common_utils.drop_info_with_name(annos, name='DontCare')
            gt_names = annos['name']
            gt_boxes_lidar = annos['gt_boxes_lidar']
            input_dict.update({
                'gt_names': gt_names,
                'gt_boxes': gt_boxes_lidar
            })

        data_dict = self.prepare_data(data_dict=input_dict)

        return data_dict

    def evaluation(self, det_annos, class_names, **kwargs):
        if 'annos' not in self.custom_infos[0].keys():
            return 'No ground-truth boxes for evaluation', {}

        def kitti_eval(eval_det_annos, eval_gt_annos, map_name_to_kitti):
            from ..kitti.kitti_object_eval_python import eval as kitti_eval
            from ..kitti import kitti_utils

            kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
            kitti_utils.transform_annotations_to_kitti_format(
                eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
                info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            kitti_class_names = [map_name_to_kitti[x] for x in class_names]
            ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
                gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
            )
            return ap_result_str, ap_dict

        eval_det_annos = copy.deepcopy(det_annos)
        eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.custom_infos]

        if kwargs['eval_metric'] == 'kitti':
            ap_result_str, ap_dict = kitti_eval(eval_det_annos, eval_gt_annos, self.map_class_to_kitti)
        else:
            raise NotImplementedError
YangXiuyu's avatar
YangXiuyu committed
133

jihanyang's avatar
jihanyang committed
134
        return ap_result_str, ap_dict
YangXiuyu's avatar
YangXiuyu committed
135

jihanyang's avatar
jihanyang committed
136
    def get_infos(self, class_names, num_workers=4, has_label=True, sample_id_list=None, num_features=4):
YangXiuyu's avatar
YangXiuyu committed
137
138
        import concurrent.futures as futures

jihanyang's avatar
jihanyang committed
139
140
        class_names = np.array(class_names)

YangXiuyu's avatar
YangXiuyu committed
141
142
143
        def process_single_scene(sample_idx):
            print('%s sample_idx: %s' % (self.split, sample_idx))
            info = {}
jihanyang's avatar
jihanyang committed
144
            pc_info = {'num_features': num_features, 'lidar_idx': sample_idx}
YangXiuyu's avatar
YangXiuyu committed
145
146
147
148
            info['point_cloud'] = pc_info

            if has_label:
                annotations = {}
jihanyang's avatar
jihanyang committed
149
150
151
                gt_boxes_lidar = self.get_label(sample_idx)
                annotations['name'] = class_names[gt_boxes_lidar[:, -1].astype(np.int64)]
                annotations['gt_boxes_lidar'] = gt_boxes_lidar[:, :7]
YangXiuyu's avatar
YangXiuyu committed
152
                info['annos'] = annotations
jihanyang's avatar
jihanyang committed
153

YangXiuyu's avatar
YangXiuyu committed
154
            return info
jihanyang's avatar
jihanyang committed
155

YangXiuyu's avatar
YangXiuyu committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        sample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_list

        # create a thread pool to improve the velocity
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_scene, sample_id_list)
        return list(infos)

    def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'):
        import torch

        database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
        db_info_save_path = Path(self.root_path) / ('custom_dbinfos_%s.pkl' % split)

        database_save_path.mkdir(parents=True, exist_ok=True)
        all_db_infos = {}

        with open(info_path, 'rb') as f:
            infos = pickle.load(f)

        for k in range(len(infos)):
            print('gt_database sample: %d/%d' % (k + 1, len(infos)))
            info = infos[k]
            sample_idx = info['point_cloud']['lidar_idx']
jihanyang's avatar
jihanyang committed
179
            points = self.get_lidar(sample_idx)
YangXiuyu's avatar
YangXiuyu committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            annos = info['annos']
            names = annos['name']
            gt_boxes = annos['gt_boxes_lidar']

            num_obj = gt_boxes.shape[0]
            point_indices = roiaware_pool3d_utils.points_in_boxes_cpu(
                torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes)
            ).numpy()  # (nboxes, npoints)

            for i in range(num_obj):
                filename = '%s_%s_%d.bin' % (sample_idx, names[i], i)
                filepath = database_save_path / filename
                gt_points = points[point_indices[i] > 0]

                gt_points[:, :3] -= gt_boxes[i, :3]
                with open(filepath, 'w') as f:
                    gt_points.tofile(f)

                if (used_classes is None) or names[i] in used_classes:
                    db_path = str(filepath.relative_to(self.root_path))  # gt_database/xxxxx.bin
                    db_info = {'name': names[i], 'path': db_path, 'gt_idx': i,
                               'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0]}
                    if names[i] in all_db_infos:
                        all_db_infos[names[i]].append(db_info)
                    else:
                        all_db_infos[names[i]] = [db_info]

        # Output the num of all classes in database
        for k, v in all_db_infos.items():
            print('Database %s: %d' % (k, len(v)))

        with open(db_info_save_path, 'wb') as f:
            pickle.dump(all_db_infos, f)

    @staticmethod
jihanyang's avatar
jihanyang committed
215
216
217
218
219
220
221
222
223
224
225
226
227
    def create_label_file_with_name_and_box(class_names, gt_names, gt_boxes, save_label_path):
        with open(save_label_path, 'w') as f:
            for idx in range(gt_boxes.shape[0]):
                boxes = gt_boxes[idx]
                name = gt_names[idx]
                if name not in class_names:
                    continue
                category_id = class_names.index(name)
                line = "{x} {y} {z} {l} {w} {h} {angle} {category_id}\n".format(
                    x=boxes[0], y=boxes[1], z=(boxes[2]), l=boxes[3],
                    w=boxes[4], h=boxes[5], angle=boxes[6], category_id=category_id
                )
                f.write(line)
YangXiuyu's avatar
YangXiuyu committed
228
229
230


def create_custom_infos(dataset_cfg, class_names, data_path, save_path, workers=4):
jihanyang's avatar
jihanyang committed
231
232
233
234
    dataset = CustomDataset(
        dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
        training=False, logger=common_utils.create_logger()
    )
YangXiuyu's avatar
YangXiuyu committed
235
    train_split, val_split = 'train', 'val'
jihanyang's avatar
jihanyang committed
236
    num_features = len(dataset_cfg.POINT_FEATURE_ENCODING.src_feature_list)
YangXiuyu's avatar
YangXiuyu committed
237
238

    train_filename = save_path / ('custom_infos_%s.pkl' % train_split)
jihanyang's avatar
jihanyang committed
239
    val_filename = save_path / ('custom_infos_%s.pkl' % val_split)
YangXiuyu's avatar
YangXiuyu committed
240
241
242
243

    print('------------------------Start to generate data infos------------------------')

    dataset.set_split(train_split)
jihanyang's avatar
jihanyang committed
244
245
246
    custom_infos_train = dataset.get_infos(
        class_names, num_workers=workers, has_label=True, num_features=num_features
    )
YangXiuyu's avatar
YangXiuyu committed
247
248
249
250
    with open(train_filename, 'wb') as f:
        pickle.dump(custom_infos_train, f)
    print('Custom info train file is save to %s' % train_filename)

jihanyang's avatar
jihanyang committed
251
252
253
254
255
256
257
    dataset.set_split(val_split)
    custom_infos_val = dataset.get_infos(
        class_names, num_workers=workers, has_label=True, num_features=num_features
    )
    with open(val_filename, 'wb') as f:
        pickle.dump(custom_infos_val, f)
    print('Custom info train file is save to %s' % val_filename)
YangXiuyu's avatar
YangXiuyu committed
258
259
260
261
262
263

    print('------------------------Start create groundtruth database for data augmentation------------------------')
    dataset.set_split(train_split)
    dataset.create_groundtruth_database(train_filename, split=train_split)
    print('------------------------Data preparation done------------------------')

jihanyang's avatar
jihanyang committed
264
265

if __name__ == '__main__':
YangXiuyu's avatar
YangXiuyu committed
266
    import sys
jihanyang's avatar
jihanyang committed
267

YangXiuyu's avatar
YangXiuyu committed
268
269
270
271
    if sys.argv.__len__() > 1 and sys.argv[1] == 'create_custom_infos':
        import yaml
        from pathlib import Path
        from easydict import EasyDict
jihanyang's avatar
jihanyang committed
272

YangXiuyu's avatar
YangXiuyu committed
273
274
275
276
        dataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[2])))
        ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
        create_custom_infos(
            dataset_cfg=dataset_cfg,
jihanyang's avatar
jihanyang committed
277
            class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
YangXiuyu's avatar
YangXiuyu committed
278
            data_path=ROOT_DIR / 'data' / 'custom',
jihanyang's avatar
jihanyang committed
279
            save_path=ROOT_DIR / 'data' / 'custom',
YangXiuyu's avatar
YangXiuyu committed
280
        )