sgc.py 4.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
This code was modified from implementations of SGC in other backends.

Simplifying Graph Convolutional Networks (Wu, Zhang and Souza et al, 2019)
Paper: https://arxiv.org/abs/1902.07153
Author Implementation: https://github.com/Tiiiger/SGC

SGC implementation in DGL.
"""
import argparse
import textwrap

import tensorflow as tf
import tensorflow_addons as tfa

from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn.tensorflow.conv import SGConv

_DATASETS = {
20
21
22
    "citeseer": CiteseerGraphDataset(verbose=False),
    "cora": CoraGraphDataset(verbose=False),
    "pubmed": PubmedGraphDataset(verbose=False),
23
24
25
26
27
28
29
30
}


def load_data(dataset):
    return _DATASETS[dataset]


def _sum_boolean_tensor(x):
31
    return tf.reduce_sum(tf.cast(x, dtype="int64"))
32
33
34
35
36
37
38
39


def describe_data(data):
    g = data[0]

    n_edges = g.number_of_edges()
    num_classes = data.num_classes

40
41
42
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    description = textwrap.dedent(
        f"""
        ----Data statistics----
        Edges           {n_edges:,.0f}
        Classes         {num_classes:,.0f}
        Train samples   {_sum_boolean_tensor(train_mask):,.0f}
        Val samples     {_sum_boolean_tensor(val_mask):,.0f}
        Test samples    {_sum_boolean_tensor(test_mask):,.0f}
        """
    )
    return description


class SGC(tf.keras.Model):
    def __init__(self, g, num_classes, bias=False):
        super().__init__()
        self.num_classes = num_classes
        self.g = self.ensure_self_loop(g)
        self.conv = SGConv(
            in_feats=self.in_feats,
            out_feats=self.num_classes,
            k=2,
            cached=True,
67
            bias=bias,
68
69
70
71
72
73
74
        )

    def call(self, inputs):
        return self.conv(self.g, inputs)

    @property
    def in_feats(self):
75
        return self.g.ndata["feat"].shape[1]
76
77
78
79
80
81
82
83
84
85
86
87
88

    @property
    def num_nodes(self):
        return self.g.num_nodes()

    @staticmethod
    def ensure_self_loop(g):
        g = g.remove_self_loop()
        g = g.add_self_loop()
        return g

    def train_step(self, data):
        X, y = data
89
        mask = self.g.ndata["train_mask"]
90
91
92
93
94
95
96
97
98
99
100
101
102

        with tf.GradientTape() as tape:
            y_pred = self(X, training=True)
            loss = self.compiled_loss(y[mask], y_pred[mask])

        trainable_variables = self.trainable_variables
        gradients = tape.gradient(loss, trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, trainable_variables))
        self.compiled_metrics.update_state(y[mask], y_pred[mask])
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        X, y = data
103
        mask = self.g.ndata["val_mask"]
104
105
106
107
108
109
110
111
112
        y_pred = self(X, training=False)
        self.compiled_loss(y[mask], y_pred[mask])
        self.compiled_metrics.update_state(y[mask], y_pred[mask])
        return {m.name: m.result() for m in self.metrics}

    def compile(self, *args, **kwargs):
        super().compile(*args, **kwargs, run_eagerly=True)

    def fit(self, *args, **kwargs):
113
114
        kwargs["batch_size"] = self.num_nodes
        kwargs["shuffle"] = False
115
116
117
        super().fit(*args, **kwargs)

    def predict(self, *args, **kwargs):
118
        kwargs["batch_size"] = self.num_nodes
119
120
121
122
123
124
125
126
        return super().predict(*args, **kwargs)


def main(dataset, lr, bias, n_epochs, weight_decay):
    data = load_data(dataset)
    print(describe_data(data))

    g = data[0]
127
128
    X = g.ndata["feat"]
    y = g.ndata["label"]
129
130
131
132
133

    model = SGC(g=g, num_classes=data.num_classes, bias=bias)

    loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tfa.optimizers.AdamW(weight_decay, lr)
134
    accuracy = tf.metrics.SparseCategoricalAccuracy(name="accuracy")
135
136

    model.compile(optimizer, loss, metrics=[accuracy])
137
    model.fit(x=X, y=y, epochs=n_epochs, validation_data=(X, y))
138
139

    y_pred = model.predict(X, batch_size=len(X))
140
    test_mask = g.ndata["test_mask"]
141
142
143
144
145
146
    test_accuracy = accuracy(y[test_mask], y_pred[test_mask])
    print(f"Test Accuracy: {test_accuracy:.1%}")


def _parse_args():
    parser = argparse.ArgumentParser(
147
        description="Run experiment for Simple Graph Convolution (SGC)"
148
    )
149
150
    parser.add_argument("--dataset", default="cora", help="dataset to run")
    parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
151
    parser.add_argument(
152
        "--bias", action="store_true", default=False, help="flag to use bias"
153
154
155
156
157
158
159
160
161
162
    )
    parser.add_argument(
        "--n-epochs", type=int, default=100, help="number of training epochs"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-6, help="weight for L2 loss"
    )
    return parser.parse_args()


163
if __name__ == "__main__":
164
165
166
167
168
169
    args = _parse_args()
    main(
        dataset=args.dataset,
        lr=args.lr,
        bias=args.bias,
        n_epochs=args.n_epochs,
170
        weight_decay=args.weight_decay,
171
    )