monoflex_bbox_coder.py 19.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import numpy as np
import torch
4
from mmdet.models.task_modules import BaseBBoxCoder
5
6
from torch.nn import functional as F

7
from mmdet3d.registry import TASK_UTILS
8
9


10
@TASK_UTILS.register_module()
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class MonoFlexCoder(BaseBBoxCoder):
    """Bbox Coder for MonoFlex.

    Args:
        depth_mode (str): The mode for depth calculation.
            Available options are "linear", "inv_sigmoid", and "exp".
        base_depth (tuple[float]): References for decoding box depth.
        depth_range (list): Depth range of predicted depth.
        combine_depth (bool): Whether to use combined depth (direct depth
            and depth from keypoints) or use direct depth only.
        uncertainty_range (list): Uncertainty range of predicted depth.
        base_dims (tuple[tuple[float]]): Dimensions mean and std of decode bbox
            dimensions [l, h, w] for each category.
        dims_mode (str): The mode for dimension calculation.
            Available options are "linear" and "exp".
        multibin (bool): Whether to use multibin representation.
        num_dir_bins (int): Number of Number of bins to encode
            direction angle.
        bin_centers (list[float]): Local yaw centers while using multibin
            representations.
        bin_margin (float): Margin of multibin representations.
        code_size (int): The dimension of boxes to be encoded.
        eps (float, optional): A value added to the denominator for numerical
            stability. Default 1e-3.
    """

    def __init__(self,
                 depth_mode,
                 base_depth,
                 depth_range,
                 combine_depth,
                 uncertainty_range,
                 base_dims,
                 dims_mode,
                 multibin,
                 num_dir_bins,
                 bin_centers,
                 bin_margin,
                 code_size,
                 eps=1e-3):
        super(MonoFlexCoder, self).__init__()

        # depth related
        self.depth_mode = depth_mode
        self.base_depth = base_depth
        self.depth_range = depth_range
        self.combine_depth = combine_depth
        self.uncertainty_range = uncertainty_range

        # dimensions related
        self.base_dims = base_dims
        self.dims_mode = dims_mode

        # orientation related
        self.multibin = multibin
        self.num_dir_bins = num_dir_bins
        self.bin_centers = bin_centers
        self.bin_margin = bin_margin

        # output related
        self.bbox_code_size = code_size
        self.eps = eps

    def encode(self, gt_bboxes_3d):
        """Encode ground truth to prediction targets.

        Args:
            gt_bboxes_3d (`BaseInstance3DBoxes`): Ground truth 3D bboxes.
                shape: (N, 7).

        Returns:
            torch.Tensor: Targets of orientations.
        """
        local_yaw = gt_bboxes_3d.local_yaw
        # encode local yaw (-pi ~ pi) to multibin format
ChaimZhu's avatar
ChaimZhu committed
86
87
        encode_local_yaw = local_yaw.new_zeros(
            [local_yaw.shape[0], self.num_dir_bins * 2])
88
89
90
        bin_size = 2 * np.pi / self.num_dir_bins
        margin_size = bin_size * self.bin_margin

ChaimZhu's avatar
ChaimZhu committed
91
        bin_centers = local_yaw.new_tensor(self.bin_centers)
92
93
        range_size = bin_size / 2 + margin_size

ChaimZhu's avatar
ChaimZhu committed
94
        offsets = local_yaw.unsqueeze(1) - bin_centers.unsqueeze(0)
95
96
97
98
99
100
101
        offsets[offsets > np.pi] = offsets[offsets > np.pi] - 2 * np.pi
        offsets[offsets < -np.pi] = offsets[offsets < -np.pi] + 2 * np.pi

        for i in range(self.num_dir_bins):
            offset = offsets[:, i]
            inds = abs(offset) < range_size
            encode_local_yaw[inds, i] = 1
ChaimZhu's avatar
ChaimZhu committed
102
            encode_local_yaw[inds, i + self.num_dir_bins] = offset[inds]
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

        orientation_target = encode_local_yaw

        return orientation_target

    def decode(self, bbox, base_centers2d, labels, downsample_ratio, cam2imgs):
        """Decode bounding box regression into 3D predictions.

        Args:
            bbox (Tensor): Raw bounding box predictions for each
                predict center2d point.
                shape: (N, C)
            base_centers2d (torch.Tensor): Base centers2d for 3D bboxes.
                shape: (N, 2).
            labels (Tensor): Batch predict class label for each predict
                center2d point.
                shape: (N, )
            downsample_ratio (int): The stride of feature map.
            cam2imgs (Tensor): Batch images' camera intrinsic matrix.
                shape: kitti (N, 4, 4)  nuscenes (N, 3, 3)

        Return:
            dict: The 3D prediction dict decoded from regression map.
            the dict has components below:
                - bboxes2d (torch.Tensor): Decoded [x1, y1, x2, y2] format
                    2D bboxes.
                - dimensions (torch.Tensor): Decoded dimensions for each
                    object.
                - offsets2d (torch.Tenosr): Offsets between base centers2d
                    and real centers2d.
                - direct_depth (torch.Tensor): Decoded directly regressed
                    depth.
                - keypoints2d (torch.Tensor): Keypoints of each projected
                    3D box on image.
                - keypoints_depth (torch.Tensor): Decoded depth from keypoints.
                - combined_depth (torch.Tensor): Combined depth using direct
                    depth and keypoints depth with depth uncertainty.
                - orientations (torch.Tensor): Multibin format orientations
                    (local yaw) for each objects.
        """

        # 4 dimensions for FCOS style regression
        pred_bboxes2d = bbox[:, 0:4]

        # change FCOS style to [x1, y1, x2, y2] format for IOU Loss
        pred_bboxes2d = self.decode_bboxes2d(pred_bboxes2d, base_centers2d)

        # 2 dimensions for projected centers2d offsets
        pred_offsets2d = bbox[:, 4:6]

        # 3 dimensions for 3D bbox dimensions offsets
        pred_dimensions_offsets3d = bbox[:, 29:32]

        # the first 8 dimensions are for orientation bin classification
        # and the second 8 dimensions are for orientation offsets.
        pred_orientations = torch.cat((bbox[:, 32:40], bbox[:, 40:48]), dim=1)

        # 3 dimensions for the uncertainties of the solved depths from
        # groups of keypoints
        pred_keypoints_depth_uncertainty = bbox[:, 26:29]

        # 1 dimension for the uncertainty of directly regressed depth
        pred_direct_depth_uncertainty = bbox[:, 49:50].squeeze(-1)

        # 2 dimension of offsets x keypoints (8 corners + top/bottom center)
ChaimZhu's avatar
ChaimZhu committed
168
        pred_keypoints2d = bbox[:, 6:26].reshape(-1, 10, 2)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        # 1 dimension for depth offsets
        pred_direct_depth_offsets = bbox[:, 48:49].squeeze(-1)

        # decode the pred residual dimensions to real dimensions
        pred_dimensions = self.decode_dims(labels, pred_dimensions_offsets3d)
        pred_direct_depth = self.decode_direct_depth(pred_direct_depth_offsets)
        pred_keypoints_depth = self.keypoints2depth(pred_keypoints2d,
                                                    pred_dimensions, cam2imgs,
                                                    downsample_ratio)

        pred_direct_depth_uncertainty = torch.clamp(
            pred_direct_depth_uncertainty, self.uncertainty_range[0],
            self.uncertainty_range[1])
        pred_keypoints_depth_uncertainty = torch.clamp(
            pred_keypoints_depth_uncertainty, self.uncertainty_range[0],
            self.uncertainty_range[1])

        if self.combine_depth:
            pred_depth_uncertainty = torch.cat(
                (pred_direct_depth_uncertainty.unsqueeze(-1),
                 pred_keypoints_depth_uncertainty),
                dim=1).exp()
            pred_depth = torch.cat(
                (pred_direct_depth.unsqueeze(-1), pred_keypoints_depth), dim=1)
            pred_combined_depth = \
                self.combine_depths(pred_depth, pred_depth_uncertainty)
        else:
            pred_combined_depth = None

        preds = dict(
            bboxes2d=pred_bboxes2d,
            dimensions=pred_dimensions,
            offsets2d=pred_offsets2d,
            keypoints2d=pred_keypoints2d,
            orientations=pred_orientations,
            direct_depth=pred_direct_depth,
            keypoints_depth=pred_keypoints_depth,
            combined_depth=pred_combined_depth,
            direct_depth_uncertainty=pred_direct_depth_uncertainty,
            keypoints_depth_uncertainty=pred_keypoints_depth_uncertainty,
        )

        return preds

    def decode_direct_depth(self, depth_offsets):
        """Transform depth offset to directly regressed depth.

        Args:
            depth_offsets (torch.Tensor): Predicted depth offsets.
                shape: (N, )

        Return:
            torch.Tensor: Directly regressed depth.
                shape: (N, )
        """
        if self.depth_mode == 'exp':
            direct_depth = depth_offsets.exp()
        elif self.depth_mode == 'linear':
            base_depth = depth_offsets.new_tensor(self.base_depth)
            direct_depth = depth_offsets * base_depth[1] + base_depth[0]
        elif self.depth_mode == 'inv_sigmoid':
            direct_depth = 1 / torch.sigmoid(depth_offsets) - 1
        else:
            raise ValueError

        if self.depth_range is not None:
            direct_depth = torch.clamp(
                direct_depth, min=self.depth_range[0], max=self.depth_range[1])

        return direct_depth

    def decode_location(self,
                        base_centers2d,
                        offsets2d,
                        depths,
                        cam2imgs,
                        downsample_ratio,
                        pad_mode='default'):
        """Retrieve object location.

        Args:
            base_centers2d (torch.Tensor): predicted base centers2d.
                shape: (N, 2)
            offsets2d (torch.Tensor): The offsets between real centers2d
                and base centers2d.
                shape: (N , 2)
            depths (torch.Tensor): Depths of objects.
                shape: (N, )
            cam2imgs (torch.Tensor): Batch images' camera intrinsic matrix.
                shape: kitti (N, 4, 4)  nuscenes (N, 3, 3)
            downsample_ratio (int): The stride of feature map.
            pad_mode (str, optional): Padding mode used in
                training data augmentation.

        Return:
            tuple(torch.Tensor): Centers of 3D boxes.
                shape: (N, 3)
        """
        N = cam2imgs.shape[0]
        # (N, 4, 4)
        cam2imgs_inv = cam2imgs.inverse()
        if pad_mode == 'default':
            centers2d_img = (base_centers2d + offsets2d) * downsample_ratio
        else:
            raise NotImplementedError
        # (N, 3)
        centers2d_img = \
ChaimZhu's avatar
ChaimZhu committed
277
            torch.cat((centers2d_img, depths.unsqueeze(-1)), dim=1)
278
279
280
        # (N, 4, 1)
        centers2d_extend = \
            torch.cat((centers2d_img, centers2d_img.new_ones(N, 1)),
ChaimZhu's avatar
ChaimZhu committed
281
                      dim=1).unsqueeze(-1)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        locations = torch.matmul(cam2imgs_inv, centers2d_extend).squeeze(-1)

        return locations[:, :3]

    def keypoints2depth(self,
                        keypoints2d,
                        dimensions,
                        cam2imgs,
                        downsample_ratio=4,
                        group0_index=[(7, 3), (0, 4)],
                        group1_index=[(2, 6), (1, 5)]):
        """Decode depth form three groups of keypoints and geometry projection
        model. 2D keypoints inlucding 8 coreners and top/bottom centers will be
        divided into three groups which will be used to calculate three depths
        of object.

        .. code-block:: none

                Group center keypoints:

                             + --------------- +
                            /|   top center   /|
                           / |      .        / |
                          /  |      |       /  |
                         + ---------|----- +   +
                         |  /       |      |  /
                         | /        .      | /
                         |/ bottom center  |/
                         + --------------- +

                Group 0 keypoints:

                             0
                             + -------------- +
                            /|               /|
                           / |              / |
                          /  |            5/  |
                         + -------------- +   +
                         |  /3            |  /
                         | /              | /
                         |/               |/
                         + -------------- + 6

                Group 1 keypoints:

                                               4
                             + -------------- +
                            /|               /|
                           / |              / |
                          /  |             /  |
                       1 + -------------- +   + 7
                         |  /             |  /
                         | /              | /
                         |/               |/
                       2 + -------------- +


        Args:
            keypoints2d (torch.Tensor): Keypoints of objects.
                8 vertices + top/bottom center.
                shape: (N, 10, 2)
            dimensions (torch.Tensor): Dimensions of objetcts.
                shape: (N, 3)
            cam2imgs (torch.Tensor): Batch images' camera intrinsic matrix.
                shape: kitti (N, 4, 4)  nuscenes (N, 3, 3)
            downsample_ratio (int, opitonal): The stride of feature map.
                Defaults: 4.
            group0_index(list[tuple[int]], optional): Keypoints group 0
                of index to calculate the depth.
                Defaults: [0, 3, 4, 7].
            group1_index(list[tuple[int]], optional): Keypoints group 1
                of index to calculate the depth.
                Defaults: [1, 2, 5, 6]

        Return:
            tuple(torch.Tensor): Depth computed from three groups of
                keypoints (top/bottom, group0, group1)
                shape: (N, 3)
        """

        pred_height_3d = dimensions[:, 1].clone()
        f_u = cam2imgs[:, 0, 0]
        center_height = keypoints2d[:, -2, 1] - keypoints2d[:, -1, 1]
        corner_group0_height = keypoints2d[:, group0_index[0], 1] \
            - keypoints2d[:, group0_index[1], 1]
        corner_group1_height = keypoints2d[:, group1_index[0], 1] \
            - keypoints2d[:, group1_index[1], 1]
        center_depth = f_u * pred_height_3d / (
            F.relu(center_height) * downsample_ratio + self.eps)
        corner_group0_depth = (f_u * pred_height_3d).unsqueeze(-1) / (
            F.relu(corner_group0_height) * downsample_ratio + self.eps)
        corner_group1_depth = (f_u * pred_height_3d).unsqueeze(-1) / (
            F.relu(corner_group1_height) * downsample_ratio + self.eps)

        corner_group0_depth = corner_group0_depth.mean(dim=1)
        corner_group1_depth = corner_group1_depth.mean(dim=1)

        keypoints_depth = torch.stack(
            (center_depth, corner_group0_depth, corner_group1_depth), dim=1)
        keypoints_depth = torch.clamp(
            keypoints_depth, min=self.depth_range[0], max=self.depth_range[1])

        return keypoints_depth

    def decode_dims(self, labels, dims_offset):
        """Retrieve object dimensions.

        Args:
            labels (torch.Tensor): Each points' category id.
                shape: (N, K)
            dims_offset (torch.Tensor): Dimension offsets.
                shape: (N, 3)

        Returns:
            torch.Tensor: Shape (N, 3)
        """

        if self.dims_mode == 'exp':
            dims_offset = dims_offset.exp()
        elif self.dims_mode == 'linear':
            labels = labels.long()
            base_dims = dims_offset.new_tensor(self.base_dims)
            dims_mean = base_dims[:, :3]
            dims_std = base_dims[:, 3:6]
            cls_dimension_mean = dims_mean[labels, :]
            cls_dimension_std = dims_std[labels, :]
            dimensions = dims_offset * cls_dimension_mean + cls_dimension_std
        else:
            raise ValueError

        return dimensions

    def decode_orientation(self, ori_vector, locations):
        """Retrieve object orientation.

        Args:
            ori_vector (torch.Tensor): Local orientation vector
                in [axis_cls, head_cls, sin, cos] format.
                shape: (N, num_dir_bins * 4)
            locations (torch.Tensor): Object location.
                shape: (N, 3)

        Returns:
            tuple[torch.Tensor]: yaws and local yaws of 3d bboxes.
        """
        if self.multibin:
            pred_bin_cls = ori_vector[:, :self.num_dir_bins * 2].view(
                -1, self.num_dir_bins, 2)
            pred_bin_cls = pred_bin_cls.softmax(dim=2)[..., 1]
            orientations = ori_vector.new_zeros(ori_vector.shape[0])
            for i in range(self.num_dir_bins):
                mask_i = (pred_bin_cls.argmax(dim=1) == i)
                start_bin = self.num_dir_bins * 2 + i * 2
                end_bin = start_bin + 2
                pred_bin_offset = ori_vector[mask_i, start_bin:end_bin]
                orientations[mask_i] = pred_bin_offset[:, 0].atan2(
                    pred_bin_offset[:, 1]) + self.bin_centers[i]
        else:
            axis_cls = ori_vector[:, :2].softmax(dim=1)
            axis_cls = axis_cls[:, 0] < axis_cls[:, 1]
            head_cls = ori_vector[:, 2:4].softmax(dim=1)
            head_cls = head_cls[:, 0] < head_cls[:, 1]
            # cls axis
            orientations = self.bin_centers[axis_cls + head_cls * 2]
            sin_cos_offset = F.normalize(ori_vector[:, 4:])
            orientations += sin_cos_offset[:, 0].atan(sin_cos_offset[:, 1])

        locations = locations.view(-1, 3)
        rays = locations[:, 0].atan2(locations[:, 2])
        local_yaws = orientations
        yaws = local_yaws + rays

ChaimZhu's avatar
ChaimZhu committed
454
455
        larger_idx = (yaws > np.pi).nonzero(as_tuple=False)
        small_idx = (yaws < -np.pi).nonzero(as_tuple=False)
456
457
458
459
460
        if len(larger_idx) != 0:
            yaws[larger_idx] -= 2 * np.pi
        if len(small_idx) != 0:
            yaws[small_idx] += 2 * np.pi

ChaimZhu's avatar
ChaimZhu committed
461
462
        larger_idx = (local_yaws > np.pi).nonzero(as_tuple=False)
        small_idx = (local_yaws < -np.pi).nonzero(as_tuple=False)
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        if len(larger_idx) != 0:
            local_yaws[larger_idx] -= 2 * np.pi
        if len(small_idx) != 0:
            local_yaws[small_idx] += 2 * np.pi

        return yaws, local_yaws

    def decode_bboxes2d(self, reg_bboxes2d, base_centers2d):
        """Retrieve [x1, y1, x2, y2] format 2D bboxes.

        Args:
            reg_bboxes2d (torch.Tensor): Predicted FCOS style
                2D bboxes.
                shape: (N, 4)
            base_centers2d (torch.Tensor): predicted base centers2d.
                shape: (N, 2)

        Returns:
            torch.Tenosr: [x1, y1, x2, y2] format 2D bboxes.
        """
        centers_x = base_centers2d[:, 0]
        centers_y = base_centers2d[:, 1]

        xs_min = centers_x - reg_bboxes2d[..., 0]
        ys_min = centers_y - reg_bboxes2d[..., 1]
        xs_max = centers_x + reg_bboxes2d[..., 2]
        ys_max = centers_y + reg_bboxes2d[..., 3]

        bboxes2d = torch.stack([xs_min, ys_min, xs_max, ys_max], dim=-1)

        return bboxes2d

ChaimZhu's avatar
ChaimZhu committed
495
    def combine_depths(self, depth, depth_uncertainty):
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        """Combine all the prediced depths with depth uncertainty.

        Args:
            depth (torch.Tensor): Predicted depths of each object.
                2D bboxes.
                shape: (N, 4)
            depth_uncertainty (torch.Tensor): Depth uncertainty for
                each depth of each object.
                shape: (N, 4)

        Returns:
            torch.Tenosr: combined depth.
        """
        uncertainty_weights = 1 / depth_uncertainty
        uncertainty_weights = \
            uncertainty_weights / \
            uncertainty_weights.sum(dim=1, keepdim=True)
        combined_depth = torch.sum(depth * uncertainty_weights, dim=1)

        return combined_depth