modules.py 3.04 KB
Newer Older
1
2
3
4
5
6
import math

import dgl.function as fn
import torch
import torch.nn as nn

7
class GraphSAGELayer(nn.Module):
8
9
10
11
12
13
14
15
    def __init__(self,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True,
                 use_pp=False,
                 use_lynorm=True):
16
17
18
19
        super(GraphSAGELayer, self).__init__()
        # The input feature size gets doubled as we concatenated the original
        # features with the new features.
        self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias)
20
21
22
23
24
25
26
27
28
29
30
31
32
        self.activation = activation
        self.use_pp = use_pp
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        if use_lynorm:
            self.lynorm = nn.LayerNorm(out_feats, elementwise_affine=True)
        else:
            self.lynorm = lambda x: x
        self.reset_parameters()

    def reset_parameters(self):
33
34
35
36
        stdv = 1. / math.sqrt(self.linear.weight.size(1))
        self.linear.weight.data.uniform_(-stdv, stdv)
        if self.linear.bias is not None:
            self.linear.bias.data.uniform_(-stdv, stdv)
37

38
39
    def forward(self, g, h):
        g = g.local_var()
40
        if not self.use_pp or not self.training:
41
            norm = self.get_norm(g)
42
43
44
45
46
47
48
49
            g.ndata['h'] = h
            g.update_all(fn.copy_src(src='h', out='m'),
                         fn.sum(msg='m', out='h'))
            ah = g.ndata.pop('h')
            h = self.concat(h, ah, norm)

        if self.dropout:
            h = self.dropout(h)
50
51

        h = self.linear(h)
52
53
54
55
56
57
        h = self.lynorm(h)
        if self.activation:
            h = self.activation(h)
        return h

    def concat(self, h, ah, norm):
58
59
60
        ah = ah * norm
        h = torch.cat((h, ah), dim=1)
        return h
61
62
63
64

    def get_norm(self, g):
        norm = 1. / g.in_degrees().float().unsqueeze(1)
        norm[torch.isinf(norm)] = 0
65
        norm = norm.to(self.linear.weight.device)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        return norm

class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 use_pp):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()

        # input layer
        self.layers.append(GCNLayerSAGE(in_feats, n_hidden, activation=activation,
                                        dropout=dropout, use_pp=use_pp, use_lynorm=True))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(
86
87
                GCNLayerSAGE(n_hidden, n_hidden, activation=activation, dropout=dropout,
                             use_pp=False, use_lynorm=True))
88
89
90
91
92
93
94
        # output layer
        self.layers.append(GCNLayerSAGE(n_hidden, n_classes, activation=None,
                                        dropout=dropout, use_pp=False, use_lynorm=False))

    def forward(self, g):
        h = g.ndata['features']
        for layer in self.layers:
95
            h = layer(g, h)
96
        return h