vote_fusion.py 9.26 KB
Newer Older
1
2
3
4
import torch
from torch import nn as nn

from mmdet3d.core.bbox import Coord3DMode, points_cam2img
5
from ..builder import FUSION_LAYERS
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform

EPS = 1e-6


@FUSION_LAYERS.register_module()
class VoteFusion(nn.Module):
    """Fuse 2d features from 3d seeds.

    Args:
        num_classes (int): number of classes.
        max_imvote_per_pixel (int): max number of imvotes.
    """

    def __init__(self, num_classes=10, max_imvote_per_pixel=3):
        super(VoteFusion, self).__init__()
        self.num_classes = num_classes
        self.max_imvote_per_pixel = max_imvote_per_pixel

    def forward(self, imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas,
                calibs):
        """Forward function.

        Args:
            imgs (list[torch.Tensor]): Image features.
            bboxes_2d_rescaled (list[torch.Tensor]): 2D bboxes.
            seeds_3d_depth (torch.Tensor): 3D seeds.
            img_metas (list[dict]): Meta information of images.
            calibs: Camera calibration information of the images.

        Returns:
            torch.Tensor: Concatenated cues of each point.
            torch.Tensor: Validity mask of each feature.
        """
        img_features = []
        masks = []
        for i, data in enumerate(
                zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)):
            img, bbox_2d_rescaled, seed_3d_depth, img_meta = data
            bbox_num = bbox_2d_rescaled.shape[0]
            seed_num = seed_3d_depth.shape[0]

            img_shape = img_meta['img_shape']
            img_h, img_w, _ = img_shape

            # first reverse the data transformations
            xyz_depth = apply_3d_transformation(
                seed_3d_depth, 'DEPTH', img_meta, reverse=True)

            # then convert from depth coords to camera coords
            xyz_cam = Coord3DMode.convert_point(
                xyz_depth,
                Coord3DMode.DEPTH,
                Coord3DMode.CAM,
                rt_mat=calibs['Rt'][i])

            # project to 2d to get image coords (uv)
            uv_origin = points_cam2img(xyz_cam, calibs['K'][i])
            uv_origin = (uv_origin - 1).round()

            # rescale 2d coordinates and bboxes
            uv_rescaled = coord_2d_transform(img_meta, uv_origin, True)
            bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled,
                                               False)

            if bbox_num == 0:
                imvote_num = seed_num * self.max_imvote_per_pixel

                # use zero features
                two_cues = torch.zeros((15, imvote_num),
                                       device=seed_3d_depth.device)
                mask_zero = torch.zeros(
                    imvote_num - seed_num, device=seed_3d_depth.device).bool()
                mask_one = torch.ones(
                    seed_num, device=seed_3d_depth.device).bool()
                mask = torch.cat([mask_one, mask_zero], dim=0)
            else:
                # expand bboxes and seeds
                bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand(
                    seed_num, -1, -1)
                seed_2d_expanded = uv_origin.view(seed_num, 1,
                                                  -1).expand(-1, bbox_num, -1)
                seed_2d_expanded_x, seed_2d_expanded_y = \
                    seed_2d_expanded.split(1, dim=-1)

                bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \
                    bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \
                    bbox_expanded.split(1, dim=-1)
                bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2
                bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2

                seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \
                    (seed_2d_expanded_x < bbox_expanded_r)
                seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \
                    (seed_2d_expanded_y < bbox_expanded_b)
                seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y

                # semantic cues, dim=class_num
                sem_cue = torch.zeros_like(bbox_expanded_conf).expand(
                    -1, -1, self.num_classes)
                sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(),
                                          bbox_expanded_conf)

                # bbox center - uv
                delta_u = bbox_expanded_midx - seed_2d_expanded_x
                delta_v = bbox_expanded_midy - seed_2d_expanded_y

                seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand(
                    -1, bbox_num, -1)

                z_cam = xyz_cam[..., 2:3].view(seed_num, 1,
                                               1).expand(-1, bbox_num, -1)

                delta_u = delta_u * z_cam / calibs['K'][i, 0, 0]
                delta_v = delta_v * z_cam / calibs['K'][i, 0, 0]

                imvote = torch.cat(
                    [delta_u, delta_v,
                     torch.zeros_like(delta_v)], dim=-1).view(-1, 3)

                # convert from camera coords to depth coords
                imvote = Coord3DMode.convert_point(
                    imvote.view((-1, 3)),
                    Coord3DMode.CAM,
                    Coord3DMode.DEPTH,
                    rt_mat=calibs['Rt'][i])

                # apply transformation to lifted imvotes
                imvote = apply_3d_transformation(
                    imvote, 'DEPTH', img_meta, reverse=False)

                seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape)

                # ray angle
                ray_angle = seed_3d_expanded + imvote
                ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) +
                                        EPS).unsqueeze(-1)

                # imvote lifted to 3d
                xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \
                    * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]]

                # geometric cues, dim=5
                geo_cue = torch.cat([xz, ray_angle],
                                    dim=-1).view(seed_num, -1, 5)

                two_cues = torch.cat([geo_cue, sem_cue], dim=-1)
                # mask to 0 if seed not in bbox
                two_cues = two_cues * seed_2d_in_bbox.float()

                feature_size = two_cues.shape[-1]
                # if bbox number is too small, append zeros
                if bbox_num < self.max_imvote_per_pixel:
                    append_num = self.max_imvote_per_pixel - bbox_num
                    append_zeros = torch.zeros(
                        (seed_num, append_num, 1),
                        device=seed_2d_in_bbox.device).bool()
                    seed_2d_in_bbox = torch.cat(
                        [seed_2d_in_bbox, append_zeros], dim=1)
                    append_zeros = torch.zeros(
                        (seed_num, append_num, feature_size),
                        device=two_cues.device)
                    two_cues = torch.cat([two_cues, append_zeros], dim=1)
                    append_zeros = torch.zeros((seed_num, append_num, 1),
                                               device=two_cues.device)
                    bbox_expanded_conf = torch.cat(
                        [bbox_expanded_conf, append_zeros], dim=1)

                # sort the valid seed-bbox pair according to confidence
                pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf
                # and find the largests
                mask, indices = pair_score.topk(
                    self.max_imvote_per_pixel,
                    dim=1,
                    largest=True,
                    sorted=True)

                indices_img = indices.expand(-1, -1, feature_size)
                two_cues = two_cues.gather(dim=1, index=indices_img)
                two_cues = two_cues.transpose(1, 0)
                two_cues = two_cues.reshape(-1, feature_size).transpose(
                    1, 0).contiguous()

                # since conf is ~ (0, 1), floor gives us validity
                mask = mask.floor().int()
                mask = mask.transpose(1, 0).reshape(-1).bool()

            # clear the padding
            img = img[:, :img_shape[0], :img_shape[1]]
            img_flatten = img.reshape(3, -1).float()
            img_flatten /= 255.

            # take the normalized pixel value as texture cue
199
200
201
202
            uv_rescaled[:, 0] = torch.clamp(uv_rescaled[:, 0].round(), 0,
                                            img_shape[1] - 1)
            uv_rescaled[:, 1] = torch.clamp(uv_rescaled[:, 1].round(), 0,
                                            img_shape[0] - 1)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            uv_flatten = uv_rescaled[:, 1].round() * \
                img_shape[1] + uv_rescaled[:, 0].round()
            uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long()
            txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded)
            txt_cue = txt_cue.unsqueeze(1).expand(-1,
                                                  self.max_imvote_per_pixel,
                                                  -1).reshape(3, -1)

            # append texture cue
            img_feature = torch.cat([two_cues, txt_cue], dim=0)
            img_features.append(img_feature)
            masks.append(mask)

        return torch.stack(img_features, 0), torch.stack(masks, 0)