sgc.py 3.89 KB
Newer Older
1
2
3
4
5
6
7
"""
This code was modified from the GCN implementation in DGL examples.
Simplifying Graph Convolutional Networks
Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
8
9
10
11
import argparse
import math
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
13
14
import dgl
import dgl.function as fn

15
16
17
18
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
19
20
21
22
23
24
from dgl.data import (
    CiteseerGraphDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    register_data_args,
)
25
from dgl.nn.pytorch.conv import SGConv
26

27
28

def evaluate(model, g, features, labels, mask):
29
30
    model.eval()
    with torch.no_grad():
31
        logits = model(g, features)[mask]  # only compute the evaluation set
32
33
34
35
36
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

37

38
39
def main(args):
    # load and preprocess dataset
40
    if args.dataset == "cora":
41
        data = CoraGraphDataset()
42
    elif args.dataset == "citeseer":
43
        data = CiteseerGraphDataset()
44
    elif args.dataset == "pubmed":
45
        data = PubmedGraphDataset()
46
    else:
47
        raise ValueError("Unknown dataset: {}".format(args.dataset))
48
49
50
51
52
53

    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
54
        g = g.int().to(args.gpu)
55

56
57
58
59
60
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
61
62
    in_feats = features.shape[1]
    n_classes = data.num_labels
63
    n_edges = g.number_of_edges()
64
65
    print(
        """----Data statistics------'
66
67
68
69
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
70
71
72
73
74
75
76
77
78
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            train_mask.int().sum().item(),
            val_mask.int().sum().item(),
            test_mask.int().sum().item(),
        )
    )
79
80
81

    n_edges = g.number_of_edges()
    # add self loop
82
83
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
84
85

    # create SGC model
86
    model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)
87

88
89
    if cuda:
        model.cuda()
90
91
92
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
93
94
95
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
96
97
98
99
100
101
102
103

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
104
        logits = model(g, features)  # only compute the train set
105
        loss = loss_fcn(logits[train_mask], labels[train_mask])
106
107
108
109
110
111
112
113

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

114
        acc = evaluate(model, g, features, labels, val_mask)
115
116
117
118
119
120
121
122
123
124
        print(
            "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
            "ETputs(KTEPS) {:.2f}".format(
                epoch,
                np.mean(dur),
                loss.item(),
                acc,
                n_edges / np.mean(dur) / 1000,
            )
        )
125
126

    print()
127
    acc = evaluate(model, g, features, labels, test_mask)
128
129
130
    print("Test Accuracy {:.4f}".format(acc))


131
132
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SGC")
133
    register_data_args(parser)
134
135
136
137
138
139
140
141
142
143
144
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
    parser.add_argument(
        "--bias", action="store_true", default=False, help="flag to use bias"
    )
    parser.add_argument(
        "--n-epochs", type=int, default=100, help="number of training epochs"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-6, help="Weight for L2 loss"
    )
145
146
147
148
    args = parser.parse_args()
    print(args)

    main(args)