basic.py 7.57 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
彭卓清's avatar
彭卓清 committed
2
3
import torch
import torch.nn as nn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
import torch.nn.functional as F
彭卓清's avatar
彭卓清 committed
5
6
from torch.autograd import Function
from torch.nn import Parameter
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
from torch.nn.modules.utils import _single
彭卓清's avatar
彭卓清 committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


class BinaryQuantize(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.sign(input)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors
        grad_input = grad_output
        grad_input[input[0].gt(1)] = 0
        grad_input[input[0].lt(-1)] = 0
        return grad_input


class BiLinearLSR(torch.nn.Linear):
    def __init__(self, in_features, out_features, bias=False, binary_act=True):
        super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias)
        self.binary_act = binary_act

        # must register a nn.Parameter placeholder for model loading
        # self.register_parameter('scale', None) doesn't register None into state_dict
        # so it leads to unexpected key error when loading saved model
        # hence, init scale with Parameter
        # however, Parameter(None) actually has size [0], not [] as a scalar
        # hence, init it using the following trick
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37
38
39
        self.register_parameter(
            "scale", Parameter(torch.Tensor([0.0]).squeeze())
        )
彭卓清's avatar
彭卓清 committed
40
41
42
43
44

    def reset_scale(self, input):
        bw = self.weight
        ba = input
        bw = bw - bw.mean()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
46
47
48
49
50
51
52
        self.scale = Parameter(
            (
                F.linear(ba, bw).std()
                / F.linear(torch.sign(ba), torch.sign(bw)).std()
            )
            .float()
            .to(ba.device)
        )
彭卓清's avatar
彭卓清 committed
53
54
        # corner case when ba is all 0.0
        if torch.isnan(self.scale):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
55
56
57
            self.scale = Parameter(
                (bw.std() / torch.sign(bw).std()).float().to(ba.device)
            )
彭卓清's avatar
彭卓清 committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    def forward(self, input):
        bw = self.weight
        ba = input
        bw = bw - bw.mean()

        if self.scale.item() == 0.0:
            self.reset_scale(input)

        bw = BinaryQuantize().apply(bw)
        bw = bw * self.scale
        if self.binary_act:
            ba = BinaryQuantize().apply(ba)
        output = F.linear(ba, bw)
        return output


class BiLinear(torch.nn.Linear):
    def __init__(self, in_features, out_features, bias=True, binary_act=True):
        super(BiLinear, self).__init__(in_features, out_features, bias=True)
        self.binary_act = binary_act
        self.output_ = None

    def forward(self, input):
        bw = self.weight
        ba = input
        bw = BinaryQuantize().apply(bw)
        if self.binary_act:
            ba = BinaryQuantize().apply(ba)
        output = F.linear(ba, bw, self.bias)
        self.output_ = output
        return output


class BiConv2d(torch.nn.Conv2d):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93
94
95
96
97
98
99
100
101
102
103
104
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        padding_mode="zeros",
    ):
彭卓清's avatar
彭卓清 committed
105
        super(BiConv2d, self).__init__(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106
107
108
109
110
111
112
113
114
115
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
        )
彭卓清's avatar
彭卓清 committed
116
117
118
119
120
121
122
123

    def forward(self, input):
        bw = self.weight
        ba = input
        bw = bw - bw.mean()
        bw = BinaryQuantize().apply(bw)
        ba = BinaryQuantize().apply(ba)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        if self.padding_mode == "circular":
            expanded_padding = (
                (self.padding[0] + 1) // 2,
                self.padding[0] // 2,
            )
            return F.conv2d(
                F.pad(ba, expanded_padding, mode="circular"),
                bw,
                self.bias,
                self.stride,
                _single(0),
                self.dilation,
                self.groups,
            )
        return F.conv2d(
            ba,
            bw,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
彭卓清's avatar
彭卓清 committed
147
148
149


def square_distance(src, dst):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
150
    """
彭卓清's avatar
彭卓清 committed
151
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
152
    """
彭卓清's avatar
彭卓清 committed
153
154
155
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
156
157
    dist += torch.sum(src**2, -1).view(B, N, 1)
    dist += torch.sum(dst**2, -1).view(B, 1, M)
彭卓清's avatar
彭卓清 committed
158
159
    return dist

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
160

彭卓清's avatar
彭卓清 committed
161
def index_points(points, idx):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
162
    """
彭卓清's avatar
彭卓清 committed
163
    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
164
    """
彭卓清's avatar
彭卓清 committed
165
166
167
168
169
170
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
171
172
173
174
175
176
    batch_indices = (
        torch.arange(B, dtype=torch.long)
        .to(device)
        .view(view_shape)
        .repeat(repeat_shape)
    )
彭卓清's avatar
彭卓清 committed
177
178
179
180
181
    new_points = points[batch_indices, idx, :]
    return new_points


class FixedRadiusNearNeighbors(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
182
    """
彭卓清's avatar
彭卓清 committed
183
    Ball Query - Find the neighbors with-in a fixed radius
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
184
185
    """

彭卓清's avatar
彭卓清 committed
186
187
188
189
190
191
    def __init__(self, radius, n_neighbor):
        super(FixedRadiusNearNeighbors, self).__init__()
        self.radius = radius
        self.n_neighbor = n_neighbor

    def forward(self, pos, centroids):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
192
        """
彭卓清's avatar
彭卓清 committed
193
        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
194
        """
彭卓清's avatar
彭卓清 committed
195
196
197
198
        device = pos.device
        B, N, _ = pos.shape
        center_pos = index_points(pos, centroids)
        _, S, _ = center_pos.shape
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
199
200
201
202
203
204
        group_idx = (
            torch.arange(N, dtype=torch.long)
            .to(device)
            .view(1, 1, N)
            .repeat([B, S, 1])
        )
彭卓清's avatar
彭卓清 committed
205
        sqrdists = square_distance(center_pos, pos)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
206
207
208
209
210
        group_idx[sqrdists > self.radius**2] = N
        group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
        group_first = (
            group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
        )
彭卓清's avatar
彭卓清 committed
211
212
213
214
215
216
        mask = group_idx == N
        group_idx[mask] = group_first[mask]
        return group_idx


class FixedRadiusNNGraph(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
217
    """
彭卓清's avatar
彭卓清 committed
218
    Build NN graph
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
219
220
    """

彭卓清's avatar
彭卓清 committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    def __init__(self, radius, n_neighbor):
        super(FixedRadiusNNGraph, self).__init__()
        self.radius = radius
        self.n_neighbor = n_neighbor
        self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)

    def forward(self, pos, centroids, feat=None):
        dev = pos.device
        group_idx = self.frnn(pos, centroids)
        B, N, _ = pos.shape
        glist = []
        for i in range(B):
            center = torch.zeros((N)).to(dev)
            center[centroids[i]] = 1
            src = group_idx[i].contiguous().view(-1)
            dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)

            unified = torch.cat([src, dst])
            uniq, inv_idx = torch.unique(unified, return_inverse=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
240
241
            src_idx = inv_idx[: src.shape[0]]
            dst_idx = inv_idx[src.shape[0] :]
彭卓清's avatar
彭卓清 committed
242
243

            g = dgl.graph((src_idx, dst_idx))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
244
245
            g.ndata["pos"] = pos[i][uniq]
            g.ndata["center"] = center[uniq]
彭卓清's avatar
彭卓清 committed
246
            if feat is not None:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
247
                g.ndata["feat"] = feat[i][uniq]
彭卓清's avatar
彭卓清 committed
248
249
250
251
252
253
            glist.append(g)
        bg = dgl.batch(glist)
        return bg


class RelativePositionMessage(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
254
    """
彭卓清's avatar
彭卓清 committed
255
    Compute the input feature from neighbors
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
256
257
    """

彭卓清's avatar
彭卓清 committed
258
259
260
261
262
    def __init__(self, n_neighbor):
        super(RelativePositionMessage, self).__init__()
        self.n_neighbor = n_neighbor

    def forward(self, edges):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
263
264
265
        pos = edges.src["pos"] - edges.dst["pos"]
        if "feat" in edges.src:
            res = torch.cat([pos, edges.src["feat"]], 1)
彭卓清's avatar
彭卓清 committed
266
267
        else:
            res = pos
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
268
        return {"agg_feat": res}