appnp.py 5.63 KB
Newer Older
1
2
3
import argparse
import time

4
import mxnet as mx
5
6
import numpy as np
from mxnet import gluon, nd
7
from mxnet.gluon import nn
8

9
import dgl
10
11
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
                      PubmedGraphDataset, register_data_args)
12
13
from dgl.nn.mxnet.conv import APPNPConv

14

15
class APPNP(nn.Block):
16
17
18
19
20
21
22
23
24
25
26
27
    def __init__(
        self,
        g,
        in_feats,
        hiddens,
        n_classes,
        activation,
        feat_drop,
        edge_drop,
        alpha,
        k,
    ):
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        super(APPNP, self).__init__()
        self.g = g

        with self.name_scope():
            self.layers = nn.Sequential()
            # input layer
            self.layers.add(nn.Dense(hiddens[0], in_units=in_feats))
            # hidden layers
            for i in range(1, len(hiddens)):
                self.layers.add(nn.Dense(hiddens[i], in_units=hiddens[i - 1]))
            # output layer
            self.layers.add(nn.Dense(n_classes, in_units=hiddens[-1]))
            self.activation = activation
            if feat_drop:
                self.feat_drop = nn.Dropout(feat_drop)
            else:
                self.feat_drop = lambda x: x
            self.propagate = APPNPConv(k, alpha, edge_drop)

    def forward(self, features):
        # prediction step
        h = features
        h = self.feat_drop(h)
        h = self.activation(self.layers[0](h))
        for layer in self.layers[1:-1]:
            h = self.activation(layer(h))
        h = self.layers[-1](self.feat_drop(h))
        # propagation step
        h = self.propagate(self.g, h)
        return h

59

60
61
62
63
64
def evaluate(model, features, labels, mask):
    pred = model(features).argmax(axis=1)
    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
    return accuracy.asscalar()

65

66
67
def main(args):
    # load and preprocess dataset
68
    if args.dataset == "cora":
69
        data = CoraGraphDataset()
70
    elif args.dataset == "citeseer":
71
        data = CiteseerGraphDataset()
72
    elif args.dataset == "pubmed":
73
74
        data = PubmedGraphDataset()
    else:
75
        raise ValueError("Unknown dataset: {}".format(args.dataset))
76
77
78
79
80
81
82
83
84

    g = data[0]
    if args.gpu < 0:
        cuda = False
        ctx = mx.cpu(0)
    else:
        cuda = True
        ctx = mx.gpu(args.gpu)
        g = g.to(ctx)
85

86
87
88
89
90
    features = g.ndata["feat"]
    labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
91
92
93
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
94
95
    print(
        """----Data statistics------'
96
97
98
99
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
100
101
102
103
104
105
106
107
108
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            train_mask.sum().asscalar(),
            val_mask.sum().asscalar(),
            test_mask.sum().asscalar(),
        )
    )
109
110

    # add self loop
111
112
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
113
114

    # create APPNP model
115
116
117
118
119
120
121
122
123
124
125
    model = APPNP(
        g,
        in_feats,
        args.hidden_sizes,
        n_classes,
        nd.relu,
        args.in_drop,
        args.edge_drop,
        args.alpha,
        args.k,
    )
126
127
128
129
130
131
132

    model.initialize(ctx=ctx)
    n_train_samples = train_mask.sum().asscalar()
    loss_fcn = gluon.loss.SoftmaxCELoss()

    # use optimizer
    print(model.collect_params())
133
134
135
136
137
    trainer = gluon.Trainer(
        model.collect_params(),
        "adam",
        {"learning_rate": args.lr, "wd": args.weight_decay},
    )
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
        with mx.autograd.record():
            pred = model(features)
            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
            loss = loss.sum() / n_train_samples

        loss.backward()
        trainer.step(batch_size=1)

        if epoch >= 3:
            loss.asscalar()
            dur.append(time.time() - t0)
            acc = evaluate(model, features, labels, val_mask)
157
158
159
160
161
162
163
164
165
166
            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.asscalar(),
                    acc,
                    n_edges / np.mean(dur) / 1000,
                )
            )
167
168
169
170
171

    # test set accuracy
    acc = evaluate(model, features, labels, test_mask)
    print("Test accuracy {:.2%}".format(acc))

172
173
174

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="APPNP")
175
    register_data_args(parser)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    parser.add_argument(
        "--in-drop", type=float, default=0.5, help="input feature dropout"
    )
    parser.add_argument(
        "--edge-drop", type=float, default=0.5, help="edge propagation dropout"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-epochs", type=int, default=200, help="number of training epochs"
    )
    parser.add_argument(
        "--hidden_sizes",
        type=int,
        nargs="+",
        default=[64],
        help="hidden unit sizes for appnp",
    )
    parser.add_argument(
        "--k", type=int, default=10, help="Number of propagation steps"
    )
    parser.add_argument(
        "--alpha", type=float, default=0.1, help="Teleport Probability"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
    )
203
204
205
    args = parser.parse_args()
    print(args)

206
    main(args)