dataset.py 12 KB
Newer Older
1
from collections import defaultdict
Shaoshuai Shi's avatar
Shaoshuai Shi committed
2
3
from pathlib import Path

4
5
import numpy as np
import torch.utils.data as torch_data
Shaoshuai Shi's avatar
Shaoshuai Shi committed
6
7

from ..utils import common_utils
8
9
10
11
from .augmentor.data_augmentor import DataAugmentor
from .processor.data_processor import DataProcessor
from .processor.point_feature_encoder import PointFeatureEncoder

jihanyang's avatar
jihanyang committed
12

13
14
15
16
17
18
19
20
21
class DatasetTemplate(torch_data.Dataset):
    def __init__(self, dataset_cfg=None, class_names=None, training=True, root_path=None, logger=None):
        super().__init__()
        self.dataset_cfg = dataset_cfg
        self.training = training
        self.class_names = class_names
        self.logger = logger
        self.root_path = root_path if root_path is not None else Path(self.dataset_cfg.DATA_PATH)
        self.logger = logger
22
        if self.dataset_cfg is None or class_names is None:
23
24
25
26
27
28
29
30
31
32
33
            return

        self.point_cloud_range = np.array(self.dataset_cfg.POINT_CLOUD_RANGE, dtype=np.float32)
        self.point_feature_encoder = PointFeatureEncoder(
            self.dataset_cfg.POINT_FEATURE_ENCODING,
            point_cloud_range=self.point_cloud_range
        )
        self.data_augmentor = DataAugmentor(
            self.root_path, self.dataset_cfg.DATA_AUGMENTOR, self.class_names, logger=self.logger
        ) if self.training else None
        self.data_processor = DataProcessor(
acivgin1's avatar
acivgin1 committed
34
35
            self.dataset_cfg.DATA_PROCESSOR, point_cloud_range=self.point_cloud_range,
            training=self.training, num_point_features=self.point_feature_encoder.num_point_features
36
37
38
39
        )

        self.grid_size = self.data_processor.grid_size
        self.voxel_size = self.data_processor.voxel_size
40
41
        self.total_epochs = 0
        self._merge_all_iters_to_one_epoch = False
42

43
44
45
46
        if hasattr(self.data_processor, "depth_downsample_factor"):
            self.depth_downsample_factor = self.data_processor.depth_downsample_factor
        else:
            self.depth_downsample_factor = None
yukang's avatar
yukang committed
47
            
48
49
50
51
    @property
    def mode(self):
        return 'train' if self.training else 'test'

Gus-Guo's avatar
Gus-Guo committed
52
53
54
55
56
57
58
59
    def __getstate__(self):
        d = dict(self.__dict__)
        del d['logger']
        return d

    def __setstate__(self, d):
        self.__dict__.update(d)

Shaoshuai Shi's avatar
Shaoshuai Shi committed
60
    def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
61
62
        """
        Args:
jihanyang's avatar
jihanyang committed
63
64
65
            batch_dict:
                frame_id:
            pred_dicts: list of pred_dicts
Shaoshuai Shi's avatar
Shaoshuai Shi committed
66
                pred_boxes: (N, 7 or 9), Tensor
67
68
69
                pred_scores: (N), Tensor
                pred_labels: (N), Tensor
            class_names:
jihanyang's avatar
jihanyang committed
70
71
            output_path:

72
73
74
        Returns:

        """
75
        
jihanyang's avatar
jihanyang committed
76
        def get_template_prediction(num_samples):
Shaoshuai Shi's avatar
Shaoshuai Shi committed
77
            box_dim = 9 if self.dataset_cfg.get('TRAIN_WITH_SPEED', False) else 7
jihanyang's avatar
jihanyang committed
78
79
            ret_dict = {
                'name': np.zeros(num_samples), 'score': np.zeros(num_samples),
Shaoshuai Shi's avatar
Shaoshuai Shi committed
80
                'boxes_lidar': np.zeros([num_samples, box_dim]), 'pred_labels': np.zeros(num_samples)
jihanyang's avatar
jihanyang committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            }
            return ret_dict

        def generate_single_sample_dict(box_dict):
            pred_scores = box_dict['pred_scores'].cpu().numpy()
            pred_boxes = box_dict['pred_boxes'].cpu().numpy()
            pred_labels = box_dict['pred_labels'].cpu().numpy()
            pred_dict = get_template_prediction(pred_scores.shape[0])
            if pred_scores.shape[0] == 0:
                return pred_dict

            pred_dict['name'] = np.array(class_names)[pred_labels - 1]
            pred_dict['score'] = pred_scores
            pred_dict['boxes_lidar'] = pred_boxes
            pred_dict['pred_labels'] = pred_labels

            return pred_dict

        annos = []
        for index, box_dict in enumerate(pred_dicts):
            single_pred_dict = generate_single_sample_dict(box_dict)
            single_pred_dict['frame_id'] = batch_dict['frame_id'][index]
            if 'metadata' in batch_dict:
                single_pred_dict['metadata'] = batch_dict['metadata'][index]
            annos.append(single_pred_dict)

        return annos

109
110
111
112
113
114
115
    def merge_all_iters_to_one_epoch(self, merge=True, epochs=None):
        if merge:
            self._merge_all_iters_to_one_epoch = True
            self.total_epochs = epochs
        else:
            self._merge_all_iters_to_one_epoch = False

116
117
118
    def __len__(self):
        raise NotImplementedError

119
120
121
122
123
124
125
126
127
128
129
130
    def __getitem__(self, index):
        """
        To support a custom dataset, implement this function to load the raw data (and labels), then transform them to
        the unified normative coordinate and call the function self.prepare_data() to process the data and send them
        to the model.

        Args:
            index:

        Returns:

        """
131
132
133
134
135
136
        raise NotImplementedError

    def prepare_data(self, data_dict):
        """
        Args:
            data_dict:
137
                points: optional, (N, 3 + C_in)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
                gt_names: optional, (N), string
                ...

        Returns:
            data_dict:
                frame_id: string
                points: (N, 3 + C_in)
                gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
                gt_names: optional, (N), string
                use_lead_xyz: bool
                voxels: optional (num_voxels, max_points_per_voxel, 3 + C)
                voxel_coords: optional (num_voxels, 3)
                voxel_num_points: optional (num_voxels)
                ...
        """
        if self.training:
            assert 'gt_boxes' in data_dict, 'gt_boxes should be provided for training'
            gt_boxes_mask = np.array([n in self.class_names for n in data_dict['gt_names']], dtype=np.bool_)
yukang's avatar
yukang committed
157
158
159
            
            if 'calib' in data_dict:
                calib = data_dict['calib']
160
161
162
163
164
165
            data_dict = self.data_augmentor.forward(
                data_dict={
                    **data_dict,
                    'gt_boxes_mask': gt_boxes_mask
                }
            )
yukang's avatar
yukang committed
166
167
            if 'calib' in data_dict:
                data_dict['calib'] = calib
168
169
170
171
172
173
174
175
        if data_dict.get('gt_boxes', None) is not None:
            selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
            data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
            data_dict['gt_names'] = data_dict['gt_names'][selected]
            gt_classes = np.array([self.class_names.index(n) + 1 for n in data_dict['gt_names']], dtype=np.int32)
            gt_boxes = np.concatenate((data_dict['gt_boxes'], gt_classes.reshape(-1, 1).astype(np.float32)), axis=1)
            data_dict['gt_boxes'] = gt_boxes

176
177
            if data_dict.get('gt_boxes2d', None) is not None:
                data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][selected]
jihanyang's avatar
jihanyang committed
178

179
180
        if data_dict.get('points', None) is not None:
            data_dict = self.point_feature_encoder.forward(data_dict)
181
182
183
184

        data_dict = self.data_processor.forward(
            data_dict=data_dict
        )
185

186
        if self.training and len(data_dict['gt_boxes']) == 0:
187
188
189
            new_index = np.random.randint(self.__len__())
            return self.__getitem__(new_index)

190
        data_dict.pop('gt_names', None)
191
192
193
194
195
196
197
198
199
200
201

        return data_dict

    @staticmethod
    def collate_batch(batch_list, _unused=False):
        data_dict = defaultdict(list)
        for cur_sample in batch_list:
            for key, val in cur_sample.items():
                data_dict[key].append(val)
        batch_size = len(batch_list)
        ret = {}
yukang's avatar
yukang committed
202
        batch_size_ratio = 1
203
204

        for key, val in data_dict.items():
205
206
            try:
                if key in ['voxels', 'voxel_num_points']:
yukang's avatar
yukang committed
207
208
209
                    if isinstance(val[0], list):
                        batch_size_ratio = len(val[0])
                        val = [i for item in val for i in item]
210
211
212
                    ret[key] = np.concatenate(val, axis=0)
                elif key in ['points', 'voxel_coords']:
                    coors = []
yukang's avatar
yukang committed
213
214
                    if isinstance(val[0], list):
                        val =  [i for item in val for i in item]
215
216
217
218
219
220
221
222
223
224
                    for i, coor in enumerate(val):
                        coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)
                        coors.append(coor_pad)
                    ret[key] = np.concatenate(coors, axis=0)
                elif key in ['gt_boxes']:
                    max_gt = max([len(x) for x in val])
                    batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32)
                    for k in range(batch_size):
                        batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
                    ret[key] = batch_gt_boxes3d
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

                elif key in ['roi_boxes']:
                    max_gt = max([x.shape[1] for x in val])
                    batch_gt_boxes3d = np.zeros((batch_size, val[0].shape[0], max_gt, val[0].shape[-1]), dtype=np.float32)
                    for k in range(batch_size):
                        batch_gt_boxes3d[k,:, :val[k].shape[1], :] = val[k]
                    ret[key] = batch_gt_boxes3d

                elif key in ['roi_scores', 'roi_labels']:
                    max_gt = max([x.shape[1] for x in val])
                    batch_gt_boxes3d = np.zeros((batch_size, val[0].shape[0], max_gt), dtype=np.float32)
                    for k in range(batch_size):
                        batch_gt_boxes3d[k,:, :val[k].shape[1]] = val[k]
                    ret[key] = batch_gt_boxes3d

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                elif key in ['gt_boxes2d']:
                    max_boxes = 0
                    max_boxes = max([len(x) for x in val])
                    batch_boxes2d = np.zeros((batch_size, max_boxes, val[0].shape[-1]), dtype=np.float32)
                    for k in range(batch_size):
                        if val[k].size > 0:
                            batch_boxes2d[k, :val[k].__len__(), :] = val[k]
                    ret[key] = batch_boxes2d
                elif key in ["images", "depth_maps"]:
                    # Get largest image size (H, W)
                    max_h = 0
                    max_w = 0
                    for image in val:
                        max_h = max(max_h, image.shape[0])
                        max_w = max(max_w, image.shape[1])

                    # Change size of images
                    images = []
                    for image in val:
                        pad_h = common_utils.get_pad_params(desired_size=max_h, cur_size=image.shape[0])
                        pad_w = common_utils.get_pad_params(desired_size=max_w, cur_size=image.shape[1])
                        pad_width = (pad_h, pad_w)
262
                        pad_value = 0
263
264
265
266
267
268
269
270
271
272
273
274
275

                        if key == "images":
                            pad_width = (pad_h, pad_w, (0, 0))
                        elif key == "depth_maps":
                            pad_width = (pad_h, pad_w)

                        image_pad = np.pad(image,
                                           pad_width=pad_width,
                                           mode='constant',
                                           constant_values=pad_value)

                        images.append(image_pad)
                    ret[key] = np.stack(images, axis=0)
yukang.chen's avatar
yukang.chen committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                elif key in ['calib']:
                    ret[key] = val
                elif key in ["points_2d"]:
                    max_len = max([len(_val) for _val in val])
                    pad_value = 0
                    points = []
                    for _points in val:
                        pad_width = ((0, max_len-len(_points)), (0,0))
                        points_pad = np.pad(_points,
                                pad_width=pad_width,
                                mode='constant',
                                constant_values=pad_value)
                        points.append(points_pad)
                    ret[key] = np.stack(points, axis=0)
290
291
292
293
294
                else:
                    ret[key] = np.stack(val, axis=0)
            except:
                print('Error in collate_batch: key=%s' % key)
                raise TypeError
295

yukang's avatar
yukang committed
296
        ret['batch_size'] = batch_size * batch_size_ratio
297
        return ret