gnn.py 2.84 KB
Newer Older
1
2
import copy
import itertools
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3

GaiYu0's avatar
GaiYu0 committed
4
5
import dgl
import dgl.function as fn
6
import networkx as nx
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
import numpy as np
8
9
import torch as th
import torch.nn as nn
GaiYu0's avatar
GaiYu0 committed
10
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11

12

GaiYu0's avatar
GaiYu0 committed
13
class GNNModule(nn.Module):
14
15
    def __init__(self, in_feats, out_feats, radius):
        super().__init__()
GaiYu0's avatar
GaiYu0 committed
16
        self.out_feats = out_feats
17
18
        self.radius = radius

19
        new_linear = lambda: nn.Linear(in_feats, out_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
21
22
        new_linear_list = lambda: nn.ModuleList(
            [new_linear() for i in range(radius)]
        )
23

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
25
26
27
28
        self.theta_x, self.theta_deg, self.theta_y = (
            new_linear(),
            new_linear(),
            new_linear(),
        )
29
        self.theta_list = new_linear_list()
30

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31
32
33
34
35
        self.gamma_y, self.gamma_deg, self.gamma_x = (
            new_linear(),
            new_linear(),
            new_linear(),
        )
36
        self.gamma_list = new_linear_list()
37

GaiYu0's avatar
GaiYu0 committed
38
39
        self.bn_x = nn.BatchNorm1d(out_feats)
        self.bn_y = nn.BatchNorm1d(out_feats)
40

GaiYu0's avatar
GaiYu0 committed
41
42
    def aggregate(self, g, z):
        z_list = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
43
44
45
        g.ndata["z"] = z
        g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
        z_list.append(g.ndata["z"])
GaiYu0's avatar
GaiYu0 committed
46
        for i in range(self.radius - 1):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
47
48
49
50
51
            for j in range(2**i):
                g.update_all(
                    fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z")
                )
            z_list.append(g.ndata["z"])
GaiYu0's avatar
GaiYu0 committed
52
        return z_list
53

54
55
    def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
        pmpd_x = F.embedding(pm_pd, x)
56

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
57
58
59
        sum_x = sum(
            theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))
        )
GaiYu0's avatar
GaiYu0 committed
60

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
61
62
63
        g.edata["y"] = y
        g.update_all(fn.copy_e(e="y", out="m"), fn.sum("m", "pmpd_y"))
        pmpd_y = g.ndata.pop("pmpd_y")
GaiYu0's avatar
GaiYu0 committed
64

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
65
66
67
68
69
70
        x = (
            self.theta_x(x)
            + self.theta_deg(deg_g * x)
            + sum_x
            + self.theta_y(pmpd_y)
        )
71
72
73
        n = self.out_feats // 2
        x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)
        x = self.bn_x(x)
74

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
75
76
77
        sum_y = sum(
            gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))
        )
78

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
79
80
81
82
83
84
        y = (
            self.gamma_y(y)
            + self.gamma_deg(deg_lg * y)
            + sum_y
            + self.gamma_x(pmpd_x)
        )
85
86
        y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)
        y = self.bn_y(y)
87

88
        return x, y
89

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
90

91
class GNN(nn.Module):
GaiYu0's avatar
GaiYu0 committed
92
    def __init__(self, feats, radius, n_classes):
GaiYu0's avatar
GaiYu0 committed
93
94
        super(GNN, self).__init__()
        self.linear = nn.Linear(feats[-1], n_classes)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
95
96
97
        self.module_list = nn.ModuleList(
            [GNNModule(m, n, radius) for m, n in zip(feats[:-1], feats[1:])]
        )
GaiYu0's avatar
GaiYu0 committed
98

99
100
    def forward(self, g, lg, deg_g, deg_lg, pm_pd):
        x, y = deg_g, deg_lg
GaiYu0's avatar
GaiYu0 committed
101
        for module in self.module_list:
102
            x, y = module(g, lg, x, y, deg_g, deg_lg, pm_pd)
GaiYu0's avatar
GaiYu0 committed
103
        return self.linear(x)