gnn.py 3.62 KB
Newer Older
1
2
3
4
5
6
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Deviations from paper:
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
GaiYu0's avatar
GaiYu0 committed
7
- Pm Pd
8
9
10
11
12
"""


import copy
import itertools
GaiYu0's avatar
GaiYu0 committed
13
14
import dgl
import dgl.function as fn
15
16
17
import networkx as nx
import torch as th
import torch.nn as nn
GaiYu0's avatar
GaiYu0 committed
18
import torch.nn.functional as F
19
20


GaiYu0's avatar
GaiYu0 committed
21
class GNNModule(nn.Module):
22
23
    def __init__(self, in_feats, out_feats, radius):
        super().__init__()
GaiYu0's avatar
GaiYu0 committed
24
        self.out_feats = out_feats
25
26
        self.radius = radius

GaiYu0's avatar
GaiYu0 committed
27
        new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
28
29
        new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])

GaiYu0's avatar
GaiYu0 committed
30
31
        self.theta_x, self.theta_deg, self.theta_y = \
            new_linear(), new_linear(), new_linear()
32
33
        self.theta_list = new_module_list()

GaiYu0's avatar
GaiYu0 committed
34
35
        self.gamma_y, self.gamma_deg, self.gamma_x = \
            new_linear(), new_linear(), new_linear()
36
37
        self.gamma_list = new_module_list()

GaiYu0's avatar
GaiYu0 committed
38
39
        self.bn_x = nn.BatchNorm1d(out_feats)
        self.bn_y = nn.BatchNorm1d(out_feats)
40

GaiYu0's avatar
GaiYu0 committed
41
42
43
44
45
46
47
48
49
50
    def aggregate(self, g, z):
        z_list = []
        g.set_n_repr(z)
        g.update_all(fn.copy_src(), fn.sum(), batchable=True)
        z_list.append(g.get_n_repr())
        for i in range(self.radius - 1):
            for j in range(2 ** i):
                g.update_all(fn.copy_src(), fn.sum(), batchable=True)
            z_list.append(g.get_n_repr())
        return z_list
51

GaiYu0's avatar
GaiYu0 committed
52
53
    def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
        xy = F.embedding(eid2nid, x)
54

GaiYu0's avatar
GaiYu0 committed
55
56
57
58
59
60
        x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
        g.set_e_repr(y)
        g.update_all(fn.copy_edge(), fn.sum(), batchable=True)
        yx = g.get_n_repr()
        x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
        x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
61

GaiYu0's avatar
GaiYu0 committed
62
63
64
65
66
67
        y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
        lg.set_e_repr(xy)
        lg.update_all(fn.copy_edge(), fn.sum(), batchable=True)
        xy = lg.get_n_repr()
        y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
        y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
68

GaiYu0's avatar
GaiYu0 committed
69
        return x, y
70
71
72


class GNN(nn.Module):
GaiYu0's avatar
GaiYu0 committed
73
    def __init__(self, g, feats, radius, n_classes):
74
75
76
77
78
        """
        Parameters
        ----------
        g : networkx.DiGraph
        """
GaiYu0's avatar
GaiYu0 committed
79
        super(GNN, self).__init__()
80

GaiYu0's avatar
GaiYu0 committed
81
82
83
84
85
        lg = nx.line_graph(g)
        x = list(zip(*g.degree))[1]
        self.x = self.normalize(th.tensor(x, dtype=th.float).unsqueeze(1))
        y = list(zip(*lg.degree))[1]
        self.y = self.normalize(th.tensor(y, dtype=th.float).unsqueeze(1))
GaiYu0's avatar
GaiYu0 committed
86
        self.eid2nid = th.tensor([int(n) for [[_, n], [_, _]] in lg.edges])
87

GaiYu0's avatar
GaiYu0 committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        self.g = dgl.DGLGraph(g)
        self.lg = dgl.DGLGraph(nx.convert_node_labels_to_integers(lg))

        self.linear = nn.Linear(feats[-1], n_classes)
        self.module_list = nn.ModuleList([GNNModule(m, n, radius)
                                          for m, n in zip(feats[:-1], feats[1:])])

    @staticmethod
    def normalize(x):
        x = x - th.mean(x, 0)
        x = x / th.sqrt(th.mean(x * x, 0))
        return x

    def cuda(self):
        self.x = self.x.cuda()
        self.y = self.y.cuda()
        self.eid2nid = self.eid2nid.cuda()
        super(GNN, self).cuda()

    def forward(self):
        x, y = self.x, self.y
        for module in self.module_list:
            x, y = module(self.g, self.lg, x, y, self.x, self.y, self.eid2nid)
        return self.linear(x)