prediction_to_waymo.py 4.69 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Wenwei Zhang's avatar
Wenwei Zhang committed
2
3
4
5
r"""Adapted from `Waymo to KITTI converter
    <https://github.com/caizhongang/waymo_kitti_converter>`_.
"""

6
try:
7
8
9
    from waymo_open_dataset import label_pb2
    from waymo_open_dataset.protos import metrics_pb2
    from waymo_open_dataset.protos.metrics_pb2 import Objects
10
except ImportError:
11
    Objects = None
12
13
14
15
    raise ImportError(
        'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
        'to install the official devkit first.')

16
from typing import List
17

18
import mmengine
19
from mmengine import print_log
Wenwei Zhang's avatar
Wenwei Zhang committed
20
21


22
23
24
class Prediction2Waymo(object):
    """Predictions to Waymo converter. The format of prediction results could
    be original format or kitti-format.
Wenwei Zhang's avatar
Wenwei Zhang committed
25
26
27
28
29

    This class serves as the converter to change predictions from KITTI to
    Waymo format.

    Args:
30
        results (list[dict]): Prediction results.
Wenwei Zhang's avatar
Wenwei Zhang committed
31
32
33
34
        waymo_results_save_dir (str): Directory to save converted predictions
            in waymo format (.bin files).
        waymo_results_final_path (str): Path to save combined
            predictions in waymo format (.bin file), like 'a/b/c.bin'.
35
        num_workers (str): Number of parallel processes. Defaults to 4.
Wenwei Zhang's avatar
Wenwei Zhang committed
36
37
38
    """

    def __init__(self,
39
40
41
                 results: List[dict],
                 waymo_results_final_path: str,
                 classes: dict,
42
                 num_workers: int = 4):
43
        self.results = results
Wenwei Zhang's avatar
Wenwei Zhang committed
44
        self.waymo_results_final_path = waymo_results_final_path
45
        self.classes = classes
46
        self.num_workers = num_workers
Wenwei Zhang's avatar
Wenwei Zhang committed
47
48
49
50
51
52
53
54

        self.k2w_cls_map = {
            'Car': label_pb2.Label.TYPE_VEHICLE,
            'Pedestrian': label_pb2.Label.TYPE_PEDESTRIAN,
            'Sign': label_pb2.Label.TYPE_SIGN,
            'Cyclist': label_pb2.Label.TYPE_CYCLIST,
        }

55
    def convert_one(self, res_idx: int):
56
57
58
59
        """Convert action for single file. It read the metainfo from the
        preprocessed file offline and will be faster.

        Args:
60
            res_idx (int): The indices of the results.
61
        """
62
63
        sample_idx = self.results[res_idx]['sample_idx']
        if len(self.results[res_idx]['labels_3d']) > 0:
64
            objects = self.parse_objects_from_origin(
65
66
                self.results[res_idx], self.results[res_idx]['context_name'],
                self.results[res_idx]['timestamp'])
67
68
69
70
        else:
            print(sample_idx, 'not found.')
            objects = metrics_pb2.Objects()

71
        return objects
72
73
74
75
76
77
78
79
80
81
82
83
84

    def parse_objects_from_origin(self, result: dict, contextname: str,
                                  timestamp: str) -> Objects:
        """Parse obejcts from the original prediction results.

        Args:
            result (dict): The original prediction results.
            contextname (str): The ``contextname`` of sample in waymo.
            timestamp (str): The ``timestamp`` of sample in waymo.

        Returns:
            metrics_pb2.Objects: The parsed object.
        """
85
86
87
        lidar_boxes = result['bboxes_3d']
        scores = result['scores_3d']
        labels = result['labels_3d']
88

89
90
91
        objects = metrics_pb2.Objects()
        for lidar_box, score, label in zip(lidar_boxes, scores, labels):
            # Parse one object
92
            box = label_pb2.Label.Box()
93
94
95
96
97
98
99
100
            height = lidar_box[5]
            heading = lidar_box[6]

            box.center_x = lidar_box[0]
            box.center_y = lidar_box[1]
            box.center_z = lidar_box[2] + height / 2
            box.length = lidar_box[3]
            box.width = lidar_box[4]
101
102
103
            box.height = height
            box.heading = heading

104
105
            object = metrics_pb2.Object()
            object.object.box.CopyFrom(box)
106

107
108
109
110
111
112
            class_name = self.classes[label]
            object.object.type = self.k2w_cls_map[class_name]
            object.score = score
            object.context_name = contextname
            object.frame_timestamp_micros = timestamp
            objects.objects.append(object)
113
114
115

        return objects

Wenwei Zhang's avatar
Wenwei Zhang committed
116
117
    def convert(self):
        """Convert action."""
118
        print_log('Start converting ...', logger='current')
119

120
121
122
        # TODO: use parallel processes.
        # objects_list = mmengine.track_parallel_progress(
        #     self.convert_one, range(len(self)), self.num_workers)
123

124
125
        objects_list = mmengine.track_progress(self.convert_one,
                                               range(len(self)))
126

127
128
129
130
        combined = metrics_pb2.Objects()
        for objects in objects_list:
            for o in objects.objects:
                combined.objects.append(o)
Wenwei Zhang's avatar
Wenwei Zhang committed
131
132
133
134
135
136

        with open(self.waymo_results_final_path, 'wb') as f:
            f.write(combined.SerializeToString())

    def __len__(self):
        """Length of the filename list."""
137
        return len(self.results)