sgc_reddit.py 3.14 KB
Newer Older
Tianyi's avatar
Tianyi committed
1
2
3
4
5
6
7
"""
This code was modified from the GCN implementation in DGL examples.
Simplifying Graph Convolutional Networks
Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
8
9
10
11
import argparse
import math
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
13
import dgl.function as fn

Tianyi's avatar
Tianyi committed
14
15
16
17
18
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
19
from dgl.data import load_data, register_data_args
20
from dgl.nn.pytorch.conv import SGConv
Tianyi's avatar
Tianyi committed
21

22

23
def normalize(h):
24
25
    return (h - h.mean(0)) / h.std(0)

Tianyi's avatar
Tianyi committed
26

27
def evaluate(model, features, graph, labels, mask):
Tianyi's avatar
Tianyi committed
28
29
    model.eval()
    with torch.no_grad():
30
        logits = model(graph, features)[mask]  # only compute the evaluation set
Tianyi's avatar
Tianyi committed
31
32
33
34
35
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

36

Tianyi's avatar
Tianyi committed
37
38
39
40
def main(args):
    # load and preprocess dataset
    args.dataset = "reddit-self-loop"
    data = load_data(args)
41
    g = data[0]
42
43
    if args.gpu < 0:
        cuda = False
44
    else:
45
        cuda = True
46
        g = g.int().to(args.gpu)
47

48
49
50
51
52
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
Tianyi's avatar
Tianyi committed
53
54
    in_feats = features.shape[1]
    n_classes = data.num_labels
55
    n_edges = g.num_edges()
56
57
    print(
        """----Data statistics------'
Tianyi's avatar
Tianyi committed
58
59
60
61
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
62
63
64
65
66
67
68
69
70
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            g.ndata["train_mask"].int().sum().item(),
            g.ndata["val_mask"].int().sum().item(),
            g.ndata["test_mask"].int().sum().item(),
        )
    )
Tianyi's avatar
Tianyi committed
71
72

    # graph preprocess and calculate normalization factor
73
    n_edges = g.num_edges()
Tianyi's avatar
Tianyi committed
74
75
76
77
    # normalization
    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0
78
    g.ndata["norm"] = norm.unsqueeze(1)
Tianyi's avatar
Tianyi committed
79
80

    # create SGC model
81
82
83
    model = SGConv(
        in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize
    )
84
85
    if args.gpu >= 0:
        model = model.cuda()
Tianyi's avatar
Tianyi committed
86
87
88
89
90
91
92

    # use optimizer
    optimizer = torch.optim.LBFGS(model.parameters())

    # define loss closure
    def closure():
        optimizer.zero_grad()
93
        output = model(g, features)[train_mask]
Tianyi's avatar
Tianyi committed
94
95
96
97
98
99
100
        loss_train = F.cross_entropy(output, labels[train_mask])
        loss_train.backward()
        return loss_train

    # initialize graph
    for epoch in range(args.n_epochs):
        model.train()
101
        optimizer.step(closure)
Tianyi's avatar
Tianyi committed
102

103
    acc = evaluate(model, features, g, labels, test_mask)
Tianyi's avatar
Tianyi committed
104
105
106
    print("Test Accuracy {:.4f}".format(acc))


107
108
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SGC")
Tianyi's avatar
Tianyi committed
109
    register_data_args(parser)
110
111
112
113
114
115
116
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument(
        "--bias", action="store_true", default=False, help="flag to use bias"
    )
    parser.add_argument(
        "--n-epochs", type=int, default=2, help="number of training epochs"
    )
Tianyi's avatar
Tianyi committed
117
118
119
120
    args = parser.parse_args()
    print(args)

    main(args)