"vscode:/vscode.git/clone" did not exist on "6ab593ec014d2c78526e20e0d85c532b8eb1444e"
pointnet2.py 13.1 KB
Newer Older
1
import numpy as np
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
6

7
8
import dgl
import dgl.function as fn
9
10
11
from dgl.geometry import (
    farthest_point_sampler,
)  # dgl.geometry.pytorch -> dgl.geometry
12

13
"""
14
15
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
16
17
"""

18
19

def square_distance(src, dst):
20
    """
21
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
22
    """
23
24
25
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
26
27
    dist += torch.sum(src**2, -1).view(B, N, 1)
    dist += torch.sum(dst**2, -1).view(B, 1, M)
28
29
    return dist

30

31
def index_points(points, idx):
32
    """
33
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
34
    """
35
36
37
38
39
40
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
41
42
43
44
45
46
    batch_indices = (
        torch.arange(B, dtype=torch.long)
        .to(device)
        .view(view_shape)
        .repeat(repeat_shape)
    )
47
48
49
    new_points = points[batch_indices, idx, :]
    return new_points

50

51
class FixedRadiusNearNeighbors(nn.Module):
52
    """
53
    Ball Query - Find the neighbors with-in a fixed radius
54
55
    """

56
57
58
59
60
61
    def __init__(self, radius, n_neighbor):
        super(FixedRadiusNearNeighbors, self).__init__()
        self.radius = radius
        self.n_neighbor = n_neighbor

    def forward(self, pos, centroids):
62
        """
63
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
64
        """
65
66
67
68
        device = pos.device
        B, N, _ = pos.shape
        center_pos = index_points(pos, centroids)
        _, S, _ = center_pos.shape
69
70
71
72
73
74
        group_idx = (
            torch.arange(N, dtype=torch.long)
            .to(device)
            .view(1, 1, N)
            .repeat([B, S, 1])
        )
75
        sqrdists = square_distance(center_pos, pos)
76
77
78
79
80
        group_idx[sqrdists > self.radius**2] = N
        group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
        group_first = (
            group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
        )
81
82
83
84
        mask = group_idx == N
        group_idx[mask] = group_first[mask]
        return group_idx

85

86
class FixedRadiusNNGraph(nn.Module):
87
    """
88
    Build NN graph
89
90
    """

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def __init__(self, radius, n_neighbor):
        super(FixedRadiusNNGraph, self).__init__()
        self.radius = radius
        self.n_neighbor = n_neighbor
        self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)

    def forward(self, pos, centroids, feat=None):
        dev = pos.device
        group_idx = self.frnn(pos, centroids)
        B, N, _ = pos.shape
        glist = []
        for i in range(B):
            center = torch.zeros((N)).to(dev)
            center[centroids[i]] = 1
            src = group_idx[i].contiguous().view(-1)
            dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)

            unified = torch.cat([src, dst])
            uniq, inv_idx = torch.unique(unified, return_inverse=True)
110
111
            src_idx = inv_idx[: src.shape[0]]
            dst_idx = inv_idx[src.shape[0] :]
112

113
            g = dgl.graph((src_idx, dst_idx))
114
115
            g.ndata["pos"] = pos[i][uniq]
            g.ndata["center"] = center[uniq]
116
            if feat is not None:
117
                g.ndata["feat"] = feat[i][uniq]
118
119
120
121
            glist.append(g)
        bg = dgl.batch(glist)
        return bg

122

123
class RelativePositionMessage(nn.Module):
124
    """
125
    Compute the input feature from neighbors
126
127
    """

128
129
130
131
132
    def __init__(self, n_neighbor):
        super(RelativePositionMessage, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, edges):
133
134
135
        pos = edges.src["pos"] - edges.dst["pos"]
        if "feat" in edges.src:
            res = torch.cat([pos, edges.src["feat"]], 1)
136
137
        else:
            res = pos
138
139
        return {"agg_feat": res}

140
141

class PointNetConv(nn.Module):
142
    """
143
    Feature aggregation
144
145
    """

146
147
148
149
150
151
    def __init__(self, sizes, batch_size):
        super(PointNetConv, self).__init__()
        self.batch_size = batch_size
        self.conv = nn.ModuleList()
        self.bn = nn.ModuleList()
        for i in range(1, len(sizes)):
152
            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
153
154
155
            self.bn.append(nn.BatchNorm2d(sizes[i]))

    def forward(self, nodes):
156
157
158
159
160
161
        shape = nodes.mailbox["agg_feat"].shape
        h = (
            nodes.mailbox["agg_feat"]
            .view(self.batch_size, -1, shape[1], shape[2])
            .permute(0, 3, 2, 1)
        )
162
163
164
165
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
166
        h = torch.max(h, 2)[0]
167
168
        feat_dim = h.shape[1]
        h = h.permute(0, 2, 1).reshape(-1, feat_dim)
169
        return {"new_feat": h}
170

171
    def group_all(self, pos, feat):
172
        """
173
        Feature aggregation and pooling for the non-sampling layer
174
        """
175
176
177
178
        if feat is not None:
            h = torch.cat([pos, feat], 2)
        else:
            h = pos
179
180
181
182
        B, N, D = h.shape
        _, _, C = pos.shape
        new_pos = torch.zeros(B, 1, C)
        h = h.permute(0, 2, 1).view(B, -1, N, 1)
183
184
185
186
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
187
188
        h = torch.max(h[:, :, :, 0], 2)[0]  # [B,D]
        return new_pos, h
189

190

191
192
193
194
class SAModule(nn.Module):
    """
    The Set Abstraction Layer
    """
195
196
197
198
199
200
201
202
203
204

    def __init__(
        self,
        npoints,
        batch_size,
        radius,
        mlp_sizes,
        n_neighbor=64,
        group_all=False,
    ):
205
206
207
        super(SAModule, self).__init__()
        self.group_all = group_all
        if not group_all:
208
            self.npoints = npoints
209
210
211
212
213
214
215
216
217
            self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
        self.message = RelativePositionMessage(n_neighbor)
        self.conv = PointNetConv(mlp_sizes, batch_size)
        self.batch_size = batch_size

    def forward(self, pos, feat):
        if self.group_all:
            return self.conv.group_all(pos, feat)

218
        centroids = farthest_point_sampler(pos, self.npoints)
219
220
        g = self.frnn_graph(pos, centroids, feat)
        g.update_all(self.message, self.conv)
221

222
223
224
225
226
        mask = g.ndata["center"] == 1
        pos_dim = g.ndata["pos"].shape[-1]
        feat_dim = g.ndata["new_feat"].shape[-1]
        pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
        feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
227
228
        return pos_res, feat_res

229

230
231
232
233
class SAMSGModule(nn.Module):
    """
    The Set Abstraction Multi-Scale grouping Layer
    """
234
235
236
237

    def __init__(
        self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list
    ):
238
239
240
241
        super(SAMSGModule, self).__init__()
        self.batch_size = batch_size
        self.group_size = len(radius_list)

242
        self.npoints = npoints
243
244
245
246
        self.frnn_graph_list = nn.ModuleList()
        self.message_list = nn.ModuleList()
        self.conv_list = nn.ModuleList()
        for i in range(self.group_size):
247
248
249
250
251
252
            self.frnn_graph_list.append(
                FixedRadiusNNGraph(radius_list[i], n_neighbor_list[i])
            )
            self.message_list.append(
                RelativePositionMessage(n_neighbor_list[i])
            )
253
254
255
            self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))

    def forward(self, pos, feat):
256
        centroids = farthest_point_sampler(pos, self.npoints)
257
        feat_res_list = []
258

259
260
261
        for i in range(self.group_size):
            g = self.frnn_graph_list[i](pos, centroids, feat)
            g.update_all(self.message_list[i], self.conv_list[i])
262
263
264
            mask = g.ndata["center"] == 1
            pos_dim = g.ndata["pos"].shape[-1]
            feat_dim = g.ndata["new_feat"].shape[-1]
265
            if i == 0:
266
267
268
269
270
271
                pos_res = g.ndata["pos"][mask].view(
                    self.batch_size, -1, pos_dim
                )
            feat_res = g.ndata["new_feat"][mask].view(
                self.batch_size, -1, feat_dim
            )
272
            feat_res_list.append(feat_res)
273

274
275
276
        feat_res = torch.cat(feat_res_list, 2)
        return pos_res, feat_res

277

278
279
280
281
class PointNet2FP(nn.Module):
    """
    The Feature Propagation Layer
    """
282

283
284
285
286
287
288
289
    def __init__(self, input_dims, sizes):
        super(PointNet2FP, self).__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        sizes = [input_dims] + sizes
        for i in range(1, len(sizes)):
290
            self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            self.bns.append(nn.BatchNorm1d(sizes[i]))

    def forward(self, x1, x2, feat1, feat2):
        """
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
            Input:
                x1: input points position data, [B, N, C]
                x2: sampled input points position data, [B, S, C]
                feat1: input points data, [B, N, D]
                feat2: input points data, [B, S, D]
            Return:
                new_feat: upsampled points data, [B, D', N]
        """
        B, N, C = x1.shape
        _, S, _ = x2.shape

        if S == 1:
            interpolated_feat = feat2.repeat(1, N, 1)
        else:
            dists = square_distance(x1, x2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
317
318
319
            interpolated_feat = torch.sum(
                index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
            )
320
321
322
323
324
325
326
327
328
329
330
331
332

        if feat1 is not None:
            new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
        else:
            new_feat = interpolated_feat

        new_feat = new_feat.permute(0, 2, 1)  # [B, D, S]
        for i, conv in enumerate(self.convs):
            bn = self.bns[i]
            new_feat = F.relu(bn(conv(new_feat)))
        return new_feat


333
class PointNet2SSGCls(nn.Module):
334
335
336
    def __init__(
        self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
    ):
337
338
339
        super(PointNet2SSGCls, self).__init__()
        self.input_dims = input_dims

340
341
342
343
344
345
346
347
348
        self.sa_module1 = SAModule(
            512, batch_size, 0.2, [input_dims, 64, 64, 128]
        )
        self.sa_module2 = SAModule(
            128, batch_size, 0.4, [128 + 3, 128, 128, 256]
        )
        self.sa_module3 = SAModule(
            None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
        )
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

        self.mlp1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(dropout_prob)

        self.mlp2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(dropout_prob)

        self.mlp_out = nn.Linear(256, output_classes)

    def forward(self, x):
        if x.shape[-1] > 3:
            pos = x[:, :, :3]
            feat = x[:, :, 3:]
        else:
            pos = x
            feat = None
        pos, feat = self.sa_module1(pos, feat)
        pos, feat = self.sa_module2(pos, feat)
369
        _, h = self.sa_module3(pos, feat)
370
371
372
373
374
375
376
377
378
379
380
381
382

        h = self.mlp1(h)
        h = self.bn1(h)
        h = F.relu(h)
        h = self.drop1(h)
        h = self.mlp2(h)
        h = self.bn2(h)
        h = F.relu(h)
        h = self.drop2(h)

        out = self.mlp_out(h)
        return out

383

384
class PointNet2MSGCls(nn.Module):
385
386
387
    def __init__(
        self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
    ):
388
389
390
        super(PointNet2MSGCls, self).__init__()
        self.input_dims = input_dims

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        self.sa_msg_module1 = SAMSGModule(
            512,
            batch_size,
            [0.1, 0.2, 0.4],
            [16, 32, 128],
            [
                [input_dims, 32, 32, 64],
                [input_dims, 64, 64, 128],
                [input_dims, 64, 96, 128],
            ],
        )
        self.sa_msg_module2 = SAMSGModule(
            128,
            batch_size,
            [0.2, 0.4, 0.8],
            [32, 64, 128],
            [
                [320 + 3, 64, 64, 128],
                [320 + 3, 128, 128, 256],
                [320 + 3, 128, 128, 256],
            ],
        )
        self.sa_module3 = SAModule(
            None, batch_size, None, [640 + 3, 256, 512, 1024], group_all=True
        )
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

        self.mlp1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(dropout_prob)

        self.mlp2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(dropout_prob)

        self.mlp_out = nn.Linear(256, output_classes)

    def forward(self, x):
        if x.shape[-1] > 3:
            pos = x[:, :, :3]
            feat = x[:, :, 3:]
        else:
            pos = x
            feat = None
        pos, feat = self.sa_msg_module1(pos, feat)
        pos, feat = self.sa_msg_module2(pos, feat)
436
        _, h = self.sa_module3(pos, feat)
437
438
439
440
441
442
443
444
445
446
447
448

        h = self.mlp1(h)
        h = self.bn1(h)
        h = F.relu(h)
        h = self.drop1(h)
        h = self.mlp2(h)
        h = self.bn2(h)
        h = F.relu(h)
        h = self.drop2(h)

        out = self.mlp_out(h)
        return out