model.py 4.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import math

def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)

def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

class ARMAConv(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 num_stacks,
                 num_layers,
                 activation=None,
                 dropout=0.0,
                 bias=True):
        super(ARMAConv, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.K = num_stacks
        self.T = num_layers
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout)

        # init weight
        self.w_0 = nn.ModuleDict({
            str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
        })
        # deeper weight
        self.w = nn.ModuleDict({
            str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)
        })
        # v
        self.v = nn.ModuleDict({
            str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
        })
        # bias
        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        for k in range(self.K):
            glorot(self.w_0[str(k)].weight)
            glorot(self.w[str(k)].weight)
            glorot(self.v[str(k)].weight)
        zeros(self.bias)

    def forward(self, g, feats):
        with g.local_scope():
            init_feats = feats
            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
            degs = g.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
67
            output = [] 
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

            for k in range(self.K):
                feats = init_feats
                for t in range(self.T):
                    feats = feats * norm
                    g.ndata['h'] = feats
                    g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                    feats = g.ndata.pop('h')
                    feats = feats * norm

                    if t == 0:
                        feats = self.w_0[str(k)](feats)
                    else:
                        feats = self.w[str(k)](feats)
                    
                    feats += self.dropout(self.v[str(k)](init_feats))
                    feats += self.v[str(k)](self.dropout(init_feats))

                    if self.bias is not None:
                        feats += self.bias[k][t]
                    
                    if self.activation is not None:
                        feats = self.activation(feats)
91
92
93
                output.append(feats)

            return torch.stack(output).mean(dim=0)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

class ARMA4NC(nn.Module):
    def __init__(self,
                 in_dim,
                 hid_dim,
                 out_dim,
                 num_stacks,
                 num_layers,
                 activation=None,
                 dropout=0.0):
        super(ARMA4NC, self).__init__()

        self.conv1 = ARMAConv(in_dim=in_dim,
                              out_dim=hid_dim,
                              num_stacks=num_stacks,
                              num_layers=num_layers,
                              activation=activation,
                              dropout=dropout)

        self.conv2 = ARMAConv(in_dim=hid_dim,
                              out_dim=out_dim,
                              num_stacks=num_stacks,
                              num_layers=num_layers,
                              activation=activation,
                              dropout=dropout)
        
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, feats):
        feats = F.relu(self.conv1(g, feats))
        feats = self.dropout(feats)
        feats = self.conv2(g, feats)
        return feats