helper.py 5.03 KB
Newer Older
esang's avatar
esang committed
1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
4

esang's avatar
esang committed
5
6
7
import dgl
from dgl.geometry import farthest_point_sampler

8
"""
esang's avatar
esang committed
9
10
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
11
"""
esang's avatar
esang committed
12
13
14


def square_distance(src, dst):
15
    """
esang's avatar
esang committed
16
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
17
    """
esang's avatar
esang committed
18
19
20
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
21
22
    dist += torch.sum(src**2, -1).view(B, N, 1)
    dist += torch.sum(dst**2, -1).view(B, 1, M)
esang's avatar
esang committed
23
24
25
26
    return dist


def index_points(points, idx):
27
    """
esang's avatar
esang committed
28
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
29
    """
esang's avatar
esang committed
30
31
32
33
34
35
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
36
37
38
39
40
41
    batch_indices = (
        torch.arange(B, dtype=torch.long)
        .to(device)
        .view(view_shape)
        .repeat(repeat_shape)
    )
esang's avatar
esang committed
42
43
44
45
46
    new_points = points[batch_indices, idx, :]
    return new_points


class KNearNeighbors(nn.Module):
47
    """
esang's avatar
esang committed
48
    Find the k nearest neighbors
49
    """
esang's avatar
esang committed
50
51
52
53
54
55

    def __init__(self, n_neighbor):
        super(KNearNeighbors, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, pos, centroids):
56
        """
esang's avatar
esang committed
57
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
58
        """
esang's avatar
esang committed
59
60
        center_pos = index_points(pos, centroids)
        sqrdists = square_distance(center_pos, pos)
61
        group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
esang's avatar
esang committed
62
63
64
65
        return group_idx


class KNNGraphBuilder(nn.Module):
66
    """
esang's avatar
esang committed
67
    Build NN graph
68
    """
esang's avatar
esang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    def __init__(self, n_neighbor):
        super(KNNGraphBuilder, self).__init__()
        self.n_neighbor = n_neighbor
        self.knn = KNearNeighbors(n_neighbor)

    def forward(self, pos, centroids, feat=None):
        dev = pos.device
        group_idx = self.knn(pos, centroids)
        B, N, _ = pos.shape
        glist = []
        for i in range(B):
            center = torch.zeros((N)).to(dev)
            center[centroids[i]] = 1
            src = group_idx[i].contiguous().view(-1)
84
85
86
87
88
89
90
91
            dst = (
                centroids[i]
                .view(-1, 1)
                .repeat(
                    1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
                )
                .view(-1)
            )
esang's avatar
esang committed
92
93
94

            unified = torch.cat([src, dst])
            uniq, inv_idx = torch.unique(unified, return_inverse=True)
95
96
            src_idx = inv_idx[: src.shape[0]]
            dst_idx = inv_idx[src.shape[0] :]
esang's avatar
esang committed
97
98

            g = dgl.graph((src_idx, dst_idx))
99
100
            g.ndata["pos"] = pos[i][uniq]
            g.ndata["center"] = center[uniq]
esang's avatar
esang committed
101
            if feat is not None:
102
                g.ndata["feat"] = feat[i][uniq]
esang's avatar
esang committed
103
104
105
106
107
108
            glist.append(g)
        bg = dgl.batch(glist)
        return bg


class KNNMessage(nn.Module):
109
    """
esang's avatar
esang committed
110
    Compute the input feature from neighbors
111
    """
esang's avatar
esang committed
112
113
114
115
116
117

    def __init__(self, n_neighbor):
        super(KNNMessage, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, edges):
118
119
120
        norm = edges.src["feat"] - edges.dst["feat"]
        if "feat" in edges.src:
            res = torch.cat([norm, edges.src["feat"]], 1)
esang's avatar
esang committed
121
122
        else:
            res = norm
123
        return {"agg_feat": res}
esang's avatar
esang committed
124
125
126


class KNNConv(nn.Module):
127
    """
esang's avatar
esang committed
128
    Feature aggregation
129
    """
esang's avatar
esang committed
130
131
132
133
134
135

    def __init__(self, sizes):
        super(KNNConv, self).__init__()
        self.conv = nn.ModuleList()
        self.bn = nn.ModuleList()
        for i in range(1, len(sizes)):
136
            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
esang's avatar
esang committed
137
138
139
            self.bn.append(nn.BatchNorm2d(sizes[i]))

    def forward(self, nodes):
140
141
142
143
144
145
        shape = nodes.mailbox["agg_feat"].shape
        h = (
            nodes.mailbox["agg_feat"]
            .view(shape[0], -1, shape[1], shape[2])
            .permute(0, 3, 2, 1)
        )
esang's avatar
esang committed
146
147
148
149
150
151
152
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
        h = torch.max(h, 2)[0]
        feat_dim = h.shape[1]
        h = h.permute(0, 2, 1).reshape(-1, feat_dim)
153
        return {"new_feat": h}
esang's avatar
esang committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172


class TransitionDown(nn.Module):
    """
    The Transition Down Module
    """

    def __init__(self, in_channels, out_channels, n_neighbor=64):
        super(TransitionDown, self).__init__()
        self.frnn_graph = KNNGraphBuilder(n_neighbor)
        self.message = KNNMessage(n_neighbor)
        self.conv = KNNConv([in_channels, out_channels, out_channels])

    def forward(self, pos, feat, n_point):
        batch_size = pos.shape[0]
        centroids = farthest_point_sampler(pos, n_point)
        g = self.frnn_graph(pos, centroids, feat)
        g.update_all(self.message, self.conv)

173
174
175
176
177
        mask = g.ndata["center"] == 1
        pos_dim = g.ndata["pos"].shape[-1]
        feat_dim = g.ndata["new_feat"].shape[-1]
        pos_res = g.ndata["pos"][mask].view(batch_size, -1, pos_dim)
        feat_res = g.ndata["new_feat"][mask].view(batch_size, -1, feat_dim)
esang's avatar
esang committed
178
        return pos_res, feat_res