gen_keypoints.py 3.25 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
from torch import Tensor

from mmdet3d.structures import CameraInstance3DBoxes, points_cam2img


def get_keypoints(
        gt_bboxes_3d_list: List[CameraInstance3DBoxes],
        centers2d_list: List[Tensor],
        img_metas: List[dict],
        use_local_coords: bool = True) -> Tuple[List[Tensor], List[Tensor]]:
    """Function to filter the objects label outside the image.

    Args:
        gt_bboxes_3d_list (List[:obj:`CameraInstance3DBoxes`]): Ground truth
            bboxes of each image.
        centers2d_list (List[Tensor]): Projected 3D centers onto 2D image,
            shape (num_gt, 2).
        img_metas (List[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.
        use_local_coords (bool): Whether to use local coordinates
            for keypoints. Defaults to True.

    Returns:
        Tuple[List[Tensor], List[Tensor]]: It contains two elements,
        the first is the keypoints for each projected 2D bbox in batch data.
        The second is the visible mask of depth calculated by keypoints.
    """

    assert len(gt_bboxes_3d_list) == len(centers2d_list)
    bs = len(gt_bboxes_3d_list)
    keypoints2d_list = []
    keypoints_depth_mask_list = []

    for i in range(bs):
        gt_bboxes_3d = gt_bboxes_3d_list[i]
        centers2d = centers2d_list[i]
        img_shape = img_metas[i]['img_shape']
        cam2img = img_metas[i]['cam2img']
        h, w = img_shape[:2]
        # (N, 8, 3)
        corners3d = gt_bboxes_3d.corners
        top_centers3d = torch.mean(corners3d[:, [0, 1, 4, 5], :], dim=1)
        bot_centers3d = torch.mean(corners3d[:, [2, 3, 6, 7], :], dim=1)
        # (N, 2, 3)
        top_bot_centers3d = torch.stack((top_centers3d, bot_centers3d), dim=1)
        keypoints3d = torch.cat((corners3d, top_bot_centers3d), dim=1)
        # (N, 10, 2)
        keypoints2d = points_cam2img(keypoints3d, cam2img)

        # keypoints mask: keypoints must be inside
        # the image and in front of the camera
        keypoints_x_visible = (keypoints2d[..., 0] >= 0) & (
            keypoints2d[..., 0] <= w - 1)
        keypoints_y_visible = (keypoints2d[..., 1] >= 0) & (
            keypoints2d[..., 1] <= h - 1)
        keypoints_z_visible = (keypoints3d[..., -1] > 0)

        # (N, 1O)
        keypoints_visible = \
            keypoints_x_visible & keypoints_y_visible & keypoints_z_visible
        # center, diag-02, diag-13
        keypoints_depth_valid = torch.stack(
            (keypoints_visible[:, [8, 9]].all(dim=1),
             keypoints_visible[:, [0, 3, 5, 6]].all(dim=1),
             keypoints_visible[:, [1, 2, 4, 7]].all(dim=1)),
            dim=1)
        keypoints_visible = keypoints_visible.float()

        if use_local_coords:
            keypoints2d = torch.cat((keypoints2d - centers2d.unsqueeze(1),
                                     keypoints_visible.unsqueeze(-1)),
                                    dim=2)
        else:
            keypoints2d = torch.cat(
                (keypoints2d, keypoints_visible.unsqueeze(-1)), dim=2)

        keypoints2d_list.append(keypoints2d)
        keypoints_depth_mask_list.append(keypoints_depth_valid)

    return (keypoints2d_list, keypoints_depth_mask_list)