model.py 4.04 KB
Newer Older
1
2
3
4
5
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

6
7
8
import dgl.function as fn


9
def drop_node(feats, drop_rate, training):
10

11
12
    n = feats.shape[0]
    drop_rates = th.FloatTensor(np.ones(n) * drop_rate)
13

14
    if training:
15
16

        masks = th.bernoulli(1.0 - drop_rates).unsqueeze(1)
17
        feats = masks.to(feats.device) * feats
18

19
    else:
20
        feats = feats * (1.0 - drop_rate)
21
22
23

    return feats

24

25
class MLP(nn.Module):
26
27
28
    def __init__(
        self, nfeat, nhid, nclass, input_droprate, hidden_droprate, use_bn=False
    ):
29
        super(MLP, self).__init__()
30
31
32

        self.layer1 = nn.Linear(nfeat, nhid, bias=True)
        self.layer2 = nn.Linear(nhid, nclass, bias=True)
33
34
35
36
37
38

        self.input_dropout = nn.Dropout(input_droprate)
        self.hidden_dropout = nn.Dropout(hidden_droprate)
        self.bn1 = nn.BatchNorm1d(nfeat)
        self.bn2 = nn.BatchNorm1d(nhid)
        self.use_bn = use_bn
39

40
41
42
    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
43

44
    def forward(self, x):
45
46

        if self.use_bn:
47
48
49
            x = self.bn1(x)
        x = self.input_dropout(x)
        x = F.relu(self.layer1(x))
50

51
52
53
54
55
        if self.use_bn:
            x = self.bn2(x)
        x = self.hidden_dropout(x)
        x = self.layer2(x)

56
57
        return x

58
59

def GRANDConv(graph, feats, order):
60
    """
61
62
63
64
65
66
    Parameters
    -----------
    graph: dgl.Graph
        The input graph
    feats: Tensor (n_nodes * feat_dim)
        Node features
67
    order: int
68
        Propagation Steps
69
    """
70
    with graph.local_scope():
71
72

        """Calculate Symmetric normalized adjacency matrix   \hat{A}"""
73
74
75
        degs = graph.in_degrees().float().clamp(min=1)
        norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1)

76
77
78
79
        graph.ndata["norm"] = norm
        graph.apply_edges(fn.u_mul_v("norm", "norm", "weight"))

        """ Graph Conv """
80
        x = feats
81
        y = 0 + feats
82
83

        for i in range(order):
84
85
86
            graph.ndata["h"] = x
            graph.update_all(fn.u_mul_e("h", "weight", "m"), fn.sum("m", "h"))
            x = graph.ndata.pop("h")
87
88
            y.add_(x)

89
90
    return y / (order + 1)

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

class GRAND(nn.Module):
    r"""

    Parameters
    -----------
    in_dim: int
        Input feature size. i.e, the number of dimensions of: math: `H^{(i)}`.
    hid_dim: int
        Hidden feature size.
    n_class: int
        Number of classes.
    S: int
        Number of Augmentation samples
    K: int
        Number of Propagation Steps
    node_dropout: float
        Dropout rate on node features.
    input_dropout: float
        Dropout rate of the input layer of a MLP
    hidden_dropout: float
        Dropout rate of the hidden layer of a MLPx
    batchnorm: bool, optional
        If True, use batch normalization.

    """
117
118
119
120
121
122
123
124
125
126
127
128
129

    def __init__(
        self,
        in_dim,
        hid_dim,
        n_class,
        S=1,
        K=3,
        node_dropout=0.0,
        input_droprate=0.0,
        hidden_droprate=0.0,
        batchnorm=False,
    ):
130
131
132
133
134
135
136

        super(GRAND, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.S = S
        self.K = K
        self.n_class = n_class
137
138
139
140
141

        self.mlp = MLP(
            in_dim, hid_dim, n_class, input_droprate, hidden_droprate, batchnorm
        )

142
143
144
        self.dropout = node_dropout
        self.node_dropout = nn.Dropout(node_dropout)

145
146
    def forward(self, graph, feats, training=True):

147
148
        X = feats
        S = self.S
149
150

        if training:  # Training Mode
151
152
153
            output_list = []
            for s in range(S):
                drop_feat = drop_node(X, self.dropout, True)  # Drop node
154
155
156
157
158
                feat = GRANDConv(graph, drop_feat, self.K)  # Graph Convolution
                output_list.append(
                    th.log_softmax(self.mlp(feat), dim=-1)
                )  # Prediction

159
            return output_list
160
161
162
        else:  # Inference Mode
            drop_feat = drop_node(X, self.dropout, False)
            X = GRANDConv(graph, drop_feat, self.K)
163

164
            return th.log_softmax(self.mlp(X), dim=-1)