Unverified Commit 8a2b54d1 authored by Josh Carty's avatar Josh Carty Committed by GitHub
Browse files

Add a TensorFlow example for SGC (#2480)



This is a TensorFlow implementation of Simple Graph Convolution. For
simplicity, we use the implementation of AdamW in TensorFlow Addons
that differs from Adam in PyTorch. This results in worse performance
on the evaluation datasets.
Co-authored-by: default avatarJosh Carty <joshuac@spotify.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 917a7552
# Simple Graph Convolution (SGC)
> Graph Convolutional Networks derive inspiration primarily from recent deep learning approaches, and as a result, may inherit unnecessary complexity and redundant computation. In this paper, we reduce this excess complexity through successively removing nonlinearities and collapsing weight matrices between consecutive layers. We theoretically analyze the resulting linear model and show that it corresponds to a fixed low-pass filter followed by a linear classifier.
* [Paper](https://arxiv.org/abs/1902.07153)
* [Author Implementation](https://github.com/Tiiiger/SGC)
Note: TensorFlow uses a different implementation of weight decay in AdamW to PyTorch. This results in differences in performance. You can see this by manually adding the L2 of the weights to the loss like [this](https://github.com/dmlc/dgl/blob/d696558b0bbcb60f1c4cf68dc93cd22c1077ce06/examples/tensorflow/gcn/train.py#L99) for comparison.
## Requirements
```bash
$ pip install dgl tensorflow tensorflow_addons
```
## Usage
```bash
$ python sgc.py --help
usage: sgc.py [-h] [--dataset DATASET] [--lr LR] [--bias]
[--n-epochs N_EPOCHS] [--weight-decay WEIGHT_DECAY]
Run experiment for Simple Graph Convolution (SGC)
optional arguments:
-h, --help show this help message and exit
--dataset DATASET dataset to run
--lr LR learning rate
--bias flag to use bias
--n-epochs N_EPOCHS number of training epochs
--weight-decay WEIGHT_DECAY weight for L2 loss
```
## Results
```bash
# Cora citation network dataset
$ python sgc.py --dataset cora --lr 0.2 --n-epochs 100 --weight-decay 5e-6
...
Epoch 100/100
1/1 [==============================] - 0s 40ms/step - loss: 0.0313 - accuracy: 1.0000 - val_loss: 0.7870 - val_accuracy: 0.7620
Test Accuracy: 77.2%
# Citeseer citation network dataset
$ python sgc.py --dataset citeseer --lr 0.2 --n-epochs 150 --bias --weight-decay 5e-5
...
Epoch 150/150
1/1 [==============================] - 0s 65ms/step - loss: 0.0160 - accuracy: 1.0000 - val_loss: 1.1021 - val_accuracy: 0.6420
Test Accuracy: 63.9%
# Pubmed citation network dataset
$ python sgc.py --dataset pubmed --lr 0.2 --n-epochs 100 --bias --weight-decay 5e-5
...
Epoch 100/100
1/1 [==============================] - 0s 52ms/step - loss: 0.0421 - accuracy: 1.0000 - val_loss: 0.5862 - val_accuracy: 0.7680
Test Accuracy: 76.3%
```
| Dataset | Accuracy | Paper |
|----------|----------|-------|
| Cora | 77.3% | 81.0% |
| Citeseer | 63.9% | 71.9% |
| Pubmed | 76.4% | 78.9% |
"""
This code was modified from implementations of SGC in other backends.
Simplifying Graph Convolutional Networks (Wu, Zhang and Souza et al, 2019)
Paper: https://arxiv.org/abs/1902.07153
Author Implementation: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
import argparse
import textwrap
import tensorflow as tf
import tensorflow_addons as tfa
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn.tensorflow.conv import SGConv
_DATASETS = {
'citeseer': CiteseerGraphDataset(verbose=False),
'cora': CoraGraphDataset(verbose=False),
'pubmed': PubmedGraphDataset(verbose=False)
}
def load_data(dataset):
return _DATASETS[dataset]
def _sum_boolean_tensor(x):
return tf.reduce_sum(tf.cast(x, dtype='int64'))
def describe_data(data):
g = data[0]
n_edges = g.number_of_edges()
num_classes = data.num_classes
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
description = textwrap.dedent(
f"""
----Data statistics----
Edges {n_edges:,.0f}
Classes {num_classes:,.0f}
Train samples {_sum_boolean_tensor(train_mask):,.0f}
Val samples {_sum_boolean_tensor(val_mask):,.0f}
Test samples {_sum_boolean_tensor(test_mask):,.0f}
"""
)
return description
class SGC(tf.keras.Model):
def __init__(self, g, num_classes, bias=False):
super().__init__()
self.num_classes = num_classes
self.g = self.ensure_self_loop(g)
self.conv = SGConv(
in_feats=self.in_feats,
out_feats=self.num_classes,
k=2,
cached=True,
bias=bias
)
def call(self, inputs):
return self.conv(self.g, inputs)
@property
def in_feats(self):
return self.g.ndata['feat'].shape[1]
@property
def num_nodes(self):
return self.g.num_nodes()
@staticmethod
def ensure_self_loop(g):
g = g.remove_self_loop()
g = g.add_self_loop()
return g
def train_step(self, data):
X, y = data
mask = self.g.ndata['train_mask']
with tf.GradientTape() as tape:
y_pred = self(X, training=True)
loss = self.compiled_loss(y[mask], y_pred[mask])
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
self.compiled_metrics.update_state(y[mask], y_pred[mask])
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
X, y = data
mask = self.g.ndata['val_mask']
y_pred = self(X, training=False)
self.compiled_loss(y[mask], y_pred[mask])
self.compiled_metrics.update_state(y[mask], y_pred[mask])
return {m.name: m.result() for m in self.metrics}
def compile(self, *args, **kwargs):
super().compile(*args, **kwargs, run_eagerly=True)
def fit(self, *args, **kwargs):
kwargs['batch_size'] = self.num_nodes
kwargs['shuffle'] = False
super().fit(*args, **kwargs)
def predict(self, *args, **kwargs):
kwargs['batch_size'] = self.num_nodes
return super().predict(*args, **kwargs)
def main(dataset, lr, bias, n_epochs, weight_decay):
data = load_data(dataset)
print(describe_data(data))
g = data[0]
X = g.ndata['feat']
y = g.ndata['label']
model = SGC(g=g, num_classes=data.num_classes, bias=bias)
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tfa.optimizers.AdamW(weight_decay, lr)
accuracy = tf.metrics.SparseCategoricalAccuracy(name='accuracy')
model.compile(optimizer, loss, metrics=[accuracy])
model.fit(
x=X,
y=y,
epochs=n_epochs,
validation_data=(X, y)
)
y_pred = model.predict(X, batch_size=len(X))
test_mask = g.ndata['test_mask']
test_accuracy = accuracy(y[test_mask], y_pred[test_mask])
print(f"Test Accuracy: {test_accuracy:.1%}")
def _parse_args():
parser = argparse.ArgumentParser(
description='Run experiment for Simple Graph Convolution (SGC)'
)
parser.add_argument(
"--dataset", default='cora', help="dataset to run"
)
parser.add_argument(
"--lr", type=float, default=0.2, help="learning rate"
)
parser.add_argument(
"--bias", action='store_true', default=False, help="flag to use bias"
)
parser.add_argument(
"--n-epochs", type=int, default=100, help="number of training epochs"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-6, help="weight for L2 loss"
)
return parser.parse_args()
if __name__ == '__main__':
args = _parse_args()
main(
dataset=args.dataset,
lr=args.lr,
bias=args.bias,
n_epochs=args.n_epochs,
weight_decay=args.weight_decay
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment