sgc.py 3.89 KB
Newer Older
1
2
3
4
5
6
7
"""
This code was modified from the GCN implementation in DGL examples.
Simplifying Graph Convolutional Networks
Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
8
9
10
11
import argparse
import math
import time

12
13
14
15
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
16

17
import dgl
18
19
20
21
22
23
24
import dgl.function as fn
from dgl.data import (
    CiteseerGraphDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    register_data_args,
)
25
from dgl.nn.pytorch.conv import SGConv
26

27
28

def evaluate(model, g, features, labels, mask):
29
30
    model.eval()
    with torch.no_grad():
31
        logits = model(g, features)[mask]  # only compute the evaluation set
32
33
34
35
36
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

37

38
39
def main(args):
    # load and preprocess dataset
40
    if args.dataset == "cora":
41
        data = CoraGraphDataset()
42
    elif args.dataset == "citeseer":
43
        data = CiteseerGraphDataset()
44
    elif args.dataset == "pubmed":
45
        data = PubmedGraphDataset()
46
    else:
47
        raise ValueError("Unknown dataset: {}".format(args.dataset))
48
49
50
51
52
53

    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
54
        g = g.int().to(args.gpu)
55

56
57
58
59
60
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
61
62
    in_feats = features.shape[1]
    n_classes = data.num_labels
63
    n_edges = g.number_of_edges()
64
65
    print(
        """----Data statistics------'
66
67
68
69
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
70
71
72
73
74
75
76
77
78
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            train_mask.int().sum().item(),
            val_mask.int().sum().item(),
            test_mask.int().sum().item(),
        )
    )
79
80
81

    n_edges = g.number_of_edges()
    # add self loop
82
83
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
84
85

    # create SGC model
86
    model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)
87

88
89
    if cuda:
        model.cuda()
90
91
92
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
93
94
95
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
96
97
98
99
100
101
102
103

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
104
        logits = model(g, features)  # only compute the train set
105
        loss = loss_fcn(logits[train_mask], labels[train_mask])
106
107
108
109
110
111
112
113

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

114
        acc = evaluate(model, g, features, labels, val_mask)
115
116
117
118
119
120
121
122
123
124
        print(
            "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
            "ETputs(KTEPS) {:.2f}".format(
                epoch,
                np.mean(dur),
                loss.item(),
                acc,
                n_edges / np.mean(dur) / 1000,
            )
        )
125
126

    print()
127
    acc = evaluate(model, g, features, labels, test_mask)
128
129
130
    print("Test Accuracy {:.4f}".format(acc))


131
132
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SGC")
133
    register_data_args(parser)
134
135
136
137
138
139
140
141
142
143
144
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
    parser.add_argument(
        "--bias", action="store_true", default=False, help="flag to use bias"
    )
    parser.add_argument(
        "--n-epochs", type=int, default=100, help="number of training epochs"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-6, help="Weight for L2 loss"
    )
145
146
147
148
    args = parser.parse_args()
    print(args)

    main(args)