helper.py 8.55 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
esang's avatar
esang committed
2
3
4
5
6
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.geometry import farthest_point_sampler

7
"""
esang's avatar
esang committed
8
9
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
10
"""
esang's avatar
esang committed
11
12
13


def square_distance(src, dst):
14
    """
esang's avatar
esang committed
15
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
16
    """
esang's avatar
esang committed
17
18
19
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
20
21
    dist += torch.sum(src**2, -1).view(B, N, 1)
    dist += torch.sum(dst**2, -1).view(B, 1, M)
esang's avatar
esang committed
22
23
24
25
    return dist


def index_points(points, idx):
26
    """
esang's avatar
esang committed
27
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
28
    """
esang's avatar
esang committed
29
30
31
32
33
34
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
35
36
37
38
39
40
    batch_indices = (
        torch.arange(B, dtype=torch.long)
        .to(device)
        .view(view_shape)
        .repeat(repeat_shape)
    )
esang's avatar
esang committed
41
42
43
44
45
    new_points = points[batch_indices, idx, :]
    return new_points


class KNearNeighbors(nn.Module):
46
    """
esang's avatar
esang committed
47
    Find the k nearest neighbors
48
    """
esang's avatar
esang committed
49
50
51
52
53
54

    def __init__(self, n_neighbor):
        super(KNearNeighbors, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, pos, centroids):
55
        """
esang's avatar
esang committed
56
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
57
        """
esang's avatar
esang committed
58
59
        center_pos = index_points(pos, centroids)
        sqrdists = square_distance(center_pos, pos)
60
        group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
esang's avatar
esang committed
61
62
63
64
        return group_idx


class KNNGraphBuilder(nn.Module):
65
    """
esang's avatar
esang committed
66
    Build NN graph
67
    """
esang's avatar
esang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    def __init__(self, n_neighbor):
        super(KNNGraphBuilder, self).__init__()
        self.n_neighbor = n_neighbor
        self.knn = KNearNeighbors(n_neighbor)

    def forward(self, pos, centroids, feat=None):
        dev = pos.device
        group_idx = self.knn(pos, centroids)
        B, N, _ = pos.shape
        glist = []
        for i in range(B):
            center = torch.zeros((N)).to(dev)
            center[centroids[i]] = 1
            src = group_idx[i].contiguous().view(-1)
83
84
85
86
87
88
89
90
            dst = (
                centroids[i]
                .view(-1, 1)
                .repeat(
                    1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
                )
                .view(-1)
            )
esang's avatar
esang committed
91
92
93

            unified = torch.cat([src, dst])
            uniq, inv_idx = torch.unique(unified, return_inverse=True)
94
95
            src_idx = inv_idx[: src.shape[0]]
            dst_idx = inv_idx[src.shape[0] :]
esang's avatar
esang committed
96
97

            g = dgl.graph((src_idx, dst_idx))
98
99
            g.ndata["pos"] = pos[i][uniq]
            g.ndata["center"] = center[uniq]
esang's avatar
esang committed
100
            if feat is not None:
101
                g.ndata["feat"] = feat[i][uniq]
esang's avatar
esang committed
102
103
104
105
106
107
            glist.append(g)
        bg = dgl.batch(glist)
        return bg


class RelativePositionMessage(nn.Module):
108
    """
esang's avatar
esang committed
109
    Compute the input feature from neighbors
110
    """
esang's avatar
esang committed
111
112
113
114
115
116

    def __init__(self, n_neighbor):
        super(RelativePositionMessage, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, edges):
117
118
119
        pos = edges.src["pos"] - edges.dst["pos"]
        if "feat" in edges.src:
            res = torch.cat([pos, edges.src["feat"]], 1)
esang's avatar
esang committed
120
121
        else:
            res = pos
122
        return {"agg_feat": res}
esang's avatar
esang committed
123
124
125


class KNNConv(nn.Module):
126
    """
esang's avatar
esang committed
127
    Feature aggregation
128
    """
esang's avatar
esang committed
129
130
131
132
133
134
135

    def __init__(self, sizes, batch_size):
        super(KNNConv, self).__init__()
        self.batch_size = batch_size
        self.conv = nn.ModuleList()
        self.bn = nn.ModuleList()
        for i in range(1, len(sizes)):
136
            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
esang's avatar
esang committed
137
138
139
            self.bn.append(nn.BatchNorm2d(sizes[i]))

    def forward(self, nodes):
140
141
142
143
144
145
        shape = nodes.mailbox["agg_feat"].shape
        h = (
            nodes.mailbox["agg_feat"]
            .view(self.batch_size, -1, shape[1], shape[2])
            .permute(0, 3, 2, 1)
        )
esang's avatar
esang committed
146
147
148
149
150
151
152
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
        h = torch.max(h, 2)[0]
        feat_dim = h.shape[1]
        h = h.permute(0, 2, 1).reshape(-1, feat_dim)
153
        return {"new_feat": h}
esang's avatar
esang committed
154
155

    def group_all(self, pos, feat):
156
        """
esang's avatar
esang committed
157
        Feature aggregation and pooling for the non-sampling layer
158
        """
esang's avatar
esang committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        if feat is not None:
            h = torch.cat([pos, feat], 2)
        else:
            h = pos
        B, N, D = h.shape
        _, _, C = pos.shape
        new_pos = torch.zeros(B, 1, C)
        h = h.permute(0, 2, 1).view(B, -1, N, 1)
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
        h = torch.max(h[:, :, :, 0], 2)[0]  # [B,D]
        return new_pos, h


class TransitionDown(nn.Module):
    """
    The Transition Down Module
    """

    def __init__(self, n_points, batch_size, mlp_sizes, n_neighbors=64):
        super(TransitionDown, self).__init__()
        self.n_points = n_points
        self.frnn_graph = KNNGraphBuilder(n_neighbors)
        self.message = RelativePositionMessage(n_neighbors)
        self.conv = KNNConv(mlp_sizes, batch_size)
        self.batch_size = batch_size

    def forward(self, pos, feat):
        centroids = farthest_point_sampler(pos, self.n_points)
        g = self.frnn_graph(pos, centroids, feat)
        g.update_all(self.message, self.conv)

193
194
195
196
197
        mask = g.ndata["center"] == 1
        pos_dim = g.ndata["pos"].shape[-1]
        feat_dim = g.ndata["new_feat"].shape[-1]
        pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
        feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
esang's avatar
esang committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        return pos_res, feat_res


class FeaturePropagation(nn.Module):
    """
    The FeaturePropagation Layer
    """

    def __init__(self, input_dims, sizes):
        super(FeaturePropagation, self).__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        sizes = [input_dims] + sizes
        for i in range(1, len(sizes)):
213
            self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
esang's avatar
esang committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            self.bns.append(nn.BatchNorm1d(sizes[i]))

    def forward(self, x1, x2, feat1, feat2):
        """
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
            Input:
                x1: input points position data, [B, N, C]
                x2: sampled input points position data, [B, S, C]
                feat1: input points data, [B, N, D]
                feat2: input points data, [B, S, D]
            Return:
                new_feat: upsampled points data, [B, D', N]
        """
        B, N, C = x1.shape
        _, S, _ = x2.shape

        if S == 1:
            interpolated_feat = feat2.repeat(1, N, 1)
        else:
            dists = square_distance(x1, x2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
240
241
242
            interpolated_feat = torch.sum(
                index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
            )
esang's avatar
esang committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        if feat1 is not None:
            new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
        else:
            new_feat = interpolated_feat

        new_feat = new_feat.permute(0, 2, 1)  # [B, D, S]
        for i, conv in enumerate(self.convs):
            bn = self.bns[i]
            new_feat = F.relu(bn(conv(new_feat)))
        return new_feat


class SwapAxes(nn.Module):
    def __init__(self, dim1=1, dim2=2):
        super(SwapAxes, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


class TransitionUp(nn.Module):
    """
    The Transition Up Module
    """

    def __init__(self, dim1, dim2, dim_out):
        super(TransitionUp, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(dim1, dim_out),
            SwapAxes(),
            nn.BatchNorm1d(dim_out),  # TODO
            SwapAxes(),
            nn.ReLU(),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(dim2, dim_out),
            SwapAxes(),
            nn.BatchNorm1d(dim_out),  # TODO
            SwapAxes(),
            nn.ReLU(),
        )
        self.fp = FeaturePropagation(-1, [])

    def forward(self, pos1, feat1, pos2, feat2):
        h1 = self.fc1(feat1)
        h2 = self.fc2(feat2)
        h1 = self.fp(pos2, pos1, None, h1).transpose(1, 2)
        return h1 + h2