sgc.py 3.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
"""
This code was modified from the GCN implementation in DGL examples.
Simplifying Graph Convolutional Networks
Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
import argparse, time, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
14
15
16
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
17
from dgl.nn.pytorch.conv import SGConv
18

19
20

def evaluate(model, g, features, labels, mask):
21
22
    model.eval()
    with torch.no_grad():
23
        logits = model(g, features)[mask] # only compute the evaluation set
24
25
26
27
28
29
30
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

def main(args):
    # load and preprocess dataset
31
32
33
34
35
36
    if args.dataset == 'cora':
        data = CoraGraphDataset()
    elif args.dataset == 'citeseer':
        data = CiteseerGraphDataset()
    elif args.dataset == 'pubmed':
        data = PubmedGraphDataset()
37
    else:
38
39
40
41
42
43
44
        raise ValueError('Unknown dataset: {}'.format(args.dataset))

    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
45
        g = g.int().to(args.gpu)
46
47
48
49
50
51

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
52
53
    in_feats = features.shape[1]
    n_classes = data.num_labels
54
    n_edges = g.number_of_edges()
55
56
57
58
59
60
61
    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
Zihao Ye's avatar
Zihao Ye committed
62
63
64
              train_mask.int().sum().item(),
              val_mask.int().sum().item(),
              test_mask.int().sum().item()))
65
66
67

    n_edges = g.number_of_edges()
    # add self loop
68
69
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
70
71

    # create SGC model
72
73
74
75
76
    model = SGConv(in_feats,
                   n_classes,
                   k=2,
                   cached=True,
                   bias=args.bias)
77

78
79
    if cuda:
        model.cuda()
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
94
        logits = model(g, features) # only compute the train set
95
        loss = loss_fcn(logits[train_mask], labels[train_mask])
96
97
98
99
100
101
102
103

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

104
        acc = evaluate(model, g, features, labels, val_mask)
105
106
107
108
109
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
              "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
                                             acc, n_edges / np.mean(dur) / 1000))

    print()
110
    acc = evaluate(model, g, features, labels, test_mask)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SGC')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=0.2,
            help="learning rate")
    parser.add_argument("--bias", action='store_true', default=False,
            help="flag to use bias")
    parser.add_argument("--n-epochs", type=int, default=100,
            help="number of training epochs")
    parser.add_argument("--weight-decay", type=float, default=5e-6,
            help="Weight for L2 loss")
    args = parser.parse_args()
    print(args)

    main(args)