gnn.py 3.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Deviations from paper:
- Addition of global aggregation operator.
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
"""


# TODO self-loop?


import copy
import itertools
GaiYu0's avatar
GaiYu0 committed
16
17
import dgl
import dgl.function as fn
18
19
20
import networkx as nx
import torch as th
import torch.nn as nn
GaiYu0's avatar
GaiYu0 committed
21
import torch.nn.functional as F
22
23


GaiYu0's avatar
GaiYu0 committed
24
class GNNModule(nn.Module):
25
26
    def __init__(self, in_feats, out_feats, radius):
        super().__init__()
GaiYu0's avatar
GaiYu0 committed
27
        self.out_feats = out_feats
28
29
        self.radius = radius

GaiYu0's avatar
GaiYu0 committed
30
        new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
31
32
        new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])

GaiYu0's avatar
GaiYu0 committed
33
34
        self.theta_x, self.theta_deg, self.theta_y = \
            new_linear(), new_linear(), new_linear()
35
36
        self.theta_list = new_module_list()

GaiYu0's avatar
GaiYu0 committed
37
38
        self.gamma_y, self.gamma_deg, self.gamma_x = \
            new_linear(), new_linear(), new_linear()
39
40
        self.gamma_list = new_module_list()

GaiYu0's avatar
GaiYu0 committed
41
42
        self.bn_x = nn.BatchNorm1d(out_feats)
        self.bn_y = nn.BatchNorm1d(out_feats)
43

GaiYu0's avatar
GaiYu0 committed
44
45
46
47
48
49
50
51
52
53
    def aggregate(self, g, z):
        z_list = []
        g.set_n_repr(z)
        g.update_all(fn.copy_src(), fn.sum(), batchable=True)
        z_list.append(g.get_n_repr())
        for i in range(self.radius - 1):
            for j in range(2 ** i):
                g.update_all(fn.copy_src(), fn.sum(), batchable=True)
            z_list.append(g.get_n_repr())
        return z_list
54

GaiYu0's avatar
GaiYu0 committed
55
56
    def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
        xy = F.embedding(eid2nid, x)
57

GaiYu0's avatar
GaiYu0 committed
58
59
60
61
62
63
        x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
        g.set_e_repr(y)
        g.update_all(fn.copy_edge(), fn.sum(), batchable=True)
        yx = g.get_n_repr()
        x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
        x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
64

GaiYu0's avatar
GaiYu0 committed
65
66
67
68
69
70
        y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
        lg.set_e_repr(xy)
        lg.update_all(fn.copy_edge(), fn.sum(), batchable=True)
        xy = lg.get_n_repr()
        y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
        y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
71

GaiYu0's avatar
GaiYu0 committed
72
        return x, y
73
74
75


class GNN(nn.Module):
GaiYu0's avatar
GaiYu0 committed
76
    def __init__(self, g, feats, radius, n_classes):
77
78
79
80
81
        """
        Parameters
        ----------
        g : networkx.DiGraph
        """
GaiYu0's avatar
GaiYu0 committed
82
        super(GNN, self).__init__()
83

GaiYu0's avatar
GaiYu0 committed
84
85
86
87
88
89
        lg = nx.line_graph(g)
        x = list(zip(*g.degree))[1]
        self.x = self.normalize(th.tensor(x, dtype=th.float).unsqueeze(1))
        y = list(zip(*lg.degree))[1]
        self.y = self.normalize(th.tensor(y, dtype=th.float).unsqueeze(1))
        self.eid2nid = th.tensor([n for [[_, n], _] in lg.edges])
90

GaiYu0's avatar
GaiYu0 committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self.g = dgl.DGLGraph(g)
        self.lg = dgl.DGLGraph(nx.convert_node_labels_to_integers(lg))

        self.linear = nn.Linear(feats[-1], n_classes)
        self.module_list = nn.ModuleList([GNNModule(m, n, radius)
                                          for m, n in zip(feats[:-1], feats[1:])])

    @staticmethod
    def normalize(x):
        x = x - th.mean(x, 0)
        x = x / th.sqrt(th.mean(x * x, 0))
        return x

    def cuda(self):
        self.x = self.x.cuda()
        self.y = self.y.cuda()
        self.eid2nid = self.eid2nid.cuda()
        super(GNN, self).cuda()

    def forward(self):
        x, y = self.x, self.y
        for module in self.module_list:
            x, y = module(self.g, self.lg, x, y, self.x, self.y, self.eid2nid)
        return self.linear(x)