Unverified Commit 47bb0593 authored by espylapiza's avatar espylapiza Committed by GitHub
Browse files

[Example] GAT on ogbn-arxiv (#2210)



* [Example] GCN on ogbn-arxiv dataset

* Add README.md

* Update GCN implementation on ogbn-arxiv

* Update GCN on ogbn-arxiv

* Fix typo

* Use evaluator to get results

* Fix duplicated

* Fix duplicated

* Update GCN on ogbn-arxiv

* Add GAT for ogbn-arxiv

* Update README.md

* Update GAT on ogbn-arxiv

* Update README.md
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 2fbffd31
# GCN on ogbn-arxiv # DGL examples for ogbn-arxiv
Requires DGL 0.5 or later versions. Requires DGL 0.5 or later versions.
### GCN
Run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result. Run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result.
```bash ```bash
python3 gcn.py --use-linear --use-labels python3 gcn.py --use-linear --use-labels
``` ```
## Usage ### GAT
Run `gat.py` with `--use-labels` enabled and you should directly see the result.
```bash
python3 gat.py --use-norm --use-labels
```
## Usage & Options
### GCN
``` ```
usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear] usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear]
...@@ -32,13 +44,38 @@ optional arguments: ...@@ -32,13 +44,38 @@ optional arguments:
--plot-curves --plot-curves
``` ```
### GAT
```
usage: GAT on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-norm]
[--lr LR] [--n-layers N_LAYERS] [--n-heads N_HEADS] [--n-hidden N_HIDDEN] [--dropout DROPOUT]
[--attn_drop ATTN_DROP] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]
optional arguments:
-h, --help show this help message and exit
--cpu CPU mode. This option overrides --gpu. (default: False)
--gpu GPU GPU device ID. (default: 0)
--n-runs N_RUNS
--n-epochs N_EPOCHS
--use-labels Use labels in the training set as input features. (default: False)
--use-norm Use symmetrically normalized adjacency matrix. (default: False)
--lr LR
--n-layers N_LAYERS
--n-heads N_HEADS
--n-hidden N_HIDDEN
--dropout DROPOUT
--attn_drop ATTN_DROP
--wd WD
--log-every LOG_EVERY
--plot-curves
```
## Results ## Results
Here are the results over 10 runs. Here are the results over 10 runs.
| | GCN | GCN+linear | GCN+labels | GCN+linear+labels | | | GCN | GCN+linear | GCN+labels | GCN+linear+labels | GAT*+labels |
|------------|:---------------:|:---------------:|:---------------:|:-----------------:| |-------------|:---------------:|:---------------:|:---------------:|:-----------------:|:---------------:|
| Val acc | 0.7361 ± 0.0009 | 0.7397 ± 0.0010 | 0.7399 ± 0.0008 | 0.7442 ± 0.0012 | | Val acc | 0.7361 ± 0.0009 | 0.7397 ± 0.0010 | 0.7399 ± 0.0008 | 0.7442 ± 0.0012 | 0.7504 ± 0.0006 |
| Test acc | 0.7246 ± 0.0021 | 0.7270 ± 0.0016 | 0.7259 ± 0.0006 | 0.7306 ± 0.0024 | | Test acc | 0.7246 ± 0.0021 | 0.7270 ± 0.0016 | 0.7259 ± 0.0006 | 0.7306 ± 0.0024 | 0.7365 ± 0.0011 |
| Parameters | 109608 | 218152 | 119848 | 238632 | | #Parameters | 109608 | 218152 | 119848 | 238632 | 1628440 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import math
import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from models import GAT
device = None
in_feats, n_classes = None, None
epsilon = 1 - math.log(2)
def gen_model(args):
norm = "both" if args.use_norm else "none"
if args.use_labels:
model = GAT(
in_feats + n_classes,
n_classes,
n_hidden=args.n_hidden,
n_layers=args.n_layers,
n_heads=args.n_heads,
activation=F.relu,
dropout=args.dropout,
attn_drop=args.attn_drop,
norm=norm,
)
else:
model = GAT(
in_feats,
n_classes,
n_hidden=args.n_hidden,
n_layers=args.n_layers,
n_heads=args.n_heads,
activation=F.relu,
dropout=args.dropout,
attn_drop=args.attn_drop,
norm=norm,
)
return model
def cross_entropy(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = th.log(epsilon + y) - math.log(epsilon)
return th.mean(y)
def compute_acc(pred, labels, evaluator):
return evaluator.eval({"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels})["acc"]
def add_labels(feat, labels, idx):
onehot = th.zeros([feat.shape[0], n_classes]).to(device)
onehot[idx, labels[idx, 0]] = 1
return th.cat([feat, onehot], dim=-1)
def adjust_learning_rate(optimizer, lr, epoch):
if epoch <= 50:
for param_group in optimizer.param_groups:
param_group["lr"] = lr * epoch / 50
def train(model, graph, labels, train_idx, optimizer, use_labels):
model.train()
feat = graph.ndata["feat"]
if use_labels:
mask_rate = 0.5
mask = th.rand(train_idx.shape) < mask_rate
train_labels_idx = train_idx[mask]
train_pred_idx = train_idx[~mask]
feat = add_labels(feat, labels, train_labels_idx)
else:
mask_rate = 0.5
mask = th.rand(train_idx.shape) < mask_rate
train_pred_idx = train_idx[mask]
optimizer.zero_grad()
pred = model(graph, feat)
loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx])
loss.backward()
optimizer.step()
return loss, pred
@th.no_grad()
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator):
model.eval()
feat = graph.ndata["feat"]
if use_labels:
feat = add_labels(feat, labels, train_idx)
pred = model(graph, feat)
train_loss = cross_entropy(pred[train_idx], labels[train_idx])
val_loss = cross_entropy(pred[val_idx], labels[val_idx])
test_loss = cross_entropy(pred[test_idx], labels[test_idx])
return (
compute_acc(pred[train_idx], labels[train_idx], evaluator),
compute_acc(pred[val_idx], labels[val_idx], evaluator),
compute_acc(pred[test_idx], labels[test_idx], evaluator),
train_loss,
val_loss,
test_loss,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
# define model and optimizer
model = gen_model(args)
model = model.to(device)
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd)
# training loop
total_time = 0
best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf")
accs, train_accs, val_accs, test_accs = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
adjust_learning_rate(optimizer, args.lr, epoch)
loss, pred = train(model, graph, labels, train_idx, optimizer, args.use_labels)
acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate(
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels, evaluator
)
toc = time.time()
total_time += toc - tic
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % args.log_every == 0:
print(f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}")
print(
f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}"
)
for l, e in zip(
[accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses],
[acc, train_acc, val_acc, test_acc, loss.item(), train_loss, val_loss, test_loss],
):
l.append(e)
print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}")
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]):
plt.plot(range(args.n_epochs), y, label=label)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_acc_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(args.n_epochs), y, label=label)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
return best_val_acc, best_test_acc
def count_parameters(args):
model = gen_model(args)
print([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device, in_feats, n_classes, epsilon
argparser = argparse.ArgumentParser("GAT on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--n-runs", type=int, default=10)
argparser.add_argument("--n-epochs", type=int, default=2000)
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.")
argparser.add_argument("--lr", type=float, default=0.002)
argparser.add_argument("--n-layers", type=int, default=3)
argparser.add_argument("--n-heads", type=int, default=3)
argparser.add_argument("--n-hidden", type=int, default=256)
argparser.add_argument("--dropout", type=float, default=0.75)
argparser.add_argument("--attn_drop", type=float, default=0.05)
argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument("--plot-curves", action="store_true")
args = argparser.parse_args()
if args.cpu:
device = th.device("cpu")
else:
device = th.device("cuda:%d" % args.gpu)
# load data
data = DglNodePropPredDataset(name="ogbn-arxiv")
evaluator = Evaluator(name="ogbn-arxiv")
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
# add reverse edges
srcs, dsts = graph.all_edges()
graph.add_edges(dsts, srcs)
# add self-loop
print(f"Total edges before adding self-loop {graph.number_of_edges()}")
graph = graph.remove_self_loop().add_self_loop()
print(f"Total edges after adding self-loop {graph.number_of_edges()}")
in_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph.create_format_()
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
test_idx = test_idx.to(device)
labels = labels.to(device)
graph = graph.to(device)
# run
val_accs = []
test_accs = []
for i in range(1, args.n_runs + 1):
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
val_accs.append(val_acc)
test_accs.append(test_acc)
print(f"Runned {args.n_runs} times")
print("Val Accs:", val_accs)
print("Test Accs:", test_accs)
print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
print(f"Number of params: {count_parameters(args)}")
if __name__ == "__main__":
main()
# Runned 10 times
# Val Accs: [0.7505956575724018, 0.7489177489177489, 0.7502600758414711, 0.7498573777643545, 0.75079700661096, 0.7504278667069365, 0.7505285412262156, 0.7512332628611699, 0.7503271921876573, 0.750729890264774]
# Test Accs: [0.7374853404110857, 0.7357982017570932, 0.7359216509268975, 0.736826944838796, 0.7385140834927885, 0.7370944180400387, 0.7358187766187272, 0.7365183219142851, 0.7343168117194412, 0.7371767174865749]
# Average val accuracy: 0.750367461995369 ± 0.0005934770264509258
# Average test accuracy: 0.7365471267205728 ± 0.0010945826389317434
# Number of params: 1628440
...@@ -17,6 +17,7 @@ from models import GCN ...@@ -17,6 +17,7 @@ from models import GCN
device = None device = None
in_feats, n_classes = None, None in_feats, n_classes = None, None
epsilon = 1 - math.log(2)
def gen_model(args): def gen_model(args):
...@@ -31,7 +32,7 @@ def gen_model(args): ...@@ -31,7 +32,7 @@ def gen_model(args):
def cross_entropy(x, labels): def cross_entropy(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none") y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = th.log(0.5 + y) - math.log(0.5) y = th.log(epsilon + y) - math.log(epsilon)
return th.mean(y) return th.mean(y)
...@@ -249,7 +250,7 @@ def main(): ...@@ -249,7 +250,7 @@ def main():
in_feats = graph.ndata["feat"].shape[1] in_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
graph.create_formats_() # graph.create_format_()
train_idx = train_idx.to(device) train_idx = train_idx.to(device)
val_idx = val_idx.to(device) val_idx = val_idx.to(device)
......
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn import torch.nn as nn
from dgl import function as fn
from dgl._ffi.base import DGLError
from dgl.nn.pytorch.utils import Identity
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
import dgl.nn.pytorch as dglnn
class Bias(nn.Module):
def __init__(self, size):
super().__init__()
self.bias = nn.Parameter(torch.Tensor(size))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.bias)
def forward(self, x):
return x + self.bias
class GCN(nn.Module): class GCN(nn.Module):
...@@ -50,3 +69,184 @@ class GCN(nn.Module): ...@@ -50,3 +69,184 @@ class GCN(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
class GATConv(nn.Module):
def __init__(
self,
in_feats,
out_feats,
num_heads=1,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
norm="none",
):
super(GATConv, self).__init__()
if norm not in ("none", "both"):
raise DGLError('Invalid norm value. Must be either "none", "both".' ' But got "{}".'.format(norm))
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._norm = norm
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False)
else:
self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer("res_fc", None)
self.reset_parameters()
self._activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
if hasattr(self, "fc"):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def set_allow_zero_in_degree(self, set_value):
self._allow_zero_in_degree = set_value
def forward(self, graph, feat):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
assert False
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src, feat_dst = h_src, h_dst
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src, feat_dst = h_src, h_dst
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[: graph.number_of_dst_nodes()]
if self._norm == "both":
degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = torch.reshape(norm, shp)
feat_src = feat_src * norm
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
if self._norm == "both":
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim() - 1)
norm = torch.reshape(norm, shp)
rst = rst * norm
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval
# activation
if self._activation is not None:
rst = self._activation(rst)
return rst
class GAT(nn.Module):
def __init__(
self, in_feats, n_classes, n_hidden, n_layers, n_heads, activation, dropout=0.0, attn_drop=0.0, norm="none"
):
super().__init__()
self.in_feats = in_feats
self.n_hidden = n_hidden
self.n_classes = n_classes
self.n_layers = n_layers
self.num_heads = n_heads
self.convs = nn.ModuleList()
self.linear = nn.ModuleList()
self.bns = nn.ModuleList()
self.biases = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes
# in_channels = n_heads if i > 0 else 1
out_channels = n_heads
self.convs.append(GATConv(in_hidden, out_hidden, num_heads=n_heads, attn_drop=attn_drop, norm=norm))
self.linear.append(nn.Linear(in_hidden, out_channels * out_hidden, bias=False))
if i < n_layers - 1:
self.bns.append(nn.BatchNorm1d(out_channels * out_hidden))
self.bias_last = Bias(n_classes)
self.dropout0 = nn.Dropout(min(0.1, dropout))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, graph, feat):
h = feat
h = self.dropout0(h)
for i in range(self.n_layers):
conv = self.convs[i](graph, h)
linear = self.linear[i](h).view(conv.shape)
h = conv + linear
if i < self.n_layers - 1:
h = h.flatten(1)
h = self.bns[i](h)
h = self.activation(h)
h = self.dropout(h)
h = h.mean(1)
h = self.bias_last(h)
return h
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment