handle_objs.py 5.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list,
                        gt_labels_3d_list, centers2d_list, img_metas):
    """Function to filter the objects label outside the image.

    Args:
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
            each has shape (num_gt, 4).
        gt_labels_list (list[Tensor]): Ground truth labels of each box,
            each has shape (num_gt,).
        gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
            image, each has shape (num_gt, bbox_code_size).
        gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
            box, each has shape (num_gt,).
        centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
            each has shape (num_gt, 2).
        img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.
    """
    bs = len(centers2d_list)

    for i in range(bs):
        centers2d = centers2d_list[i].clone()
        img_shape = img_metas[i]['img_shape']
        keep_inds = (centers2d[:, 0] > 0) & \
            (centers2d[:, 0] < img_shape[1]) & \
            (centers2d[:, 1] > 0) & \
            (centers2d[:, 1] < img_shape[0])
        centers2d_list[i] = centers2d[keep_inds]
        gt_labels_list[i] = gt_labels_list[i][keep_inds]
        gt_bboxes_list[i] = gt_bboxes_list[i][keep_inds]
        gt_bboxes_3d_list[i].tensor = gt_bboxes_3d_list[i].tensor[keep_inds]
        gt_labels_3d_list[i] = gt_labels_3d_list[i][keep_inds]


def get_centers2d_target(centers2d, centers, img_shape):
    """Function to get target centers2d.

    Args:
        centers2d (Tensor): Projected 3D centers onto 2D images.
        centers (Tensor): Centers of 2d gt bboxes.
        img_shape (tuple): Resized image shape.

    Returns:
        torch.Tensor: Projected 3D centers (centers2D) target.
    """
    N = centers2d.shape[0]
    h, w = img_shape[:2]
    valid_intersects = centers2d.new_zeros((N, 2))
    a = (centers[:, 1] - centers2d[:, 1]) / (centers[:, 0] - centers2d[:, 0])
    b = centers[:, 1] - a * centers[:, 0]
    left_y = b
    right_y = (w - 1) * a + b
    top_x = -b / a
    bottom_x = (h - 1 - b) / a

    left_coors = torch.stack((left_y.new_zeros(N, ), left_y), dim=1)
    right_coors = torch.stack((right_y.new_full((N, ), w - 1), right_y), dim=1)
    top_coors = torch.stack((top_x, top_x.new_zeros(N, )), dim=1)
    bottom_coors = torch.stack((bottom_x, bottom_x.new_full((N, ), h - 1)),
                               dim=1)

    intersects = torch.stack(
        [left_coors, right_coors, top_coors, bottom_coors], dim=1)
    intersects_x = intersects[:, :, 0]
    intersects_y = intersects[:, :, 1]
    inds = (intersects_x >= 0) & (intersects_x <=
                                  w - 1) & (intersects_y >= 0) & (
                                      intersects_y <= h - 1)
    valid_intersects = intersects[inds].reshape(N, 2, 2)
    dist = torch.norm(valid_intersects - centers2d.unsqueeze(1), dim=2)
    min_idx = torch.argmin(dist, dim=1)

    min_idx = min_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2)
    centers2d_target = valid_intersects.gather(dim=1, index=min_idx).squeeze(1)

    return centers2d_target


def handle_proj_objs(centers2d_list, gt_bboxes_list, img_metas):
    """Function to handle projected object centers2d, generate target
    centers2d.

    Args:
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
            shape (num_gt, 4).
        centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
            shape (num_gt, 2).
        img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.

    Returns:
        tuple[list[Tensor]]: It contains three elements. The first is the
        target centers2d after handling the truncated objects. The second
        is the offsets between target centers2d and round int dtype
        centers2d,and the last is the truncation mask for each object in
        batch data.
    """
    bs = len(centers2d_list)
    centers2d_target_list = []
    trunc_mask_list = []
    offsets2d_list = []
    # for now, only pad mode that img is padded by right and
    # bottom side is supported.
    for i in range(bs):
        centers2d = centers2d_list[i]
        gt_bbox = gt_bboxes_list[i]
        img_shape = img_metas[i]['img_shape']
        centers2d_target = centers2d.clone()
        inside_inds = (centers2d[:, 0] > 0) & \
            (centers2d[:, 0] < img_shape[1]) & \
            (centers2d[:, 1] > 0) & \
            (centers2d[:, 1] < img_shape[0])
        outside_inds = ~inside_inds

        # if there are outside objects
        if outside_inds.any():
            centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
            outside_centers2d = centers2d[outside_inds]
            match_centers = centers[outside_inds]
            target_outside_centers2d = get_centers2d_target(
                outside_centers2d, match_centers, img_shape)
            centers2d_target[outside_inds] = target_outside_centers2d

        offsets2d = centers2d - centers2d_target.round().int()
        trunc_mask = outside_inds

        centers2d_target_list.append(centers2d_target)
        trunc_mask_list.append(trunc_mask)
        offsets2d_list.append(offsets2d)

    return (centers2d_target_list, offsets2d_list, trunc_mask_list)