train.py 7.13 KB
Newer Older
1
2
import argparse
import time
3

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
import networkx as nx
7
import numpy as np
8
import tensorflow as tf
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from dgi import Classifier, DGI
10
11
12
13
14
15
from dgl.data import (
    CiteseerGraphDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    register_data_args,
)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
16
from tensorflow.keras import layers
17
18
19
20
21
22
23
24
25
26
27
28
29


def evaluate(model, features, labels, mask):
    logits = model(features, training=False)
    logits = logits[mask]
    labels = labels[mask]
    indices = tf.math.argmax(logits, axis=1)
    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
    return acc.numpy().item()


def main(args):
    # load and preprocess dataset
30
    if args.dataset == "cora":
31
        data = CoraGraphDataset()
32
    elif args.dataset == "citeseer":
33
        data = CiteseerGraphDataset()
34
    elif args.dataset == "pubmed":
35
36
        data = PubmedGraphDataset()
    else:
37
        raise ValueError("Unknown dataset: {}".format(args.dataset))
38
39

    g = data[0]
40
41
42
43
    if args.gpu < 0:
        device = "/cpu:0"
    else:
        device = "/gpu:{}".format(args.gpu)
44
45
        g = g.to(device)

46
    with tf.device(device):
47
48
49
50
51
        features = g.ndata["feat"]
        labels = g.ndata["label"]
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
52
        in_feats = features.shape[1]
53
        n_classes = data.num_classes
54
        n_edges = g.number_of_edges()
55
56
57

        # add self loop
        if args.self_loop:
58
59
            g = dgl.remove_self_loop(g)
            g = dgl.add_self_loop(g)
60
61
62
        n_edges = g.number_of_edges()

        # create DGI model
63
64
65
66
67
68
69
70
71
72
73
74
        dgi = DGI(
            g,
            in_feats,
            args.n_hidden,
            args.n_layers,
            tf.keras.layers.PReLU(
                alpha_initializer=tf.constant_initializer(0.25)
            ),
            args.dropout,
        )

        dgi_optimizer = tf.keras.optimizers.Adam(learning_rate=args.dgi_lr)
75
76
77
78
79
80
81
82
83
84
85
86
87

        # train deep graph infomax
        cnt_wait = 0
        best = 1e9
        best_t = 0
        dur = []
        for epoch in range(args.n_dgi_epochs):
            if epoch >= 3:
                t0 = time.time()

            with tf.GradientTape() as tape:
                loss = dgi(features)
                # Manually Weight Decay
88
                # We found Tensorflow has a different implementation on weight decay
89
90
91
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
                for weight in dgi.trainable_weights:
92
                    loss = loss + args.weight_decay * tf.nn.l2_loss(weight)
93
94
95
96
97
98
99
                grads = tape.gradient(loss, dgi.trainable_weights)
                dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights))

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
100
                dgi.save_weights("best_dgi.pkl")
101
102
103
104
            else:
                cnt_wait += 1

            if cnt_wait == args.patience:
105
                print("Early stopping!")
106
107
108
109
110
                break

            if epoch >= 3:
                dur.append(time.time() - t0)

111
112
113
114
115
116
117
118
119
            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.numpy().item(),
                    n_edges / np.mean(dur) / 1000,
                )
            )
120
121
122
123

        # create classifier model
        classifier = Classifier(args.n_hidden, n_classes)

124
125
126
        classifier_optimizer = tf.keras.optimizers.Adam(
            learning_rate=args.classifier_lr
        )
127
128

        # train classifier
129
130
        print("Loading {}th epoch".format(best_t))
        dgi.load_weights("best_dgi.pkl")
131
132
133
134
        embeds = dgi.encoder(features, corrupt=False)
        embeds = tf.stop_gradient(embeds)
        dur = []
        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(
135
136
            from_logits=True
        )
137
138
139
140
141
142
143
        for epoch in range(args.n_classifier_epochs):
            if epoch >= 3:
                t0 = time.time()
            with tf.GradientTape() as tape:
                preds = classifier(embeds)
                loss = loss_fcn(labels[train_mask], preds[train_mask])
                # Manually Weight Decay
144
                # We found Tensorflow has a different implementation on weight decay
145
146
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
147
                # In original code, there's no weight decay applied in this part
148
149
150
151
152
                # link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121
                # for weight in classifier.trainable_weights:
                #     loss = loss + \
                #         args.weight_decay * tf.nn.l2_loss(weight)
                grads = tape.gradient(loss, classifier.trainable_weights)
153
154
155
                classifier_optimizer.apply_gradients(
                    zip(grads, classifier.trainable_weights)
                )
156
157
158
159
            if epoch >= 3:
                dur.append(time.time() - t0)

            acc = evaluate(classifier, embeds, labels, val_mask)
160
161
162
163
164
165
166
167
168
169
            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.numpy().item(),
                    acc,
                    n_edges / np.mean(dur) / 1000,
                )
            )
170
171
172
173
174
175

        print()
        acc = evaluate(classifier, embeds, labels, test_mask)
        print("Test Accuracy {:.4f}".format(acc))


176
177
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="DGI")
178
    register_data_args(parser)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    parser.add_argument(
        "--dropout", type=float, default=0.0, help="dropout probability"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument(
        "--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
    )
    parser.add_argument(
        "--classifier-lr",
        type=float,
        default=1e-2,
        help="classifier learning rate",
    )
    parser.add_argument(
        "--n-dgi-epochs",
        type=int,
        default=300,
        help="number of training epochs",
    )
    parser.add_argument(
        "--n-classifier-epochs",
        type=int,
        default=300,
        help="number of training epochs",
    )
    parser.add_argument(
        "--n-hidden", type=int, default=512, help="number of hidden gcn units"
    )
    parser.add_argument(
        "--n-layers", type=int, default=1, help="number of hidden gcn layers"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
    )
    parser.add_argument(
        "--patience", type=int, default=20, help="early stop patience condition"
    )
    parser.add_argument(
        "--self-loop",
        action="store_true",
        help="graph self-loop (default=False)",
    )
221
222
223
224
225
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    main(args)