gnn.py 2.68 KB
Newer Older
1
2
import copy
import itertools
GaiYu0's avatar
GaiYu0 committed
3
4
import dgl
import dgl.function as fn
5
6
7
import networkx as nx
import torch as th
import torch.nn as nn
GaiYu0's avatar
GaiYu0 committed
8
import torch.nn.functional as F
9

GaiYu0's avatar
GaiYu0 committed
10
class GNNModule(nn.Module):
11
12
    def __init__(self, in_feats, out_feats, radius):
        super().__init__()
GaiYu0's avatar
GaiYu0 committed
13
        self.out_feats = out_feats
14
15
        self.radius = radius

16
17
        new_linear = lambda: nn.Linear(in_feats, out_feats)
        new_linear_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
18

GaiYu0's avatar
GaiYu0 committed
19
20
        self.theta_x, self.theta_deg, self.theta_y = \
            new_linear(), new_linear(), new_linear()
21
        self.theta_list = new_linear_list()
22

GaiYu0's avatar
GaiYu0 committed
23
24
        self.gamma_y, self.gamma_deg, self.gamma_x = \
            new_linear(), new_linear(), new_linear()
25
        self.gamma_list = new_linear_list()
26

GaiYu0's avatar
GaiYu0 committed
27
28
        self.bn_x = nn.BatchNorm1d(out_feats)
        self.bn_y = nn.BatchNorm1d(out_feats)
29

GaiYu0's avatar
GaiYu0 committed
30
31
    def aggregate(self, g, z):
        z_list = []
32
33
34
        g.set_n_repr({'z' : z})
        g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
        z_list.append(g.get_n_repr()['z'])
GaiYu0's avatar
GaiYu0 committed
35
36
        for i in range(self.radius - 1):
            for j in range(2 ** i):
37
38
                g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
            z_list.append(g.get_n_repr()['z'])
GaiYu0's avatar
GaiYu0 committed
39
        return z_list
40

41
42
    def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
        pmpd_x = F.embedding(pm_pd, x)
43

44
        sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))
GaiYu0's avatar
GaiYu0 committed
45

46
47
48
        g.set_e_repr({'y' : y})
        g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y'))
        pmpd_y = g.pop_n_repr('pmpd_y')
GaiYu0's avatar
GaiYu0 committed
49

50
51
52
53
        x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
        n = self.out_feats // 2
        x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)
        x = self.bn_x(x)
54

55
        sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y)))
56

57
58
59
        y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x)
        y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)
        y = self.bn_y(y)
60

61
        return x, y
62
63

class GNN(nn.Module):
GaiYu0's avatar
GaiYu0 committed
64
    def __init__(self, feats, radius, n_classes):
65
66
67
68
69
        """
        Parameters
        ----------
        g : networkx.DiGraph
        """
GaiYu0's avatar
GaiYu0 committed
70
71
72
73
74
        super(GNN, self).__init__()
        self.linear = nn.Linear(feats[-1], n_classes)
        self.module_list = nn.ModuleList([GNNModule(m, n, radius)
                                          for m, n in zip(feats[:-1], feats[1:])])

75
76
    def forward(self, g, lg, deg_g, deg_lg, pm_pd):
        x, y = deg_g, deg_lg
GaiYu0's avatar
GaiYu0 committed
77
        for module in self.module_list:
78
            x, y = module(g, lg, x, y, deg_g, deg_lg, pm_pd)
GaiYu0's avatar
GaiYu0 committed
79
        return self.linear(x)