vote_fusion.py 8.92 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List, Tuple

4
import torch
5
from torch import Tensor
6
7
from torch import nn as nn

8
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.structures import points_cam2img
10
11
12
13
14
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform

EPS = 1e-6


15
@MODELS.register_module()
16
17
18
19
class VoteFusion(nn.Module):
    """Fuse 2d features from 3d seeds.

    Args:
20
21
        num_classes (int): Number of classes.
        max_imvote_per_pixel (int): Max number of imvotes.
22
23
    """

24
25
26
    def __init__(self,
                 num_classes: int = 10,
                 max_imvote_per_pixel: int = 3) -> None:
27
28
29
30
        super(VoteFusion, self).__init__()
        self.num_classes = num_classes
        self.max_imvote_per_pixel = max_imvote_per_pixel

31
32
33
    def forward(self, imgs: List[Tensor], bboxes_2d_rescaled: List[Tensor],
                seeds_3d_depth: List[Tensor],
                img_metas: List[dict]) -> Tuple[Tensor]:
34
35
36
        """Forward function.

        Args:
37
38
39
40
            imgs (List[Tensor]): Image features.
            bboxes_2d_rescaled (List[Tensor]): 2D bboxes.
            seeds_3d_depth (List[Tensor]): 3D seeds.
            img_metas (List[dict]): Meta information of images.
41
42

        Returns:
43
44
45
46
            Tuple[Tensor]:

                - img_features: Concatenated cues of each point.
                - masks: Validity mask of each feature.
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        """
        img_features = []
        masks = []
        for i, data in enumerate(
                zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)):
            img, bbox_2d_rescaled, seed_3d_depth, img_meta = data
            bbox_num = bbox_2d_rescaled.shape[0]
            seed_num = seed_3d_depth.shape[0]

            img_shape = img_meta['img_shape']
            # first reverse the data transformations
            xyz_depth = apply_3d_transformation(
                seed_3d_depth, 'DEPTH', img_meta, reverse=True)

61
62
63
64
65
            # project points from depth to image
            depth2img = xyz_depth.new_tensor(img_meta['depth2img'])
            uvz_origin = points_cam2img(xyz_depth, depth2img, True)
            z_cam = uvz_origin[..., 2]
            uv_origin = (uvz_origin[..., :2] - 1).round()
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

            # rescale 2d coordinates and bboxes
            uv_rescaled = coord_2d_transform(img_meta, uv_origin, True)
            bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled,
                                               False)

            if bbox_num == 0:
                imvote_num = seed_num * self.max_imvote_per_pixel

                # use zero features
                two_cues = torch.zeros((15, imvote_num),
                                       device=seed_3d_depth.device)
                mask_zero = torch.zeros(
                    imvote_num - seed_num, device=seed_3d_depth.device).bool()
                mask_one = torch.ones(
                    seed_num, device=seed_3d_depth.device).bool()
                mask = torch.cat([mask_one, mask_zero], dim=0)
            else:
                # expand bboxes and seeds
                bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand(
                    seed_num, -1, -1)
                seed_2d_expanded = uv_origin.view(seed_num, 1,
                                                  -1).expand(-1, bbox_num, -1)
                seed_2d_expanded_x, seed_2d_expanded_y = \
                    seed_2d_expanded.split(1, dim=-1)

                bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \
                    bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \
                    bbox_expanded.split(1, dim=-1)
                bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2
                bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2

                seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \
                    (seed_2d_expanded_x < bbox_expanded_r)
                seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \
                    (seed_2d_expanded_y < bbox_expanded_b)
                seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y

                # semantic cues, dim=class_num
                sem_cue = torch.zeros_like(bbox_expanded_conf).expand(
                    -1, -1, self.num_classes)
                sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(),
                                          bbox_expanded_conf)

                # bbox center - uv
                delta_u = bbox_expanded_midx - seed_2d_expanded_x
                delta_v = bbox_expanded_midy - seed_2d_expanded_y

                seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand(
                    -1, bbox_num, -1)

117
                z_cam = z_cam.view(seed_num, 1, 1).expand(-1, bbox_num, -1)
118
119
120
                imvote = torch.cat(
                    [delta_u, delta_v,
                     torch.zeros_like(delta_v)], dim=-1).view(-1, 3)
121
122
                imvote = imvote * z_cam.reshape(-1, 1)
                imvote = imvote @ torch.inverse(depth2img.t())
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

                # apply transformation to lifted imvotes
                imvote = apply_3d_transformation(
                    imvote, 'DEPTH', img_meta, reverse=False)

                seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape)

                # ray angle
                ray_angle = seed_3d_expanded + imvote
                ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) +
                                        EPS).unsqueeze(-1)

                # imvote lifted to 3d
                xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \
                    * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]]

                # geometric cues, dim=5
                geo_cue = torch.cat([xz, ray_angle],
                                    dim=-1).view(seed_num, -1, 5)

                two_cues = torch.cat([geo_cue, sem_cue], dim=-1)
                # mask to 0 if seed not in bbox
                two_cues = two_cues * seed_2d_in_bbox.float()

                feature_size = two_cues.shape[-1]
                # if bbox number is too small, append zeros
                if bbox_num < self.max_imvote_per_pixel:
                    append_num = self.max_imvote_per_pixel - bbox_num
                    append_zeros = torch.zeros(
                        (seed_num, append_num, 1),
                        device=seed_2d_in_bbox.device).bool()
                    seed_2d_in_bbox = torch.cat(
                        [seed_2d_in_bbox, append_zeros], dim=1)
                    append_zeros = torch.zeros(
                        (seed_num, append_num, feature_size),
                        device=two_cues.device)
                    two_cues = torch.cat([two_cues, append_zeros], dim=1)
                    append_zeros = torch.zeros((seed_num, append_num, 1),
                                               device=two_cues.device)
                    bbox_expanded_conf = torch.cat(
                        [bbox_expanded_conf, append_zeros], dim=1)

                # sort the valid seed-bbox pair according to confidence
                pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf
                # and find the largests
                mask, indices = pair_score.topk(
                    self.max_imvote_per_pixel,
                    dim=1,
                    largest=True,
                    sorted=True)

                indices_img = indices.expand(-1, -1, feature_size)
                two_cues = two_cues.gather(dim=1, index=indices_img)
                two_cues = two_cues.transpose(1, 0)
                two_cues = two_cues.reshape(-1, feature_size).transpose(
                    1, 0).contiguous()

                # since conf is ~ (0, 1), floor gives us validity
                mask = mask.floor().int()
                mask = mask.transpose(1, 0).reshape(-1).bool()

            # clear the padding
            img = img[:, :img_shape[0], :img_shape[1]]
            img_flatten = img.reshape(3, -1).float()
            img_flatten /= 255.

            # take the normalized pixel value as texture cue
190
191
192
193
            uv_rescaled[:, 0] = torch.clamp(uv_rescaled[:, 0].round(), 0,
                                            img_shape[1] - 1)
            uv_rescaled[:, 1] = torch.clamp(uv_rescaled[:, 1].round(), 0,
                                            img_shape[0] - 1)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            uv_flatten = uv_rescaled[:, 1].round() * \
                img_shape[1] + uv_rescaled[:, 0].round()
            uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long()
            txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded)
            txt_cue = txt_cue.unsqueeze(1).expand(-1,
                                                  self.max_imvote_per_pixel,
                                                  -1).reshape(3, -1)

            # append texture cue
            img_feature = torch.cat([two_cues, txt_cue], dim=0)
            img_features.append(img_feature)
            masks.append(mask)

        return torch.stack(img_features, 0), torch.stack(masks, 0)