ipm_backbone.py 11.5 KB
Newer Older
yeshenglong1's avatar
yeshenglong1 committed
1
2
import copy
import math
zhe chen's avatar
zhe chen committed
3

yeshenglong1's avatar
yeshenglong1 committed
4
5
6
7
8
9
10
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet3d.models.builder import BACKBONES
from mmdet.models import build_backbone, build_neck

zhe chen's avatar
zhe chen committed
11

yeshenglong1's avatar
yeshenglong1 committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class UpsampleBlock(nn.Module):
    def __init__(self, ins, outs):
        super(UpsampleBlock, self).__init__()
        self.gn = nn.GroupNorm(32, outs)
        self.conv = nn.Conv2d(ins, outs, kernel_size=3,
                              stride=1, padding=1)  # same
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(self.gn(x))
        x = self.upsample2x(x)

        return x

    def upsample2x(self, x):
        _, _, h, w = x.shape
zhe chen's avatar
zhe chen committed
29
        x = F.interpolate(x, size=(h * 2, w * 2),
yeshenglong1's avatar
yeshenglong1 committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
                          mode='bilinear', align_corners=True)
        return x


class Upsample(nn.Module):

    def __init__(self,
                 zoom_size=(2, 4, 8),
                 in_channels=128,
                 out_channels=128,
                 ):
        super(Upsample, self).__init__()

        self.out_channels = out_channels

        input_conv = UpsampleBlock(in_channels, out_channels)
        inter_conv = UpsampleBlock(out_channels, out_channels)

        fscale = []
        for scale_factor in zoom_size:

            layer_num = int(math.log2(scale_factor))
            if layer_num < 1:
                fscale.append(nn.Identity())
                continue

            tmp = [copy.deepcopy(input_conv), ]
zhe chen's avatar
zhe chen committed
57
            tmp += [copy.deepcopy(inter_conv) for i in range(layer_num - 1)]
yeshenglong1's avatar
yeshenglong1 committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            fscale.append(nn.Sequential(*tmp))

        self.fscale = nn.ModuleList(fscale)

    def init_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, imgs):

        rescale_i = []
        for f, img in zip(self.fscale, imgs):
            rescale_i.append(f(img))

        out = sum(rescale_i)

        return out


@BACKBONES.register_module()
class IPMEncoder(nn.Module):
    '''
    encode cam features
    '''

    def __init__(self,
                 img_backbone,
                 img_neck,
                 upsample,
                 xbound=[-30.0, 30.0, 0.5],
                 ybound=[-15.0, 15.0, 0.5],
                 zbound=[-10.0, 10.0, 20.0],
                 heights=[-1.1, 0, 0.5, 1.1],
                 pretrained=None,
                 out_channels=128,
                 num_cam=6,
                 use_lidar=False,
                 use_image=True,
                 lidar_dim=128,
                 ):
        super(IPMEncoder, self).__init__()
        self.x_bound = xbound
        self.y_bound = ybound
        self.heights = heights

        self.num_cam = num_cam

        num_x = int((xbound[1] - xbound[0]) / xbound[2])
        num_y = int((ybound[1] - ybound[0]) / ybound[2])

        self.img_backbone = build_backbone(img_backbone)
        self.img_neck = build_neck(img_neck)
        self.upsample = Upsample(**upsample)

        self.use_image = use_image
        self.use_lidar = use_lidar
        if self.use_lidar:
            self.pp = PointPillarEncoder(lidar_dim, xbound, ybound, zbound)

zhe chen's avatar
zhe chen committed
120
121
122
            self.outconvs = \
                nn.Conv2d((self.upsample.out_channels + 3) * len(heights), out_channels // 2,
                          kernel_size=3, stride=1, padding=1)  # same
yeshenglong1's avatar
yeshenglong1 committed
123
            if self.use_image:
zhe chen's avatar
zhe chen committed
124
                _out_channels = out_channels // 2
yeshenglong1's avatar
yeshenglong1 committed
125
126
127
            else:
                _out_channels = out_channels

zhe chen's avatar
zhe chen committed
128
129
130
            self.outconvs_lidar = \
                nn.Conv2d(lidar_dim, _out_channels,
                          kernel_size=3, stride=1, padding=1)  # same
yeshenglong1's avatar
yeshenglong1 committed
131
        else:
zhe chen's avatar
zhe chen committed
132
133
134
            self.outconvs = \
                nn.Conv2d((self.upsample.out_channels + 3) * len(heights), out_channels,
                          kernel_size=3, stride=1, padding=1)  # same
yeshenglong1's avatar
yeshenglong1 committed
135
136
137
138
139
140
141

        self.init_weights(pretrained=pretrained)

        # bev_plane
        bev_planes = [construct_plane_grid(
            xbound, ybound, h) for h in self.heights]
        self.register_buffer('bev_planes', torch.stack(
zhe chen's avatar
zhe chen committed
142
            bev_planes), )  # nlvl,bH,bW,2
yeshenglong1's avatar
yeshenglong1 committed
143
144
145
146
147
148
149
150
151
152
153
154
155

        self.masked_embeds = nn.Embedding(len(heights), out_channels)

    def init_weights(self, pretrained=None):
        """Initialize model weights."""

        self.img_backbone.init_weights()
        self.img_neck.init_weights()
        self.upsample.init_weights()

        for p in self.outconvs.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
zhe chen's avatar
zhe chen committed
156

yeshenglong1's avatar
yeshenglong1 committed
157
158
159
160
        if self.use_lidar:
            for p in self.outconvs_lidar.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)
zhe chen's avatar
zhe chen committed
161

yeshenglong1's avatar
yeshenglong1 committed
162
163
164
165
166
167
168
169
170
            for p in self.pp.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

    def extract_img_feat(self, imgs):
        '''
            Extract image feaftures and sum up into one pic
            Args:
                imgs: B, n_cam, C, iH, iW
zhe chen's avatar
zhe chen committed
171
            Returns:
yeshenglong1's avatar
yeshenglong1 committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                img_feat: B * n_cam, C, H, W
        '''

        B, n_cam, C, iH, iW = imgs.shape
        imgs = imgs.view(B * n_cam, C, iH, iW)

        img_feats = self.img_backbone(imgs)

        # reduce the channel dim
        img_feats = self.img_neck(img_feats)

        # fuse four feature map
        img_feat = self.upsample(img_feats)

        return img_feat

    def forward(self, imgs, img_metas, *args, points=None, **kwargs):
        '''
zhe chen's avatar
zhe chen committed
190
            Args:
yeshenglong1's avatar
yeshenglong1 committed
191
192
                imgs: torch.Tensor of shape [B, N, 3, H, W]
                    N: number of cams
zhe chen's avatar
zhe chen committed
193
                img_metas:
yeshenglong1's avatar
yeshenglong1 committed
194
                    # N=6, ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
zhe chen's avatar
zhe chen committed
195
                    ego2cam: [B, N, 4, 4]
yeshenglong1's avatar
yeshenglong1 committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                    cam_intrinsics: [B, N, 3, 3]
                    cam2ego_rotations: [B, N, 3, 3]
                    cam2ego_translations: [B, N, 3]
                    ...
            Outs:
                bev_feature: torch.Tensor of shape [B, C*nlvl, bH, bW]
        '''

        if self.use_image:
            self.B = imgs.shape[0]

            # Get transform matrix
            ego2cam = []
            for img_meta in img_metas:
                ego2cam.append(img_meta['ego2img'])
            img_shape = imgs.shape[-2:]

            ego2cam = np.asarray(ego2cam)
            # Image backbone
            img_feats = self.extract_img_feat(imgs)

            # IPM
            bev_feat, bev_feat_mask = self.ipm(img_feats, ego2cam, img_shape)

            # multi level into a same
            bev_feat = bev_feat.flatten(1, 2)
            bev_feat = self.outconvs(bev_feat)

        if self.use_lidar:
            lidar_feat = self.get_lidar_feature(points)
            if self.use_image:
zhe chen's avatar
zhe chen committed
227
                bev_feat = torch.cat([bev_feat, lidar_feat], dim=1)
yeshenglong1's avatar
yeshenglong1 committed
228
229
230
231
232
233
234
            else:
                bev_feat = lidar_feat

        return bev_feat

    def ipm(self, cam_feat, ego2cam, img_shape):
        '''
zhe chen's avatar
zhe chen committed
235
            inverse project
yeshenglong1's avatar
yeshenglong1 committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            Args:
                cam_feat: B*ncam, C, cH, cW
                img_shape: tuple(H, W)
            Returns:
                project_feat: B, C, nlvl, bH, bW
                bev_feat_mask: B, 1, nlvl, bH, bW
        '''
        C = cam_feat.shape[1]
        bev_grid = self.bev_planes.unsqueeze(0).repeat(self.B, 1, 1, 1, 1)
        nlvl, bH, bW = bev_grid.shape[1:4]
        bev_grid = bev_grid.flatten(1, 3)  # B, nlvl*W*H, 3

        # Find points in cam coords
        # bev_grid_pos: B*ncam, nlvl*bH*bW, 2
        bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape)
        # B*cam, nlvl*bH, bW, 2
zhe chen's avatar
zhe chen committed
252
        bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl * bH, bW))
yeshenglong1's avatar
yeshenglong1 committed
253
254
255
256
257
258
259
260
261
262
263

        # project feat from 2D to bev plane
        projected_feature = F.grid_sample(
            cam_feat, bev_grid_pos, align_corners=False).view(self.B, -1, C, nlvl, bH, bW)  # B,cam,C,nlvl,bH,bW

        # B,cam,nlvl,bH,bW
        bev_feat_mask = bev_cam_mask.unflatten(-1, (nlvl, bH, bW))

        # eliminate the ncam
        # The bev feature is the sum of the 6 cameras
        bev_feat_mask = bev_feat_mask.unsqueeze(2)
zhe chen's avatar
zhe chen committed
264
        projected_feature = (projected_feature * bev_feat_mask).sum(1)
yeshenglong1's avatar
yeshenglong1 committed
265
266
267
        num_feat = bev_feat_mask.sum(1)

        projected_feature = projected_feature / \
zhe chen's avatar
zhe chen committed
268
                            num_feat.masked_fill(num_feat == 0, 1)
yeshenglong1's avatar
yeshenglong1 committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        # concatenate a position information
        # projected_feature: B, bH, bW, nlvl, C+3
        bev_grid = bev_grid.view(self.B, nlvl, bH, bW,
                                 3).permute(0, 4, 1, 2, 3)
        projected_feature = torch.cat(
            (projected_feature, bev_grid), dim=1)

        return projected_feature, bev_feat_mask.sum(1) > 0

    def get_lidar_feature(self, points):
        ptensor, pmask = points
        lidar_feature = self.pp(ptensor, pmask)

        # bev_grid = self.bev_planes[...,:-1].unsqueeze(0).repeat(self.B, 1, 1, 1, 1)
        # bev_grid = bev_grid[:,0]

        # bev_grid = bev_grid.permute(0, 3, 1, 2)
        # lidar_feature = torch.cat(
        #     (lidar_feature, bev_grid), dim=1)
zhe chen's avatar
zhe chen committed
289

yeshenglong1's avatar
yeshenglong1 committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        lidar_feature = self.outconvs_lidar(lidar_feature)

        return lidar_feature


def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32):
    '''
        Returns:
            plane: H, W, 3
    '''

    xmin, xmax = xbound[0], xbound[1]
    num_x = int((xbound[1] - xbound[0]) / xbound[2])
    ymin, ymax = ybound[0], ybound[1]
    num_y = int((ybound[1] - ybound[0]) / ybound[2])

    x = torch.linspace(xmin, xmax, num_x, dtype=dtype)
    y = torch.linspace(ymin, ymax, num_y, dtype=dtype)

    # [num_y, num_x]
    y, x = torch.meshgrid(y, x)

    z = torch.ones_like(x) * height

    # [num_y, num_x, 3]
    plane = torch.stack([x, y, z], dim=-1)

    return plane


def get_campos(reference_points, ego2cam, img_shape):
    '''
        Find the each refence point's corresponding pixel in each camera
zhe chen's avatar
zhe chen committed
323
        Args:
yeshenglong1's avatar
yeshenglong1 committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            reference_points: [B, num_query, 3]
            ego2cam: (B, num_cam, 4, 4)
        Outs:
            reference_points_cam: (B*num_cam, num_query, 2)
            mask:  (B, num_cam, num_query)
            num_query == W*H
    '''

    ego2cam = reference_points.new_tensor(ego2cam)  # (B, N, 4, 4)
    reference_points = reference_points.clone()

    B, num_query = reference_points.shape[:2]
    num_cam = ego2cam.shape[1]

    # reference_points (B, num_queries, 4)
    reference_points = torch.cat(
        (reference_points, torch.ones_like(reference_points[..., :1])), -1)
    reference_points = reference_points.view(
        B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)

    ego2cam = ego2cam.view(
        B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)

    # reference_points_cam (B, num_cam, num_queries, 4)
    reference_points_cam = (ego2cam @ reference_points).squeeze(-1)

    eps = 1e-9
    mask = (reference_points_cam[..., 2:3] > eps)

zhe chen's avatar
zhe chen committed
353
    reference_points_cam = \
yeshenglong1's avatar
yeshenglong1 committed
354
355
356
357
358
359
360
361
362
363
        reference_points_cam[..., 0:2] / \
        reference_points_cam[..., 2:3] + eps

    reference_points_cam[..., 0] /= img_shape[1]
    reference_points_cam[..., 1] /= img_shape[0]

    # from 0~1 to -1~1
    reference_points_cam = (reference_points_cam - 0.5) * 2

    mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
zhe chen's avatar
zhe chen committed
364
365
366
            & (reference_points_cam[..., 0:1] < 1.0)
            & (reference_points_cam[..., 1:2] > -1.0)
            & (reference_points_cam[..., 1:2] < 1.0))
yeshenglong1's avatar
yeshenglong1 committed
367
368
369

    # (B, num_cam, num_query)
    mask = mask.view(B, num_cam, num_query)
zhe chen's avatar
zhe chen committed
370
    reference_points_cam = reference_points_cam.view(B * num_cam, num_query, 2)
yeshenglong1's avatar
yeshenglong1 committed
371
372
373
374
375
376
377
378
379
380

    return reference_points_cam, mask


def _test():
    pass


if __name__ == '__main__':
    _test()