waymo_dataset.py 34.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset
# Reference https://github.com/open-mmlab/OpenPCDet
# Written by Shaoshuai Shi, Chaoxu Guo
# All Rights Reserved 2019-2020.

import os
import pickle
import copy
import numpy as np
import torch
11
import multiprocessing
12
13
import SharedArray
import torch.distributed as dist
14
from tqdm import tqdm
Shaoshuai Shi's avatar
Shaoshuai Shi committed
15
from pathlib import Path
16
17
from functools import partial

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils
from ..dataset import DatasetTemplate


class WaymoDataset(DatasetTemplate):
    def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
        super().__init__(
            dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
        )
        self.data_path = self.root_path / self.dataset_cfg.PROCESSED_DATA_TAG
        self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
        split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
        self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]

        self.infos = []
34
        self.seq_name_to_infos = self.include_waymo_data(self.mode)
35

36
37
38
39
40
        self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
        if self.use_shared_memory:
            self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
            self.load_data_to_shared_memory()

41
42
43
44
45
46
47
48
49
    def set_split(self, split):
        super().__init__(
            dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
            root_path=self.root_path, logger=self.logger
        )
        self.split = split
        split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
        self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
        self.infos = []
50
        self.seq_name_to_infos = self.include_waymo_data(self.mode)
51
52
53
54

    def include_waymo_data(self, mode):
        self.logger.info('Loading Waymo dataset')
        waymo_infos = []
55
        seq_name_to_infos = {}
56
57
58
59
60
61
62
63
64
65
66
67
68

        num_skipped_infos = 0
        for k in range(len(self.sample_sequence_list)):
            sequence_name = os.path.splitext(self.sample_sequence_list[k])[0]
            info_path = self.data_path / sequence_name / ('%s.pkl' % sequence_name)
            info_path = self.check_sequence_name_with_all_version(info_path)
            if not info_path.exists():
                num_skipped_infos += 1
                continue
            with open(info_path, 'rb') as f:
                infos = pickle.load(f)
                waymo_infos.extend(infos)

69
70
            seq_name_to_infos[infos[0]['point_cloud']['lidar_sequence']] = infos

71
72
73
74
75
76
77
78
79
80
81
        self.infos.extend(waymo_infos[:])
        self.logger.info('Total skipped info %s' % num_skipped_infos)
        self.logger.info('Total samples for Waymo dataset: %d' % (len(waymo_infos)))

        if self.dataset_cfg.SAMPLED_INTERVAL[mode] > 1:
            sampled_waymo_infos = []
            for k in range(0, len(self.infos), self.dataset_cfg.SAMPLED_INTERVAL[mode]):
                sampled_waymo_infos.append(self.infos[k])
            self.infos = sampled_waymo_infos
            self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))

82
83
        return seq_name_to_infos

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def load_data_to_shared_memory(self):
        self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')

        cur_rank, num_gpus = common_utils.get_dist_info()
        all_infos = self.infos[:self.shared_memory_file_limit] \
            if self.shared_memory_file_limit < len(self.infos) else self.infos
        cur_infos = all_infos[cur_rank::num_gpus]
        for info in cur_infos:
            pc_info = info['point_cloud']
            sequence_name = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']

            sa_key = f'{sequence_name}___{sample_idx}'
            if os.path.exists(f"/dev/shm/{sa_key}"):
                continue

            points = self.get_lidar(sequence_name, sample_idx)
            common_utils.sa_create(f"shm://{sa_key}", points)

        dist.barrier()
        self.logger.info('Training data has been saved to shared memory')

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def clean_shared_memory(self):
        self.logger.info(f'Clean training data from shared memory (file limit={self.shared_memory_file_limit})')

        cur_rank, num_gpus = common_utils.get_dist_info()
        all_infos = self.infos[:self.shared_memory_file_limit] \
            if self.shared_memory_file_limit < len(self.infos) else self.infos
        cur_infos = all_infos[cur_rank::num_gpus]
        for info in cur_infos:
            pc_info = info['point_cloud']
            sequence_name = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']

            sa_key = f'{sequence_name}___{sample_idx}'
            if not os.path.exists(f"/dev/shm/{sa_key}"):
                continue

            SharedArray.delete(f"shm://{sa_key}")

124
125
        if num_gpus > 1:
            dist.barrier()
126
127
        self.logger.info('Training data has been deleted from shared memory')

128
129
    @staticmethod
    def check_sequence_name_with_all_version(sequence_file):
130
131
132
133
134
135
136
137
138
139
140
141
        if not sequence_file.exists():
            found_sequence_file = sequence_file
            for pre_text in ['training', 'validation', 'testing']:
                if not sequence_file.exists():
                    temp_sequence_file = Path(str(sequence_file).replace('segment', pre_text + '_segment'))
                    if temp_sequence_file.exists():
                        found_sequence_file = temp_sequence_file
                        break
            if not found_sequence_file.exists():
                found_sequence_file = Path(str(sequence_file).replace('_with_camera_labels', ''))
            if found_sequence_file.exists():
                sequence_file = found_sequence_file
142
143
        return sequence_file

144
    def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1, update_info_only=False):
145
        from . import waymo_utils
Shaoshuai Shi's avatar
Shaoshuai Shi committed
146
147
        print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
              % (sampled_interval, len(self.sample_sequence_list)))
148
149
150

        process_single_sequence = partial(
            waymo_utils.process_single_sequence,
151
            save_path=save_path, sampled_interval=sampled_interval, has_label=has_label, update_info_only=update_info_only
152
153
154
155
156
157
        )
        sample_sequence_file_list = [
            self.check_sequence_name_with_all_version(raw_data_path / sequence_file)
            for sequence_file in self.sample_sequence_list
        ]

158
159
        with multiprocessing.Pool(num_workers) as p:
            sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list),
160
                                       total=len(sample_sequence_file_list)))
161

162
163
164
165
166
167
168
169
        all_sequences_infos = [item for infos in sequence_infos for item in infos]
        return all_sequences_infos

    def get_lidar(self, sequence_name, sample_idx):
        lidar_file = self.data_path / sequence_name / ('%04d.npy' % sample_idx)
        point_features = np.load(lidar_file)  # (N, 7): [x, y, z, intensity, elongation, NLZ_flag]

        points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5]
170
171
        if not self.dataset_cfg.get('DISABLE_NLZ_FLAG_ON_POINTS', False):
            points_all = points_all[NLZ_flag == -1]
172
173
174
        points_all[:, 3] = np.tanh(points_all[:, 3])
        return points_all

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    def get_sequence_data(self, info, points, sequence_name, sample_idx, sequence_cfg):
        """
        Args:
            info:
            points:
            sequence_name:
            sample_idx:
            sequence_cfg:
        Returns:
        """

        def remove_ego_points(points, center_radius=1.0):
            mask = ~((np.abs(points[:, 0]) < center_radius) & (np.abs(points[:, 1]) < center_radius))
            return points[mask]

        pose_cur = info['pose'].reshape((4, 4))
        num_pts_cur = points.shape[0]
        sample_idx_pre_list = np.clip(sample_idx + np.arange(
            sequence_cfg.SAMPLE_OFFSET[0], sequence_cfg.SAMPLE_OFFSET[1]), 0, 0x7FFFFFFF)
        if sequence_cfg.get('ONEHOT_TIMESTAMP', False):
            onehot_cur = np.zeros((points.shape[0], len(sample_idx_pre_list) + 1)).astype(points.dtype)
            onehot_cur[:, 0] = 1
            points = np.hstack([points, onehot_cur])
        else:
            points = np.hstack([points, np.zeros((points.shape[0], 1)).astype(points.dtype)])
        points_pre_all = []
        num_points_pre = []

        sequence_info = self.seq_name_to_infos[sequence_name]

        for i, sample_idx_pre in enumerate(sample_idx_pre_list):
            if sample_idx == sample_idx_pre:
                continue

            points_pre = self.get_lidar(sequence_name, sample_idx_pre)
            pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4))
            expand_points_pre = np.concatenate([points_pre[:, :3], np.ones((points_pre.shape[0], 1))], axis=-1)
            points_pre_global = np.dot(expand_points_pre, pose_pre.T)[:, :3]
            expand_points_pre_global = np.concatenate([points_pre_global,
                                                       np.ones((points_pre_global.shape[0], 1))], axis=-1)
            points_pre2cur = np.dot(expand_points_pre_global, np.linalg.inv(pose_cur.T))[:, :3]
            points_pre = np.concatenate([points_pre2cur, points_pre[:, 3:]], axis=-1)
            if sequence_cfg.get('ONEHOT_TIMESTAMP', False):
                onehot_vector = np.zeros((points_pre.shape[0], len(sample_idx_pre_list) + 1))
                onehot_vector[:, i + 1] = 1
                points_pre = np.hstack([points_pre, onehot_vector])
            else:
                # add timestamp
                points_pre = np.hstack([points_pre, 0.1 * (sample_idx - sample_idx_pre)
                                        * np.ones((points_pre.shape[0], 1)).astype(points_pre.dtype)])  # one frame 0.1s
            points_pre = remove_ego_points(points_pre, 1.0)
            points_pre_all.append(points_pre)
            num_points_pre.append(points_pre.shape[0])
228
        points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32)
229
        num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int32)
230
231
        return points, num_points_all, sample_idx_pre_list

232
233
234
235
236
237
238
239
240
241
242
243
244
245
    def __len__(self):
        if self._merge_all_iters_to_one_epoch:
            return len(self.infos) * self.total_epochs

        return len(self.infos)

    def __getitem__(self, index):
        if self._merge_all_iters_to_one_epoch:
            index = index % len(self.infos)

        info = copy.deepcopy(self.infos[index])
        pc_info = info['point_cloud']
        sequence_name = pc_info['lidar_sequence']
        sample_idx = pc_info['sample_idx']
246
247
248
249
250
251

        if self.use_shared_memory and index < self.shared_memory_file_limit:
            sa_key = f'{sequence_name}___{sample_idx}'
            points = SharedArray.attach(f"shm://{sa_key}").copy()
        else:
            points = self.get_lidar(sequence_name, sample_idx)
252

253
254
255
256
257
        if self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED:
            points, num_points_all, sample_idx_pre_list = self.get_sequence_data(
                info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
            )

258
259
260
261
262
263
264
265
266
267
268
269
270
271
        input_dict = {
            'points': points,
            'frame_id': info['frame_id'],
        }

        if 'annos' in info:
            annos = info['annos']
            annos = common_utils.drop_info_with_name(annos, name='unknown')

            if self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False):
                gt_boxes_lidar = box_utils.boxes3d_kitti_fakelidar_to_lidar(annos['gt_boxes_lidar'])
            else:
                gt_boxes_lidar = annos['gt_boxes_lidar']

272
273
274
275
            if self.dataset_cfg.get('TRAIN_WITH_SPEED', False):
                assert gt_boxes_lidar.shape[-1] == 9
            else:
                gt_boxes_lidar = gt_boxes_lidar[:, 0:7]
276

277
278
279
280
281
282
            if self.training and self.dataset_cfg.get('FILTER_EMPTY_BOXES_FOR_TRAIN', False):
                mask = (annos['num_points_in_gt'] > 0)  # filter empty boxes
                annos['name'] = annos['name'][mask]
                gt_boxes_lidar = gt_boxes_lidar[mask]
                annos['num_points_in_gt'] = annos['num_points_in_gt'][mask]

283
284
285
286
287
288
289
290
291
292
293
            input_dict.update({
                'gt_names': annos['name'],
                'gt_boxes': gt_boxes_lidar,
                'num_points_in_gt': annos.get('num_points_in_gt', None)
            })

        data_dict = self.prepare_data(data_dict=input_dict)
        data_dict['metadata'] = info.get('metadata', info['frame_id'])
        data_dict.pop('num_points_in_gt', None)
        return data_dict

294
    def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
295
296
297
298
299
        """
        Args:
            batch_dict:
                frame_id:
            pred_dicts: list of pred_dicts
300
                pred_boxes: (N, 7 or 9), Tensor
301
302
303
304
305
306
307
308
309
310
                pred_scores: (N), Tensor
                pred_labels: (N), Tensor
            class_names:
            output_path:

        Returns:

        """

        def get_template_prediction(num_samples):
311
            box_dim = 9 if self.dataset_cfg.get('TRAIN_WITH_SPEED', False) else 7
312
313
            ret_dict = {
                'name': np.zeros(num_samples), 'score': np.zeros(num_samples),
314
                'boxes_lidar': np.zeros([num_samples, box_dim])
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
            }
            return ret_dict

        def generate_single_sample_dict(box_dict):
            pred_scores = box_dict['pred_scores'].cpu().numpy()
            pred_boxes = box_dict['pred_boxes'].cpu().numpy()
            pred_labels = box_dict['pred_labels'].cpu().numpy()
            pred_dict = get_template_prediction(pred_scores.shape[0])
            if pred_scores.shape[0] == 0:
                return pred_dict

            pred_dict['name'] = np.array(class_names)[pred_labels - 1]
            pred_dict['score'] = pred_scores
            pred_dict['boxes_lidar'] = pred_boxes

            return pred_dict

        annos = []
        for index, box_dict in enumerate(pred_dicts):
            single_pred_dict = generate_single_sample_dict(box_dict)
            single_pred_dict['frame_id'] = batch_dict['frame_id'][index]
            single_pred_dict['metadata'] = batch_dict['metadata'][index]
            annos.append(single_pred_dict)

        return annos

    def evaluation(self, det_annos, class_names, **kwargs):
        if 'annos' not in self.infos[0].keys():
            return 'No ground-truth boxes for evaluation', {}

        def kitti_eval(eval_det_annos, eval_gt_annos):
            from ..kitti.kitti_object_eval_python import eval as kitti_eval
            from ..kitti import kitti_utils

            map_name_to_kitti = {
                'Vehicle': 'Car',
                'Pedestrian': 'Pedestrian',
                'Cyclist': 'Cyclist',
                'Sign': 'Sign',
                'Car': 'Car'
            }
            kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
            kitti_utils.transform_annotations_to_kitti_format(
                eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
                info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            kitti_class_names = [map_name_to_kitti[x] for x in class_names]
            ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
                gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
            )
            return ap_result_str, ap_dict

        def waymo_eval(eval_det_annos, eval_gt_annos):
            from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator
            eval = OpenPCDetWaymoDetectionMetricsEstimator()

            ap_dict = eval.waymo_evaluation(
                eval_det_annos, eval_gt_annos, class_name=class_names,
                distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            ap_result_str = '\n'
            for key in ap_dict:
                ap_dict[key] = ap_dict[key][0]
                ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])

            return ap_result_str, ap_dict

        eval_det_annos = copy.deepcopy(det_annos)
        eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.infos]

        if kwargs['eval_metric'] == 'kitti':
            ap_result_str, ap_dict = kitti_eval(eval_det_annos, eval_gt_annos)
        elif kwargs['eval_metric'] == 'waymo':
            ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos)
        else:
            raise NotImplementedError

        return ap_result_str, ap_dict

    def create_groundtruth_database(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
                                    processed_data_tag=None):
396
397
398
399
400
401
402
403
404
405
406
407
408
409

        use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED

        if use_sequence_data:
            st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
            st_frame = min(-4, st_frame)  # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
            database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
            db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
            db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
        else:
            database_save_path = save_path / ('%s_gt_database_%s_sampled_%d' % (processed_data_tag, split, sampled_interval))
            db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d.pkl' % (processed_data_tag, split, sampled_interval))
            db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_global.npy' % (processed_data_tag, split, sampled_interval))

410
411
412
413
414
        database_save_path.mkdir(parents=True, exist_ok=True)
        all_db_infos = {}
        with open(info_path, 'rb') as f:
            infos = pickle.load(f)

415
416
        point_offset_cnt = 0
        stacked_gt_points = []
417
        for k in tqdm(range(0, len(infos), sampled_interval)):
418
419
420
421
422
423
424
425
            print('gt_database sample: %d/%d' % (k + 1, len(infos)))
            info = infos[k]

            pc_info = info['point_cloud']
            sequence_name = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']
            points = self.get_lidar(sequence_name, sample_idx)

426
427
428
429
430
            if use_sequence_data:
                points, num_points_all, sample_idx_pre_list = self.get_sequence_data(
                    info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
                )

431
432
433
434
435
            annos = info['annos']
            names = annos['name']
            difficulty = annos['difficulty']
            gt_boxes = annos['gt_boxes_lidar']

436
437
438
439
440
441
442
443
444
445
446
447
            if k % 4 != 0 and len(names) > 0:
                mask = (names == 'Vehicle')
                names = names[~mask]
                difficulty = difficulty[~mask]
                gt_boxes = gt_boxes[~mask]

            if k % 2 != 0 and len(names) > 0:
                mask = (names == 'Pedestrian')
                names = names[~mask]
                difficulty = difficulty[~mask]
                gt_boxes = gt_boxes[~mask]

448
            num_obj = gt_boxes.shape[0]
449
450
            if num_obj == 0:
                continue
451
452
453
454
455
456
457
458
459
460
461
462
463

            box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
                torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
                torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
            ).long().squeeze(dim=0).cpu().numpy()

            for i in range(num_obj):
                filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
                filepath = database_save_path / filename
                gt_points = points[box_idxs_of_pts == i]
                gt_points[:, :3] -= gt_boxes[i, :3]

                if (used_classes is None) or names[i] in used_classes:
464
465
                    gt_points = gt_points.astype(np.float32)
                    assert gt_points.dtype == np.float32
466
467
468
469
470
471
472
                    with open(filepath, 'w') as f:
                        gt_points.tofile(f)

                    db_path = str(filepath.relative_to(self.root_path))  # gt_database/xxxxx.bin
                    db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
                               'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
                               'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}
473
474
475
476
477
478

                    # it will be used if you choose to use shared memory for gt sampling
                    stacked_gt_points.append(gt_points)
                    db_info['global_data_offset'] = [point_offset_cnt, point_offset_cnt + gt_points.shape[0]]
                    point_offset_cnt += gt_points.shape[0]

479
480
481
482
483
484
485
486
487
488
                    if names[i] in all_db_infos:
                        all_db_infos[names[i]].append(db_info)
                    else:
                        all_db_infos[names[i]] = [db_info]
        for k, v in all_db_infos.items():
            print('Database %s: %d' % (k, len(v)))

        with open(db_info_save_path, 'wb') as f:
            pickle.dump(all_db_infos, f)

489
490
491
492
        # it will be used if you choose to use shared memory for gt sampling
        stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
        np.save(db_data_save_path, stacked_gt_points)

493
494
    def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, 
                                           total_samples=0, use_cuda=False, crop_gt_with_tail=False):
495
        info, info_idx = info_with_idx
496
        print('gt_database sample: %d/%d' % (info_idx, total_samples))
497

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        all_db_infos = {}
        pc_info = info['point_cloud']
        sequence_name = pc_info['lidar_sequence']
        sample_idx = pc_info['sample_idx']
        points = self.get_lidar(sequence_name, sample_idx)

        if use_sequence_data:
            points, num_points_all, sample_idx_pre_list = self.get_sequence_data(
                info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
            )

        annos = info['annos']
        names = annos['name']
        difficulty = annos['difficulty']
        gt_boxes = annos['gt_boxes_lidar']

        if info_idx % 4 != 0 and len(names) > 0:
            mask = (names == 'Vehicle')
            names = names[~mask]
            difficulty = difficulty[~mask]
            gt_boxes = gt_boxes[~mask]

        if info_idx % 2 != 0 and len(names) > 0:
            mask = (names == 'Pedestrian')
            names = names[~mask]
            difficulty = difficulty[~mask]
            gt_boxes = gt_boxes[~mask]

        num_obj = gt_boxes.shape[0]
        if num_obj == 0:
            return {}
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        
        if use_sequence_data and crop_gt_with_tail:
            assert gt_boxes.shape[1] == 9
            speed = gt_boxes[:, 7:9]
            sequence_cfg = self.dataset_cfg.SEQUENCE_CONFIG
            assert sequence_cfg.SAMPLE_OFFSET[1] == 0
            assert sequence_cfg.SAMPLE_OFFSET[0] < 0
            num_frames = sequence_cfg.SAMPLE_OFFSET[1] - sequence_cfg.SAMPLE_OFFSET[0] + 1
            assert num_frames > 1
            latest_center = gt_boxes[:, 0:2]
            oldest_center = latest_center - speed * (num_frames - 1) * 0.1
            new_center = (latest_center + oldest_center) * 0.5
            new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1) 
            gt_boxes_crop = gt_boxes.copy()
            gt_boxes_crop[:, 0:2] = new_center
            gt_boxes_crop[:, 3] = new_length 
            
        else:
            gt_boxes_crop = gt_boxes
548
549
550
551

        if use_cuda:
            box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
                torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
552
                torch.from_numpy(gt_boxes_crop[:, 0:7]).unsqueeze(dim=0).float().cuda()
553
554
555
556
            ).long().squeeze(dim=0).cpu().numpy()
        else:
            box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
                torch.from_numpy(points[:, 0:3]).float(),
557
                torch.from_numpy(gt_boxes_crop[:, 0:7]).float()
558
559
560
561
562
563
564
565
566
            ).long().numpy()  # (num_boxes, num_points)

        for i in range(num_obj):
            filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
            filepath = database_save_path / filename
            if use_cuda:
                gt_points = points[box_idxs_of_pts == i]
            else:
                gt_points = points[box_point_mask[i] > 0]
567

568
569
570
            gt_points[:, :3] -= gt_boxes[i, :3]

            if (used_classes is None) or names[i] in used_classes:
571
572
                gt_points = gt_points.astype(np.float32)
                assert gt_points.dtype == np.float32
573
574
575
576
577
578
                with open(filepath, 'w') as f:
                    gt_points.tofile(f)

                db_path = str(filepath.relative_to(self.root_path))  # gt_database/xxxxx.bin
                db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
                            'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
579
580
                            'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i], 
                            'box3d_crop': gt_boxes_crop[i]}
581
582
583
584
585
586
587

                if names[i] in all_db_infos:
                    all_db_infos[names[i]].append(db_info)
                else:
                    all_db_infos[names[i]] = [db_info]
        return all_db_infos

588
589
    def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, 
                                             processed_data_tag=None, num_workers=16, crop_gt_with_tail=False):
590
591
592
        use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
        if use_sequence_data:
            st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
593
594
595
596
597
            self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4, st_frame)  # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
            st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
            database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sparallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
            db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_%sparallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
            db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sglobal_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
598
599
600
601
602
603
604
605
606
607
        else:
            database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
            db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
            db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_global_parallel.npy' % (processed_data_tag, split, sampled_interval))

        database_save_path.mkdir(parents=True, exist_ok=True)

        with open(info_path, 'rb') as f:
            infos = pickle.load(f)

608
        print(f'Number workers: {num_workers}')
609
        create_gt_database_of_single_scene = partial(
610
611
            self.create_gt_database_of_single_scene,
            use_sequence_data=use_sequence_data, database_save_path=database_save_path,
612
613
            used_classes=used_classes, total_samples=len(infos), use_cuda=False, 
            crop_gt_with_tail=crop_gt_with_tail
614
        )
615
        # create_gt_database_of_single_scene((infos[300], 0))
616
        with multiprocessing.Pool(num_workers) as p:
617
            all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))))
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632

        all_db_infos = {}

        for cur_db_infos in all_db_infos_list:
            for key, val in cur_db_infos.items():
                if key not in all_db_infos:
                    all_db_infos[key] = val
                else:
                    all_db_infos[key].extend(val)

        for k, v in all_db_infos.items():
            print('Database %s: %d' % (k, len(v)))

        with open(db_info_save_path, 'wb') as f:
            pickle.dump(all_db_infos, f)
633
634

def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
635
                       raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
636
                       workers=min(16, multiprocessing.cpu_count()), update_info_only=False):
637
638
639
640
641
642
    dataset = WaymoDataset(
        dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
        training=False, logger=common_utils.create_logger()
    )
    train_split, val_split = 'train', 'val'

643
644
    train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
    val_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, val_split))
645

646
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
647
648
649
650
651
652
    print('---------------Start to generate data infos---------------')

    dataset.set_split(train_split)
    waymo_infos_train = dataset.get_infos(
        raw_data_path=data_path / raw_data_tag,
        save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
653
        sampled_interval=1, update_info_only=update_info_only
654
655
656
    )
    with open(train_filename, 'wb') as f:
        pickle.dump(waymo_infos_train, f)
Shaoshuai Shi's avatar
Shaoshuai Shi committed
657
    print('----------------Waymo info train file is saved to %s----------------' % train_filename)
658
659
660
661
662

    dataset.set_split(val_split)
    waymo_infos_val = dataset.get_infos(
        raw_data_path=data_path / raw_data_tag,
        save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
663
        sampled_interval=1, update_info_only=update_info_only
664
665
666
    )
    with open(val_filename, 'wb') as f:
        pickle.dump(waymo_infos_val, f)
Shaoshuai Shi's avatar
Shaoshuai Shi committed
667
    print('----------------Waymo info val file is saved to %s----------------' % val_filename)
668

669
670
671
    if update_info_only:
        return

672
    print('---------------Start create groundtruth database for data augmentation---------------')
673
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
674
675
    dataset.set_split(train_split)
    dataset.create_groundtruth_database(
676
677
        info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
        used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
678
679
680
681
    )
    print('---------------Data preparation Done---------------')


682
683
def create_waymo_gt_database(
    dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
684
    workers=min(16, multiprocessing.cpu_count()), use_parallel=False, crop_gt_with_tail=False):
685
686
    dataset = WaymoDataset(
        dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
Shaoshuai Shi's avatar
Shaoshuai Shi committed
687
        training=False, logger=common_utils.create_logger()
688
689
690
691
692
693
694
    )
    train_split = 'train'
    train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))

    print('---------------Start create groundtruth database for data augmentation---------------')
    dataset.set_split(train_split)

695
696
697
698
    if use_parallel:
        dataset.create_groundtruth_database_parallel(
            info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
            used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
699
            num_workers=workers, crop_gt_with_tail=crop_gt_with_tail
700
701
702
703
704
705
        )
    else:
        dataset.create_groundtruth_database(
            info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
            used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
        )
706
707
708
    print('---------------Data preparation Done---------------')


709
if __name__ == '__main__':
Shaoshuai Shi's avatar
Shaoshuai Shi committed
710
    import argparse
711
712
    import yaml
    from easydict import EasyDict
Shaoshuai Shi's avatar
Shaoshuai Shi committed
713
714
715
716

    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--cfg_file', type=str, default=None, help='specify the config of dataset')
    parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
717
    parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
718
    parser.add_argument('--update_info_only', action='store_true', default=False, help='')
719
    parser.add_argument('--use_parallel', action='store_true', default=False, help='')
720
    parser.add_argument('--crop_gt_with_tail', action='store_true', default=False, help='')
721

Shaoshuai Shi's avatar
Shaoshuai Shi committed
722
723
    args = parser.parse_args()

724
725
    ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()

Shaoshuai Shi's avatar
Shaoshuai Shi committed
726
    if args.func == 'create_waymo_infos':
727
        try:
728
            yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
729
        except:
730
            yaml_config = yaml.safe_load(open(args.cfg_file))
731
        dataset_cfg = EasyDict(yaml_config)
732
        dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
733
734
735
736
737
        create_waymo_infos(
            dataset_cfg=dataset_cfg,
            class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
            data_path=ROOT_DIR / 'data' / 'waymo',
            save_path=ROOT_DIR / 'data' / 'waymo',
738
            raw_data_tag='raw_data',
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
            processed_data_tag=args.processed_data_tag,
            update_info_only=args.update_info_only
        )
    elif args.func == 'create_waymo_gt_database':
        try:
            yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
        except:
            yaml_config = yaml.safe_load(open(args.cfg_file))
        dataset_cfg = EasyDict(yaml_config)
        dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
        create_waymo_gt_database(
            dataset_cfg=dataset_cfg,
            class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
            data_path=ROOT_DIR / 'data' / 'waymo',
            save_path=ROOT_DIR / 'data' / 'waymo',
754
            processed_data_tag=args.processed_data_tag,
755
756
            use_parallel=args.use_parallel, 
            crop_gt_with_tail=args.crop_gt_with_tail
757
        )
758
759
    else:
        raise NotImplementedError