sgc_reddit.py 3.55 KB
Newer Older
Tianyi's avatar
Tianyi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
This code was modified from the GCN implementation in DGL examples.
Simplifying Graph Convolutional Networks
Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
import argparse, time, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
16
from dgl.nn.pytorch.conv import SGConv
Tianyi's avatar
Tianyi committed
17

18
19
def normalize(h):
    return (h-h.mean(0))/h.std(0)
Tianyi's avatar
Tianyi committed
20

21
def evaluate(model, features, graph, labels, mask):
Tianyi's avatar
Tianyi committed
22
23
    model.eval()
    with torch.no_grad():
24
        logits = model(graph, features)[mask] # only compute the evaluation set
Tianyi's avatar
Tianyi committed
25
26
27
28
29
30
31
32
33
34
35
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

def main(args):
    # load and preprocess dataset
    args.dataset = "reddit-self-loop"
    data = load_data(args)
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
36
37
38
39
40
41
42
43
    if hasattr(torch, 'BoolTensor'):
        train_mask = torch.BoolTensor(data.train_mask)
        val_mask = torch.BoolTensor(data.val_mask)
        test_mask = torch.BoolTensor(data.test_mask)
    else:
        train_mask = torch.ByteTensor(data.train_mask)
        val_mask = torch.ByteTensor(data.val_mask)
        test_mask = torch.ByteTensor(data.test_mask)
Tianyi's avatar
Tianyi committed
44
45
46
47
48
49
50
51
52
53
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
Zihao Ye's avatar
Zihao Ye committed
54
55
56
           train_mask.int().sum().item(),
           val_mask.int().sum().item(),
           test_mask.int().sum().item()))
Tianyi's avatar
Tianyi committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        features = features.cuda()
        labels = labels.cuda()
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()

    # graph preprocess and calculate normalization factor
    g = DGLGraph(data.graph)
    n_edges = g.number_of_edges()
    # normalization
    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0
    if cuda: norm = norm.cuda()
    g.ndata['norm'] = norm.unsqueeze(1)

    # create SGC model
80
81
82
    model = SGConv(in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize)
    if args.gpu >= 0:
        model = model.cuda()
Tianyi's avatar
Tianyi committed
83
84
85
86
87
88
89

    # use optimizer
    optimizer = torch.optim.LBFGS(model.parameters())

    # define loss closure
    def closure():
        optimizer.zero_grad()
90
        output = model(g, features)[train_mask]
Tianyi's avatar
Tianyi committed
91
92
93
94
95
96
97
        loss_train = F.cross_entropy(output, labels[train_mask])
        loss_train.backward()
        return loss_train

    # initialize graph
    for epoch in range(args.n_epochs):
        model.train()
98
        optimizer.step(closure)
Tianyi's avatar
Tianyi committed
99

100
    acc = evaluate(model, features, g, labels, test_mask)
Tianyi's avatar
Tianyi committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SGC')
    register_data_args(parser)
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--bias", action='store_true', default=False,
            help="flag to use bias")
    parser.add_argument("--n-epochs", type=int, default=2,
            help="number of training epochs")
    args = parser.parse_args()
    print(args)

    main(args)