helper.py 8.55 KB
Newer Older
esang's avatar
esang committed
1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
4

esang's avatar
esang committed
5
6
7
import dgl
from dgl.geometry import farthest_point_sampler

8
"""
esang's avatar
esang committed
9
10
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
11
"""
esang's avatar
esang committed
12
13
14


def square_distance(src, dst):
15
    """
esang's avatar
esang committed
16
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
17
    """
esang's avatar
esang committed
18
19
20
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
21
22
    dist += torch.sum(src**2, -1).view(B, N, 1)
    dist += torch.sum(dst**2, -1).view(B, 1, M)
esang's avatar
esang committed
23
24
25
26
    return dist


def index_points(points, idx):
27
    """
esang's avatar
esang committed
28
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
29
    """
esang's avatar
esang committed
30
31
32
33
34
35
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
36
37
38
39
40
41
    batch_indices = (
        torch.arange(B, dtype=torch.long)
        .to(device)
        .view(view_shape)
        .repeat(repeat_shape)
    )
esang's avatar
esang committed
42
43
44
45
46
    new_points = points[batch_indices, idx, :]
    return new_points


class KNearNeighbors(nn.Module):
47
    """
esang's avatar
esang committed
48
    Find the k nearest neighbors
49
    """
esang's avatar
esang committed
50
51
52
53
54
55

    def __init__(self, n_neighbor):
        super(KNearNeighbors, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, pos, centroids):
56
        """
esang's avatar
esang committed
57
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
58
        """
esang's avatar
esang committed
59
60
        center_pos = index_points(pos, centroids)
        sqrdists = square_distance(center_pos, pos)
61
        group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
esang's avatar
esang committed
62
63
64
65
        return group_idx


class KNNGraphBuilder(nn.Module):
66
    """
esang's avatar
esang committed
67
    Build NN graph
68
    """
esang's avatar
esang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    def __init__(self, n_neighbor):
        super(KNNGraphBuilder, self).__init__()
        self.n_neighbor = n_neighbor
        self.knn = KNearNeighbors(n_neighbor)

    def forward(self, pos, centroids, feat=None):
        dev = pos.device
        group_idx = self.knn(pos, centroids)
        B, N, _ = pos.shape
        glist = []
        for i in range(B):
            center = torch.zeros((N)).to(dev)
            center[centroids[i]] = 1
            src = group_idx[i].contiguous().view(-1)
84
85
86
87
88
89
90
91
            dst = (
                centroids[i]
                .view(-1, 1)
                .repeat(
                    1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
                )
                .view(-1)
            )
esang's avatar
esang committed
92
93
94

            unified = torch.cat([src, dst])
            uniq, inv_idx = torch.unique(unified, return_inverse=True)
95
96
            src_idx = inv_idx[: src.shape[0]]
            dst_idx = inv_idx[src.shape[0] :]
esang's avatar
esang committed
97
98

            g = dgl.graph((src_idx, dst_idx))
99
100
            g.ndata["pos"] = pos[i][uniq]
            g.ndata["center"] = center[uniq]
esang's avatar
esang committed
101
            if feat is not None:
102
                g.ndata["feat"] = feat[i][uniq]
esang's avatar
esang committed
103
104
105
106
107
108
            glist.append(g)
        bg = dgl.batch(glist)
        return bg


class RelativePositionMessage(nn.Module):
109
    """
esang's avatar
esang committed
110
    Compute the input feature from neighbors
111
    """
esang's avatar
esang committed
112
113
114
115
116
117

    def __init__(self, n_neighbor):
        super(RelativePositionMessage, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, edges):
118
119
120
        pos = edges.src["pos"] - edges.dst["pos"]
        if "feat" in edges.src:
            res = torch.cat([pos, edges.src["feat"]], 1)
esang's avatar
esang committed
121
122
        else:
            res = pos
123
        return {"agg_feat": res}
esang's avatar
esang committed
124
125
126


class KNNConv(nn.Module):
127
    """
esang's avatar
esang committed
128
    Feature aggregation
129
    """
esang's avatar
esang committed
130
131
132
133
134
135
136

    def __init__(self, sizes, batch_size):
        super(KNNConv, self).__init__()
        self.batch_size = batch_size
        self.conv = nn.ModuleList()
        self.bn = nn.ModuleList()
        for i in range(1, len(sizes)):
137
            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
esang's avatar
esang committed
138
139
140
            self.bn.append(nn.BatchNorm2d(sizes[i]))

    def forward(self, nodes):
141
142
143
144
145
146
        shape = nodes.mailbox["agg_feat"].shape
        h = (
            nodes.mailbox["agg_feat"]
            .view(self.batch_size, -1, shape[1], shape[2])
            .permute(0, 3, 2, 1)
        )
esang's avatar
esang committed
147
148
149
150
151
152
153
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
        h = torch.max(h, 2)[0]
        feat_dim = h.shape[1]
        h = h.permute(0, 2, 1).reshape(-1, feat_dim)
154
        return {"new_feat": h}
esang's avatar
esang committed
155
156

    def group_all(self, pos, feat):
157
        """
esang's avatar
esang committed
158
        Feature aggregation and pooling for the non-sampling layer
159
        """
esang's avatar
esang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        if feat is not None:
            h = torch.cat([pos, feat], 2)
        else:
            h = pos
        B, N, D = h.shape
        _, _, C = pos.shape
        new_pos = torch.zeros(B, 1, C)
        h = h.permute(0, 2, 1).view(B, -1, N, 1)
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
        h = torch.max(h[:, :, :, 0], 2)[0]  # [B,D]
        return new_pos, h


class TransitionDown(nn.Module):
    """
    The Transition Down Module
    """

    def __init__(self, n_points, batch_size, mlp_sizes, n_neighbors=64):
        super(TransitionDown, self).__init__()
        self.n_points = n_points
        self.frnn_graph = KNNGraphBuilder(n_neighbors)
        self.message = RelativePositionMessage(n_neighbors)
        self.conv = KNNConv(mlp_sizes, batch_size)
        self.batch_size = batch_size

    def forward(self, pos, feat):
        centroids = farthest_point_sampler(pos, self.n_points)
        g = self.frnn_graph(pos, centroids, feat)
        g.update_all(self.message, self.conv)

194
195
196
197
198
        mask = g.ndata["center"] == 1
        pos_dim = g.ndata["pos"].shape[-1]
        feat_dim = g.ndata["new_feat"].shape[-1]
        pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
        feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
esang's avatar
esang committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        return pos_res, feat_res


class FeaturePropagation(nn.Module):
    """
    The FeaturePropagation Layer
    """

    def __init__(self, input_dims, sizes):
        super(FeaturePropagation, self).__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        sizes = [input_dims] + sizes
        for i in range(1, len(sizes)):
214
            self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
esang's avatar
esang committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            self.bns.append(nn.BatchNorm1d(sizes[i]))

    def forward(self, x1, x2, feat1, feat2):
        """
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
            Input:
                x1: input points position data, [B, N, C]
                x2: sampled input points position data, [B, S, C]
                feat1: input points data, [B, N, D]
                feat2: input points data, [B, S, D]
            Return:
                new_feat: upsampled points data, [B, D', N]
        """
        B, N, C = x1.shape
        _, S, _ = x2.shape

        if S == 1:
            interpolated_feat = feat2.repeat(1, N, 1)
        else:
            dists = square_distance(x1, x2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
241
242
243
            interpolated_feat = torch.sum(
                index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
            )
esang's avatar
esang committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        if feat1 is not None:
            new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
        else:
            new_feat = interpolated_feat

        new_feat = new_feat.permute(0, 2, 1)  # [B, D, S]
        for i, conv in enumerate(self.convs):
            bn = self.bns[i]
            new_feat = F.relu(bn(conv(new_feat)))
        return new_feat


class SwapAxes(nn.Module):
    def __init__(self, dim1=1, dim2=2):
        super(SwapAxes, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


class TransitionUp(nn.Module):
    """
    The Transition Up Module
    """

    def __init__(self, dim1, dim2, dim_out):

        super(TransitionUp, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(dim1, dim_out),
            SwapAxes(),
            nn.BatchNorm1d(dim_out),  # TODO
            SwapAxes(),
            nn.ReLU(),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(dim2, dim_out),
            SwapAxes(),
            nn.BatchNorm1d(dim_out),  # TODO
            SwapAxes(),
            nn.ReLU(),
        )
        self.fp = FeaturePropagation(-1, [])

    def forward(self, pos1, feat1, pos2, feat2):
        h1 = self.fc1(feat1)
        h2 = self.fc2(feat2)
        h1 = self.fp(pos2, pos1, None, h1).transpose(1, 2)
        return h1 + h2