model.py 4.03 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
2
3
4
5
6
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

7

8
9
10
def drop_node(feats, drop_rate, training):
    n = feats.shape[0]
    drop_rates = th.FloatTensor(np.ones(n) * drop_rate)
11

12
    if training:
13
        masks = th.bernoulli(1.0 - drop_rates).unsqueeze(1)
14
        feats = masks.to(feats.device) * feats
15

16
    else:
17
        feats = feats * (1.0 - drop_rate)
18
19
20

    return feats

21

22
class MLP(nn.Module):
23
24
25
    def __init__(
        self, nfeat, nhid, nclass, input_droprate, hidden_droprate, use_bn=False
    ):
26
        super(MLP, self).__init__()
27
28
29

        self.layer1 = nn.Linear(nfeat, nhid, bias=True)
        self.layer2 = nn.Linear(nhid, nclass, bias=True)
30
31
32
33
34
35

        self.input_dropout = nn.Dropout(input_droprate)
        self.hidden_dropout = nn.Dropout(hidden_droprate)
        self.bn1 = nn.BatchNorm1d(nfeat)
        self.bn2 = nn.BatchNorm1d(nhid)
        self.use_bn = use_bn
36

37
38
39
    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
40

41
    def forward(self, x):
42
        if self.use_bn:
43
44
45
            x = self.bn1(x)
        x = self.input_dropout(x)
        x = F.relu(self.layer1(x))
46

47
48
49
50
51
        if self.use_bn:
            x = self.bn2(x)
        x = self.hidden_dropout(x)
        x = self.layer2(x)

52
53
        return x

54
55

def GRANDConv(graph, feats, order):
56
    """
57
58
59
60
61
62
    Parameters
    -----------
    graph: dgl.Graph
        The input graph
    feats: Tensor (n_nodes * feat_dim)
        Node features
63
    order: int
64
        Propagation Steps
65
    """
66
    with graph.local_scope():
67
        """Calculate Symmetric normalized adjacency matrix   \hat{A}"""
68
69
70
        degs = graph.in_degrees().float().clamp(min=1)
        norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1)

71
72
73
74
        graph.ndata["norm"] = norm
        graph.apply_edges(fn.u_mul_v("norm", "norm", "weight"))

        """ Graph Conv """
75
        x = feats
76
        y = 0 + feats
77
78

        for i in range(order):
79
80
81
            graph.ndata["h"] = x
            graph.update_all(fn.u_mul_e("h", "weight", "m"), fn.sum("m", "h"))
            x = graph.ndata.pop("h")
82
83
            y.add_(x)

84
85
    return y / (order + 1)

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

class GRAND(nn.Module):
    r"""

    Parameters
    -----------
    in_dim: int
        Input feature size. i.e, the number of dimensions of: math: `H^{(i)}`.
    hid_dim: int
        Hidden feature size.
    n_class: int
        Number of classes.
    S: int
        Number of Augmentation samples
    K: int
        Number of Propagation Steps
    node_dropout: float
        Dropout rate on node features.
    input_dropout: float
        Dropout rate of the input layer of a MLP
    hidden_dropout: float
        Dropout rate of the hidden layer of a MLPx
    batchnorm: bool, optional
        If True, use batch normalization.

    """
112
113
114
115
116
117
118
119
120
121
122
123
124

    def __init__(
        self,
        in_dim,
        hid_dim,
        n_class,
        S=1,
        K=3,
        node_dropout=0.0,
        input_droprate=0.0,
        hidden_droprate=0.0,
        batchnorm=False,
    ):
125
126
127
128
129
130
        super(GRAND, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.S = S
        self.K = K
        self.n_class = n_class
131
132
133
134
135

        self.mlp = MLP(
            in_dim, hid_dim, n_class, input_droprate, hidden_droprate, batchnorm
        )

136
137
138
        self.dropout = node_dropout
        self.node_dropout = nn.Dropout(node_dropout)

139
    def forward(self, graph, feats, training=True):
140
141
        X = feats
        S = self.S
142
143

        if training:  # Training Mode
144
145
146
            output_list = []
            for s in range(S):
                drop_feat = drop_node(X, self.dropout, True)  # Drop node
147
148
149
150
151
                feat = GRANDConv(graph, drop_feat, self.K)  # Graph Convolution
                output_list.append(
                    th.log_softmax(self.mlp(feat), dim=-1)
                )  # Prediction

152
            return output_list
153
154
155
        else:  # Inference Mode
            drop_feat = drop_node(X, self.dropout, False)
            X = GRANDConv(graph, drop_feat, self.K)
156

157
            return th.log_softmax(self.mlp(X), dim=-1)