gnn.py 2.93 KB
Newer Older
1
2
3
4
5
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Deviations from paper:
GaiYu0's avatar
GaiYu0 committed
6
- Pm Pd
7
8
9
10
11
"""


import copy
import itertools
GaiYu0's avatar
GaiYu0 committed
12
13
import dgl
import dgl.function as fn
14
15
16
import networkx as nx
import torch as th
import torch.nn as nn
GaiYu0's avatar
GaiYu0 committed
17
import torch.nn.functional as F
18
19


GaiYu0's avatar
GaiYu0 committed
20
class GNNModule(nn.Module):
21
22
    def __init__(self, in_feats, out_feats, radius):
        super().__init__()
GaiYu0's avatar
GaiYu0 committed
23
        self.out_feats = out_feats
24
25
        self.radius = radius

GaiYu0's avatar
GaiYu0 committed
26
        new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
27
28
        new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])

GaiYu0's avatar
GaiYu0 committed
29
30
        self.theta_x, self.theta_deg, self.theta_y = \
            new_linear(), new_linear(), new_linear()
31
32
        self.theta_list = new_module_list()

GaiYu0's avatar
GaiYu0 committed
33
34
        self.gamma_y, self.gamma_deg, self.gamma_x = \
            new_linear(), new_linear(), new_linear()
35
36
        self.gamma_list = new_module_list()

GaiYu0's avatar
GaiYu0 committed
37
38
        self.bn_x = nn.BatchNorm1d(out_feats)
        self.bn_y = nn.BatchNorm1d(out_feats)
39

GaiYu0's avatar
GaiYu0 committed
40
41
42
    def aggregate(self, g, z):
        z_list = []
        g.set_n_repr(z)
GaiYu0's avatar
GaiYu0 committed
43
        g.update_all(fn.copy_src(), fn.sum())
GaiYu0's avatar
GaiYu0 committed
44
45
46
        z_list.append(g.get_n_repr())
        for i in range(self.radius - 1):
            for j in range(2 ** i):
GaiYu0's avatar
GaiYu0 committed
47
                g.update_all(fn.copy_src(), fn.sum())
GaiYu0's avatar
GaiYu0 committed
48
49
            z_list.append(g.get_n_repr())
        return z_list
50

GaiYu0's avatar
GaiYu0 committed
51
52
    def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
        xy = F.embedding(eid2nid, x)
53

GaiYu0's avatar
GaiYu0 committed
54
        x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
GaiYu0's avatar
GaiYu0 committed
55

GaiYu0's avatar
GaiYu0 committed
56
        g.set_e_repr(y)
GaiYu0's avatar
GaiYu0 committed
57
        g.update_all(fn.copy_edge(), fn.sum())
GaiYu0's avatar
GaiYu0 committed
58
        yx = g.get_n_repr()
GaiYu0's avatar
GaiYu0 committed
59

GaiYu0's avatar
GaiYu0 committed
60
61
        x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
        x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
62

GaiYu0's avatar
GaiYu0 committed
63
        y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
GaiYu0's avatar
GaiYu0 committed
64
        lg.set_n_repr(xy)
GaiYu0's avatar
GaiYu0 committed
65
        lg.update_all(fn.copy_src(), fn.sum())
GaiYu0's avatar
GaiYu0 committed
66
67
68
        xy = lg.get_n_repr()
        y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
        y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
69

GaiYu0's avatar
GaiYu0 committed
70
        return x, y
71
72
73


class GNN(nn.Module):
GaiYu0's avatar
GaiYu0 committed
74
    def __init__(self, feats, radius, n_classes):
75
76
77
78
79
        """
        Parameters
        ----------
        g : networkx.DiGraph
        """
GaiYu0's avatar
GaiYu0 committed
80
81
82
83
84
        super(GNN, self).__init__()
        self.linear = nn.Linear(feats[-1], n_classes)
        self.module_list = nn.ModuleList([GNNModule(m, n, radius)
                                          for m, n in zip(feats[:-1], feats[1:])])

GaiYu0's avatar
GaiYu0 committed
85
86
87
88
89
    def forward(self, g, lg, deg_g, deg_lg, eid2nid):
        def normalize(x):
            x = x - th.mean(x, 0)
            x = x / th.sqrt(th.mean(x * x, 0))
            return x
GaiYu0's avatar
GaiYu0 committed
90

GaiYu0's avatar
GaiYu0 committed
91
92
        x = normalize(deg_g)
        y = normalize(deg_lg)
GaiYu0's avatar
GaiYu0 committed
93
        for module in self.module_list:
GaiYu0's avatar
GaiYu0 committed
94
            x, y = module(g, lg, x, y, deg_g, deg_lg, eid2nid)
GaiYu0's avatar
GaiYu0 committed
95
        return self.linear(x)