fcos3d_bbox_coder.py 5.03 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
import numpy as np
import torch
4
from mmdet.models.task_modules import BaseBBoxCoder
Tai-Wang's avatar
Tai-Wang committed
5

6
from mmdet3d.registry import TASK_UTILS
zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.structures.bbox_3d import limit_period
Tai-Wang's avatar
Tai-Wang committed
8
9


10
@TASK_UTILS.register_module()
Tai-Wang's avatar
Tai-Wang committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class FCOS3DBBoxCoder(BaseBBoxCoder):
    """Bounding box coder for FCOS3D.

    Args:
        base_depths (tuple[tuple[float]]): Depth references for decode box
            depth. Defaults to None.
        base_dims (tuple[tuple[float]]): Dimension references for decode box
            dimension. Defaults to None.
        code_size (int): The dimension of boxes to be encoded. Defaults to 7.
        norm_on_bbox (bool): Whether to apply normalization on the bounding
            box 2D attributes. Defaults to True.
    """

    def __init__(self,
                 base_depths=None,
                 base_dims=None,
                 code_size=7,
                 norm_on_bbox=True):
        super(FCOS3DBBoxCoder, self).__init__()
        self.base_depths = base_depths
        self.base_dims = base_dims
        self.bbox_code_size = code_size
        self.norm_on_bbox = norm_on_bbox

    def encode(self, gt_bboxes_3d, gt_labels_3d, gt_bboxes, gt_labels):
        # TODO: refactor the encoder in the FCOS3D and PGD head
        pass

    def decode(self, bbox, scale, stride, training, cls_score=None):
        """Decode regressed results into 3D predictions.

        Note that offsets are not transformed to the projected 3D centers.

        Args:
            bbox (torch.Tensor): Raw bounding box predictions in shape
                [N, C, H, W].
            scale (tuple[`Scale`]): Learnable scale parameters.
Tai-Wang's avatar
Tai-Wang committed
48
            stride (int): Stride for a specific feature level.
Tai-Wang's avatar
Tai-Wang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            training (bool): Whether the decoding is in the training
                procedure.
            cls_score (torch.Tensor): Classification score map for deciding
                which base depth or dim is used. Defaults to None.

        Returns:
            torch.Tensor: Decoded boxes.
        """
        # scale the bbox of different level
        # only apply to offset, depth and size prediction
        scale_offset, scale_depth, scale_size = scale[0:3]

        clone_bbox = bbox.clone()
        bbox[:, :2] = scale_offset(clone_bbox[:, :2]).float()
        bbox[:, 2] = scale_depth(clone_bbox[:, 2]).float()
        bbox[:, 3:6] = scale_size(clone_bbox[:, 3:6]).float()

        if self.base_depths is None:
            bbox[:, 2] = bbox[:, 2].exp()
        elif len(self.base_depths) == 1:  # only single prior
            mean = self.base_depths[0][0]
            std = self.base_depths[0][1]
            bbox[:, 2] = mean + bbox.clone()[:, 2] * std
        else:  # multi-class priors
            assert len(self.base_depths) == cls_score.shape[1], \
                'The number of multi-class depth priors should be equal to ' \
                'the number of categories.'
            indices = cls_score.max(dim=1)[1]
            depth_priors = cls_score.new_tensor(
                self.base_depths)[indices, :].permute(0, 3, 1, 2)
            mean = depth_priors[:, 0]
            std = depth_priors[:, 1]
            bbox[:, 2] = mean + bbox.clone()[:, 2] * std

        bbox[:, 3:6] = bbox[:, 3:6].exp()
        if self.base_dims is not None:
            assert len(self.base_dims) == cls_score.shape[1], \
                'The number of anchor sizes should be equal to the number ' \
                'of categories.'
            indices = cls_score.max(dim=1)[1]
            size_priors = cls_score.new_tensor(
                self.base_dims)[indices, :].permute(0, 3, 1, 2)
            bbox[:, 3:6] = size_priors * bbox.clone()[:, 3:6]

        assert self.norm_on_bbox is True, 'Setting norm_on_bbox to False '\
            'has not been thoroughly tested for FCOS3D.'
        if self.norm_on_bbox:
            if not training:
                # Note that this line is conducted only when testing
                bbox[:, :2] *= stride

        return bbox

    @staticmethod
    def decode_yaw(bbox, centers2d, dir_cls, dir_offset, cam2img):
        """Decode yaw angle and change it from local to global.i.

        Args:
            bbox (torch.Tensor): Bounding box predictions in shape
                [N, C] with yaws to be decoded.
            centers2d (torch.Tensor): Projected 3D-center on the image planes
                corresponding to the box predictions.
            dir_cls (torch.Tensor): Predicted direction classes.
            dir_offset (float): Direction offset before dividing all the
                directions into several classes.
            cam2img (torch.Tensor): Camera intrinsic matrix in shape [4, 4].

        Returns:
            torch.Tensor: Bounding boxes with decoded yaws.
        """
        if bbox.shape[0] > 0:
            dir_rot = limit_period(bbox[..., 6] - dir_offset, 0, np.pi)
            bbox[..., 6] = \
                dir_rot + dir_offset + np.pi * dir_cls.to(bbox.dtype)

        bbox[:, 6] = torch.atan2(centers2d[:, 0] - cam2img[0, 2],
                                 cam2img[0, 0]) + bbox[:, 6]

        return bbox