model.py 3.97 KB
Newer Older
1
2
import math

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl.function as fn

5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8

9
10
11
12
13
14

def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)

15

16
17
18
19
def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

20

21
class ARMAConv(nn.Module):
22
23
24
25
26
27
28
29
30
31
    def __init__(
        self,
        in_dim,
        out_dim,
        num_stacks,
        num_layers,
        activation=None,
        dropout=0.0,
        bias=True,
    ):
32
        super(ARMAConv, self).__init__()
33

34
35
36
37
38
39
40
41
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.K = num_stacks
        self.T = num_layers
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout)

        # init weight
42
43
44
45
46
47
        self.w_0 = nn.ModuleDict(
            {
                str(k): nn.Linear(in_dim, out_dim, bias=False)
                for k in range(self.K)
            }
        )
48
        # deeper weight
49
50
51
52
53
54
        self.w = nn.ModuleDict(
            {
                str(k): nn.Linear(out_dim, out_dim, bias=False)
                for k in range(self.K)
            }
        )
55
        # v
56
57
58
59
60
61
        self.v = nn.ModuleDict(
            {
                str(k): nn.Linear(in_dim, out_dim, bias=False)
                for k in range(self.K)
            }
        )
62
63
        # bias
        if bias:
64
65
66
            self.bias = nn.Parameter(
                torch.Tensor(self.K, self.T, 1, self.out_dim)
            )
67
        else:
68
69
            self.register_parameter("bias", None)

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        self.reset_parameters()

    def reset_parameters(self):
        for k in range(self.K):
            glorot(self.w_0[str(k)].weight)
            glorot(self.w[str(k)].weight)
            glorot(self.v[str(k)].weight)
        zeros(self.bias)

    def forward(self, g, feats):
        with g.local_scope():
            init_feats = feats
            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
            degs = g.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
85
            output = []
86
87
88
89
90

            for k in range(self.K):
                feats = init_feats
                for t in range(self.T):
                    feats = feats * norm
91
92
93
                    g.ndata["h"] = feats
                    g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                    feats = g.ndata.pop("h")
94
95
96
97
98
99
                    feats = feats * norm

                    if t == 0:
                        feats = self.w_0[str(k)](feats)
                    else:
                        feats = self.w[str(k)](feats)
100

101
102
103
104
105
                    feats += self.dropout(self.v[str(k)](init_feats))
                    feats += self.v[str(k)](self.dropout(init_feats))

                    if self.bias is not None:
                        feats += self.bias[k][t]
106

107
108
                    if self.activation is not None:
                        feats = self.activation(feats)
109
110
111
                output.append(feats)

            return torch.stack(output).mean(dim=0)
112

113

114
class ARMA4NC(nn.Module):
115
116
117
118
119
120
121
122
123
124
    def __init__(
        self,
        in_dim,
        hid_dim,
        out_dim,
        num_stacks,
        num_layers,
        activation=None,
        dropout=0.0,
    ):
125
126
        super(ARMA4NC, self).__init__()

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        self.conv1 = ARMAConv(
            in_dim=in_dim,
            out_dim=hid_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            activation=activation,
            dropout=dropout,
        )

        self.conv2 = ARMAConv(
            in_dim=hid_dim,
            out_dim=out_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            activation=activation,
            dropout=dropout,
        )

145
146
147
148
149
150
151
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, feats):
        feats = F.relu(self.conv1(g, feats))
        feats = self.dropout(feats)
        feats = self.conv2(g, feats)
        return feats