vote_fusion.py 8.69 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
import torch
from torch import nn as nn

5
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
6
from mmdet3d.structures import points_cam2img
7
8
9
10
11
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform

EPS = 1e-6


12
@MODELS.register_module()
13
14
15
16
17
18
19
20
21
22
23
24
25
class VoteFusion(nn.Module):
    """Fuse 2d features from 3d seeds.

    Args:
        num_classes (int): number of classes.
        max_imvote_per_pixel (int): max number of imvotes.
    """

    def __init__(self, num_classes=10, max_imvote_per_pixel=3):
        super(VoteFusion, self).__init__()
        self.num_classes = num_classes
        self.max_imvote_per_pixel = max_imvote_per_pixel

26
    def forward(self, imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        """Forward function.

        Args:
            imgs (list[torch.Tensor]): Image features.
            bboxes_2d_rescaled (list[torch.Tensor]): 2D bboxes.
            seeds_3d_depth (torch.Tensor): 3D seeds.
            img_metas (list[dict]): Meta information of images.

        Returns:
            torch.Tensor: Concatenated cues of each point.
            torch.Tensor: Validity mask of each feature.
        """
        img_features = []
        masks = []
        for i, data in enumerate(
                zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)):
            img, bbox_2d_rescaled, seed_3d_depth, img_meta = data
            bbox_num = bbox_2d_rescaled.shape[0]
            seed_num = seed_3d_depth.shape[0]

            img_shape = img_meta['img_shape']
            # first reverse the data transformations
            xyz_depth = apply_3d_transformation(
                seed_3d_depth, 'DEPTH', img_meta, reverse=True)

52
53
54
55
56
            # project points from depth to image
            depth2img = xyz_depth.new_tensor(img_meta['depth2img'])
            uvz_origin = points_cam2img(xyz_depth, depth2img, True)
            z_cam = uvz_origin[..., 2]
            uv_origin = (uvz_origin[..., :2] - 1).round()
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

            # rescale 2d coordinates and bboxes
            uv_rescaled = coord_2d_transform(img_meta, uv_origin, True)
            bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled,
                                               False)

            if bbox_num == 0:
                imvote_num = seed_num * self.max_imvote_per_pixel

                # use zero features
                two_cues = torch.zeros((15, imvote_num),
                                       device=seed_3d_depth.device)
                mask_zero = torch.zeros(
                    imvote_num - seed_num, device=seed_3d_depth.device).bool()
                mask_one = torch.ones(
                    seed_num, device=seed_3d_depth.device).bool()
                mask = torch.cat([mask_one, mask_zero], dim=0)
            else:
                # expand bboxes and seeds
                bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand(
                    seed_num, -1, -1)
                seed_2d_expanded = uv_origin.view(seed_num, 1,
                                                  -1).expand(-1, bbox_num, -1)
                seed_2d_expanded_x, seed_2d_expanded_y = \
                    seed_2d_expanded.split(1, dim=-1)

                bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \
                    bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \
                    bbox_expanded.split(1, dim=-1)
                bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2
                bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2

                seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \
                    (seed_2d_expanded_x < bbox_expanded_r)
                seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \
                    (seed_2d_expanded_y < bbox_expanded_b)
                seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y

                # semantic cues, dim=class_num
                sem_cue = torch.zeros_like(bbox_expanded_conf).expand(
                    -1, -1, self.num_classes)
                sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(),
                                          bbox_expanded_conf)

                # bbox center - uv
                delta_u = bbox_expanded_midx - seed_2d_expanded_x
                delta_v = bbox_expanded_midy - seed_2d_expanded_y

                seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand(
                    -1, bbox_num, -1)

108
                z_cam = z_cam.view(seed_num, 1, 1).expand(-1, bbox_num, -1)
109
110
111
                imvote = torch.cat(
                    [delta_u, delta_v,
                     torch.zeros_like(delta_v)], dim=-1).view(-1, 3)
112
113
                imvote = imvote * z_cam.reshape(-1, 1)
                imvote = imvote @ torch.inverse(depth2img.t())
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

                # apply transformation to lifted imvotes
                imvote = apply_3d_transformation(
                    imvote, 'DEPTH', img_meta, reverse=False)

                seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape)

                # ray angle
                ray_angle = seed_3d_expanded + imvote
                ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) +
                                        EPS).unsqueeze(-1)

                # imvote lifted to 3d
                xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \
                    * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]]

                # geometric cues, dim=5
                geo_cue = torch.cat([xz, ray_angle],
                                    dim=-1).view(seed_num, -1, 5)

                two_cues = torch.cat([geo_cue, sem_cue], dim=-1)
                # mask to 0 if seed not in bbox
                two_cues = two_cues * seed_2d_in_bbox.float()

                feature_size = two_cues.shape[-1]
                # if bbox number is too small, append zeros
                if bbox_num < self.max_imvote_per_pixel:
                    append_num = self.max_imvote_per_pixel - bbox_num
                    append_zeros = torch.zeros(
                        (seed_num, append_num, 1),
                        device=seed_2d_in_bbox.device).bool()
                    seed_2d_in_bbox = torch.cat(
                        [seed_2d_in_bbox, append_zeros], dim=1)
                    append_zeros = torch.zeros(
                        (seed_num, append_num, feature_size),
                        device=two_cues.device)
                    two_cues = torch.cat([two_cues, append_zeros], dim=1)
                    append_zeros = torch.zeros((seed_num, append_num, 1),
                                               device=two_cues.device)
                    bbox_expanded_conf = torch.cat(
                        [bbox_expanded_conf, append_zeros], dim=1)

                # sort the valid seed-bbox pair according to confidence
                pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf
                # and find the largests
                mask, indices = pair_score.topk(
                    self.max_imvote_per_pixel,
                    dim=1,
                    largest=True,
                    sorted=True)

                indices_img = indices.expand(-1, -1, feature_size)
                two_cues = two_cues.gather(dim=1, index=indices_img)
                two_cues = two_cues.transpose(1, 0)
                two_cues = two_cues.reshape(-1, feature_size).transpose(
                    1, 0).contiguous()

                # since conf is ~ (0, 1), floor gives us validity
                mask = mask.floor().int()
                mask = mask.transpose(1, 0).reshape(-1).bool()

            # clear the padding
            img = img[:, :img_shape[0], :img_shape[1]]
            img_flatten = img.reshape(3, -1).float()
            img_flatten /= 255.

            # take the normalized pixel value as texture cue
181
182
183
184
            uv_rescaled[:, 0] = torch.clamp(uv_rescaled[:, 0].round(), 0,
                                            img_shape[1] - 1)
            uv_rescaled[:, 1] = torch.clamp(uv_rescaled[:, 1].round(), 0,
                                            img_shape[0] - 1)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            uv_flatten = uv_rescaled[:, 1].round() * \
                img_shape[1] + uv_rescaled[:, 0].round()
            uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long()
            txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded)
            txt_cue = txt_cue.unsqueeze(1).expand(-1,
                                                  self.max_imvote_per_pixel,
                                                  -1).reshape(3, -1)

            # append texture cue
            img_feature = torch.cat([two_cues, txt_cue], dim=0)
            img_features.append(img_feature)
            masks.append(mask)

        return torch.stack(img_features, 0), torch.stack(masks, 0)