Unverified Commit 305d5c16 authored by espylapiza's avatar espylapiza Committed by GitHub
Browse files

[Example] update ogbn-arxiv and ogbn-proteins results (#3018)



* [Example] GCN on ogbn-arxiv dataset

* Add README.md

* Update GCN implementation on ogbn-arxiv

* Update GCN on ogbn-arxiv

* Fix typo

* Use evaluator to get results

* Fix duplicated

* Fix duplicated

* Update GCN on ogbn-arxiv

* Add GAT for ogbn-arxiv

* Update README.md

* Update GAT on ogbn-arxiv

* Update README.md

* Update README

* Update README.md

* Update README.md

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Add GAT implementation for ogbn-proteins

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Update examples for ogbn-arxiv

* Add examples for ogbn-products.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-proteins.

* Update examples for ogbn-products.

* Update examples for ogbn-products.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-proteins.

* Update examples for ogbn-products.

* Update examples for ogbn-proteins.

* Update examples for ogbn-arxiv.

* Update examples for ogbn-proteins.

* Update README.md

* [Example] ogbn-arxiv & ogbn-proteins

* [Example] ogbn-arxiv & ogbn-proteins

* [Example] ogbn-arxiv & ogbn-proteins

* [Example] ogbn-arxiv & ogbn-proteins
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 89322856
# DGL examples for ogbn-arxiv
DGL implementation of GCN and GAT for [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/). Using some of the techniques from *Bag of Tricks for Node Classification with Graph Neural Networks* ([https://arxiv.org/abs/2103.13355](https://arxiv.org/abs/2103.13355)).
Requires DGL 0.5 or later versions.
### GCN
Run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result.
For the best score, run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result.
```bash
python3 gcn.py --use-linear --use-labels
......@@ -12,10 +14,23 @@ python3 gcn.py --use-linear --use-labels
### GAT
Run `gat.py` with `--use-labels` enabled and you should directly see the result.
For the score of `GAT(norm. adj.)+labels`, run the following command and you should directly see the result.
```bash
python3 gat.py --use-norm --use-labels --no-attn-dst --edge-drop=0.1 --input-drop=0.1
```
For the score of `GAT(norm. adj.)+label reuse`, run the following command and you should directly see the result.
```bash
python3 gat.py --use-norm --use-labels
python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25
```
For the score of `GAT(norm. adj.)+label reuse+C&S`, run the following command and you should directly see the result.
```bash
python3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred
python3 correct_and_smooth.py --use-norm
```
## Usage & Options
......@@ -23,59 +38,72 @@ python3 gat.py --use-norm --use-labels
### GCN
```
usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear]
[--lr LR] [--n-layers N_LAYERS] [--n-hidden N_HIDDEN] [--dropout DROPOUT] [--wd WD]
[--log-every LOG_EVERY] [--plot-curves]
usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear] [--lr LR] [--n-layers N_LAYERS] [--n-hidden N_HIDDEN]
[--dropout DROPOUT] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]
optional arguments:
-h, --help show this help message and exit
--cpu CPU mode. This option overrides --gpu. (default: False)
--gpu GPU GPU device ID. (default: 0)
--n-runs N_RUNS
--n-epochs N_EPOCHS
--n-runs N_RUNS running times (default: 10)
--n-epochs N_EPOCHS number of epochs (default: 1000)
--use-labels Use labels in the training set as input features. (default: False)
--use-linear Use linear layer. (default: False)
--lr LR
--n-layers N_LAYERS
--n-hidden N_HIDDEN
--dropout DROPOUT
--wd WD
--lr LR learning rate (default: 0.005)
--n-layers N_LAYERS number of layers (default: 3)
--n-hidden N_HIDDEN number of hidden units (default: 256)
--dropout DROPOUT dropout rate (default: 0.75)
--wd WD weight decay (default: 0)
--log-every LOG_EVERY
--plot-curves
log every LOG_EVERY epochs (default: 20)
--plot-curves plot learning curves (default: False)
```
### GAT
```
usage: GAT on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-norm]
[--lr LR] [--n-layers N_LAYERS] [--n-heads N_HEADS] [--n-hidden N_HIDDEN] [--dropout DROPOUT]
[--attn_drop ATTN_DROP] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]
usage: GAT on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--n-label-iters N_LABEL_ITERS] [--no-attn-dst]
[--use-norm] [--lr LR] [--n-layers N_LAYERS] [--n-heads N_HEADS] [--n-hidden N_HIDDEN] [--dropout DROPOUT] [--input-drop INPUT_DROP]
[--attn-drop ATTN_DROP] [--edge-drop EDGE_DROP] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]
optional arguments:
-h, --help show this help message and exit
--cpu CPU mode. This option overrides --gpu. (default: False)
--gpu GPU GPU device ID. (default: 0)
--n-runs N_RUNS
--n-epochs N_EPOCHS
--n-runs N_RUNS running times (default: 10)
--n-epochs N_EPOCHS number of epochs (default: 2000)
--use-labels Use labels in the training set as input features. (default: False)
--n-label-iters N_LABEL_ITERS
number of label iterations (default: 0)
--no-attn-dst Don't use attn_dst. (default: False)
--use-norm Use symmetrically normalized adjacency matrix. (default: False)
--lr LR
--n-layers N_LAYERS
--n-heads N_HEADS
--n-hidden N_HIDDEN
--dropout DROPOUT
--attn_drop ATTN_DROP
--wd WD
--lr LR learning rate (default: 0.002)
--n-layers N_LAYERS number of layers (default: 3)
--n-heads N_HEADS number of heads (default: 3)
--n-hidden N_HIDDEN number of hidden units (default: 250)
--dropout DROPOUT dropout rate (default: 0.75)
--input-drop INPUT_DROP
input drop rate (default: 0.1)
--attn-drop ATTN_DROP
attention dropout rate (default: 0.0)
--edge-drop EDGE_DROP
edge drop rate (default: 0.0)
--wd WD weight decay (default: 0)
--log-every LOG_EVERY
--plot-curves
log every LOG_EVERY epochs (default: 20)
--plot-curves plot learning curves (default: False)
```
## Results
Here are the results over 10 runs.
| | GCN | GCN+linear | GCN+labels | GCN+linear+labels | GAT*+labels |
|-------------|:---------------:|:---------------:|:---------------:|:-----------------:|:---------------:|
| Val acc | 0.7361 ± 0.0009 | 0.7397 ± 0.0010 | 0.7399 ± 0.0008 | 0.7442 ± 0.0012 | 0.7504 ± 0.0006 |
| Test acc | 0.7246 ± 0.0021 | 0.7270 ± 0.0016 | 0.7259 ± 0.0006 | 0.7306 ± 0.0024 | 0.7365 ± 0.0011 |
| #Parameters | 109608 | 218152 | 119848 | 238632 | 1628440 |
Here are the results over at least 10 runs.
| Method | Validation Accuracy | Test Accuracy | #Parameters |
|:-------------------------------:|:-------------------:|:---------------:|:-----------:|
| GCN | 0.7361 ± 0.0009 | 0.7246 ± 0.0021 | 109,608 |
| GCN+linear | 0.7397 ± 0.0010 | 0.7270 ± 0.0016 | 218,152 |
| GCN+labels | 0.7399 ± 0.0008 | 0.7259 ± 0.0006 | 119,848 |
| GCN+linear+labels | 0.7442 ± 0.0012 | 0.7306 ± 0.0024 | 238,632 |
| GAT(norm. adj.)+labels | 0.7508 ± 0.0009 | 0.7366 ± 0.0011 | 1,441,580 |
| GAT(norm. adj.)+label reuse | 0.7516 ± 0.0008 | 0.7391 ± 0.0012 | 1,441,580 |
| GAT(norm. adj.)+label reuse+C&S | 0.7519 ± 0.0008 | 0.7395 ± 0.0012 | 1,441,580 |
import argparse
import glob
import numpy as np
import torch
import torch.nn.functional as F
from dgl import function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
device = None
dataset = "ogbn-arxiv"
n_node_feats, n_classes = 0, 0
def load_data(dataset):
global n_node_feats, n_classes
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph):
global n_node_feats
# add reverse edges
srcs, dsts = graph.all_edges()
graph.add_edges(dsts, srcs)
# add self-loop
print(f"Total edges before adding self-loop {graph.number_of_edges()}")
graph = graph.remove_self_loop().add_self_loop()
print(f"Total edges after adding self-loop {graph.number_of_edges()}")
graph.create_formats_()
return graph
def general_outcome_correlation(graph, y0, n_prop=50, alpha=0.8, use_norm=False, post_step=None):
with graph.local_scope():
y = y0
for _ in range(n_prop):
if use_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (y.dim() - 1)
norm = torch.reshape(norm, shp)
y = y * norm
graph.srcdata.update({"y": y})
graph.update_all(fn.copy_u("y", "m"), fn.mean("m", "y"))
y = graph.dstdata["y"]
if use_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (y.dim() - 1)
norm = torch.reshape(norm, shp)
y = y * norm
y = alpha * y + (1 - alpha) * y0
if post_step is not None:
y = post_step(y)
return y
def evaluate(labels, pred, train_idx, val_idx, test_idx, evaluator):
return (
evaluator(pred[train_idx], labels[train_idx]),
evaluator(pred[val_idx], labels[val_idx]),
evaluator(pred[test_idx], labels[test_idx]),
)
def run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
y = pred.clone()
y[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1)
# dy = torch.zeros(graph.number_of_nodes(), n_classes, device=device)
# dy[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1) - pred[train_idx]
_train_acc, val_acc, test_acc = evaluate(labels, y, train_idx, val_idx, test_idx, evaluator_wrapper)
# print("train acc:", _train_acc)
print("original val acc:", val_acc)
print("original test acc:", test_acc)
# NOTE: Only "smooth" is performed here.
# smoothed_dy = general_outcome_correlation(
# graph, dy, alpha=args.alpha1, use_norm=args.use_norm, post_step=lambda x: x.clamp(-1, 1)
# )
# y[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1)
# smoothed_dy = smoothed_dy
# y = y + args.alpha2 * smoothed_dy # .clamp(0, 1)
smoothed_y = general_outcome_correlation(
graph, y, alpha=args.alpha, use_norm=args.use_norm, post_step=lambda x: x.clamp(0, 1)
)
_train_acc, val_acc, test_acc = evaluate(labels, smoothed_y, train_idx, val_idx, test_idx, evaluator_wrapper)
# print("train acc:", _train_acc)
print("val acc:", val_acc)
print("test acc:", test_acc)
return val_acc, test_acc
def main():
global device
argparser = argparse.ArgumentParser(description="implementation of C&S)")
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.")
argparser.add_argument("--alpha", type=float, default=0.6, help="alpha")
argparser.add_argument("--pred-files", type=str, default="./output/*.pt", help="address of prediction files")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{args.gpu}")
# load data & preprocess
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph = preprocess(graph)
graph, labels, train_idx, val_idx, test_idx = map(
lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx)
)
# run
val_accs, test_accs = [], []
for pred_file in glob.iglob(args.pred_files):
print("load:", pred_file)
pred = torch.load(pred_file)
val_acc, test_acc = run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator)
val_accs.append(val_acc)
test_accs.append(test_acc)
print(args)
print(f"Runned {len(val_accs)} times")
print("Val Accs:", val_accs)
print("Test Accs:", test_accs)
print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
if __name__ == "__main__":
main()
# Namespace(alpha=0.6, cpu=False, gpu=0, pred_files='./output/*.pt', use_norm=True)
# Runned 20 times
# Val Accs: [0.7523742407463337, 0.750729890264774, 0.7524077989194268, 0.7527098224772644, 0.752508473438706, 0.7509983556495184, 0.751904426323031, 0.7514010537266351, 0.7524077989194268, 0.753716567670056, 0.7523071244001477, 0.7518373099768448, 0.7528440551696366, 0.7509983556495184, 0.7521057753615893, 0.7520386590154032, 0.7500251686298198, 0.7513674955535421, 0.7509312393033323, 0.7518037518037518]
# Test Accs: [0.7392753533732486, 0.7381437359833755, 0.7412093903668497, 0.7402629467316832, 0.7386169578009588, 0.7380408616752052, 0.7397280003291978, 0.7401189227002448, 0.7424233072032591, 0.7397280003291978, 0.7378351130588647, 0.7400160483920746, 0.740921342303973, 0.7385758080776906, 0.7411682406435817, 0.7389667304487377, 0.7396457008826616, 0.7384935086311545, 0.7396251260210275, 0.7379997119519371]
# Average val accuracy: 0.751870868149938 ± 0.0008415008835817228
# Average test accuracy: 0.7395397403452462 ± 0.0012162384423867229
......@@ -3,10 +3,13 @@
import argparse
import math
import os
import random
import time
import dgl
import numpy as np
import torch as th
import torch
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
......@@ -15,56 +18,93 @@ from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from models import GAT
device = None
in_feats, n_classes = None, None
epsilon = 1 - math.log(2)
device = None
def gen_model(args):
norm = "both" if args.use_norm else "none"
dataset = "ogbn-arxiv"
n_node_feats, n_classes = 0, 0
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dgl.random.seed(seed)
def load_data(dataset):
global n_node_feats, n_classes
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph):
global n_node_feats
# make bidirected
feat = graph.ndata["feat"]
graph = dgl.to_bidirected(graph)
graph.ndata["feat"] = feat
# add self-loop
print(f"Total edges before adding self-loop {graph.number_of_edges()}")
graph = graph.remove_self_loop().add_self_loop()
print(f"Total edges after adding self-loop {graph.number_of_edges()}")
graph.create_formats_()
return graph
def gen_model(args):
if args.use_labels:
model = GAT(
in_feats + n_classes,
n_classes,
n_hidden=args.n_hidden,
n_layers=args.n_layers,
n_heads=args.n_heads,
activation=F.relu,
dropout=args.dropout,
attn_drop=args.attn_drop,
norm=norm,
)
n_node_feats_ = n_node_feats + n_classes
else:
n_node_feats_ = n_node_feats
model = GAT(
in_feats,
n_node_feats_,
n_classes,
n_hidden=args.n_hidden,
n_layers=args.n_layers,
n_heads=args.n_heads,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
attn_drop=args.attn_drop,
norm=norm,
edge_drop=args.edge_drop,
use_attn_dst=not args.no_attn_dst,
use_symmetric_norm=args.use_norm,
)
return model
def cross_entropy(x, labels):
def custom_loss_function(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = th.log(epsilon + y) - math.log(epsilon)
return th.mean(y)
def compute_acc(pred, labels, evaluator):
return evaluator.eval({"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels})["acc"]
y = torch.log(epsilon + y) - math.log(epsilon)
return torch.mean(y)
def add_labels(feat, labels, idx):
onehot = th.zeros([feat.shape[0], n_classes]).to(device)
onehot = torch.zeros([feat.shape[0], n_classes], device=device)
onehot[idx, labels[idx, 0]] = 1
return th.cat([feat, onehot], dim=-1)
return torch.cat([feat, onehot], dim=-1)
def adjust_learning_rate(optimizer, lr, epoch):
......@@ -73,68 +113,86 @@ def adjust_learning_rate(optimizer, lr, epoch):
param_group["lr"] = lr * epoch / 50
def train(model, graph, labels, train_idx, optimizer, use_labels):
def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, evaluator):
model.train()
feat = graph.ndata["feat"]
if use_labels:
mask_rate = 0.5
mask = th.rand(train_idx.shape) < mask_rate
if args.use_labels:
mask = torch.rand(train_idx.shape) < args.mask_rate
train_labels_idx = train_idx[mask]
train_pred_idx = train_idx[~mask]
feat = add_labels(feat, labels, train_labels_idx)
else:
mask_rate = 0.5
mask = th.rand(train_idx.shape) < mask_rate
mask = torch.rand(train_idx.shape) < args.mask_rate
train_pred_idx = train_idx[mask]
optimizer.zero_grad()
pred = model(graph, feat)
loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx])
if args.n_label_iters > 0:
unlabel_idx = torch.cat([train_pred_idx, val_idx, test_idx])
for _ in range(args.n_label_iters):
pred = pred.detach()
torch.cuda.empty_cache()
feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1)
pred = model(graph, feat)
loss = custom_loss_function(pred[train_pred_idx], labels[train_pred_idx])
loss.backward()
optimizer.step()
return loss, pred
return evaluator(pred[train_idx], labels[train_idx]), loss.item()
@th.no_grad()
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator):
@torch.no_grad()
def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator):
model.eval()
feat = graph.ndata["feat"]
if use_labels:
if args.use_labels:
feat = add_labels(feat, labels, train_idx)
pred = model(graph, feat)
train_loss = cross_entropy(pred[train_idx], labels[train_idx])
val_loss = cross_entropy(pred[val_idx], labels[val_idx])
test_loss = cross_entropy(pred[test_idx], labels[test_idx])
if args.n_label_iters > 0:
unlabel_idx = torch.cat([val_idx, test_idx])
for _ in range(args.n_label_iters):
feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1)
pred = model(graph, feat)
train_loss = custom_loss_function(pred[train_idx], labels[train_idx])
val_loss = custom_loss_function(pred[val_idx], labels[val_idx])
test_loss = custom_loss_function(pred[test_idx], labels[test_idx])
return (
compute_acc(pred[train_idx], labels[train_idx], evaluator),
compute_acc(pred[val_idx], labels[val_idx], evaluator),
compute_acc(pred[test_idx], labels[test_idx], evaluator),
evaluator(pred[train_idx], labels[train_idx]),
evaluator(pred[val_idx], labels[val_idx]),
evaluator(pred[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
pred,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
# define model and optimizer
model = gen_model(args)
model = model.to(device)
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
# define model and optimizer
model = gen_model(args).to(device)
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd)
# training loop
total_time = 0
best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf")
best_val_acc, final_test_acc, best_val_loss = 0, 0, float("inf")
final_pred = None
accs, train_accs, val_accs, test_accs = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
......@@ -144,11 +202,10 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
adjust_learning_rate(optimizer, args.lr, epoch)
loss, pred = train(model, graph, labels, train_idx, optimizer, args.use_labels)
acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)
acc, loss = train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, evaluator_wrapper)
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate(
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels, evaluator
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss, pred = evaluate(
args, model, graph, labels, train_idx, val_idx, test_idx, evaluator_wrapper
)
toc = time.time()
......@@ -157,25 +214,28 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_acc = val_acc
best_test_acc = test_acc
final_test_acc = test_acc
final_pred = pred
if epoch % args.log_every == 0:
print(f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}")
if epoch == args.n_epochs or epoch % args.log_every == 0:
print(
f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
f"Loss: {loss:.4f}, Acc: {acc:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}"
f"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}"
)
for l, e in zip(
[accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses],
[acc, train_acc, val_acc, test_acc, loss.item(), train_loss, val_loss, test_loss],
[acc, train_acc, val_acc, test_acc, loss, train_loss, val_loss, test_loss],
):
l.append(e)
print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}")
print(f"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}")
print("*" * 50)
# plot learning curves
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
......@@ -183,7 +243,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]):
plt.plot(range(args.n_epochs), y, label=label)
plt.plot(range(args.n_epochs), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
......@@ -201,7 +261,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(args.n_epochs), y, label=label)
plt.plot(range(args.n_epochs), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
......@@ -212,79 +272,76 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
return best_val_acc, best_test_acc
if args.save_pred:
os.makedirs("./output", exist_ok=True)
torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt")
return best_val_acc, final_test_acc
def count_parameters(args):
model = gen_model(args)
print([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
return sum([p.numel() for p in model.parameters() if p.requires_grad])
def main():
global device, in_feats, n_classes, epsilon
global device, n_node_feats, n_classes, epsilon
argparser = argparse.ArgumentParser("GAT on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--n-runs", type=int, default=10)
argparser.add_argument("--n-epochs", type=int, default=2000)
argparser.add_argument("--seed", type=int, default=0, help="seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=2000, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--n-label-iters", type=int, default=0, help="number of label iterations")
argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate")
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.")
argparser.add_argument("--lr", type=float, default=0.002)
argparser.add_argument("--n-layers", type=int, default=3)
argparser.add_argument("--n-heads", type=int, default=3)
argparser.add_argument("--n-hidden", type=int, default=256)
argparser.add_argument("--dropout", type=float, default=0.75)
argparser.add_argument("--attn_drop", type=float, default=0.05)
argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument("--plot-curves", action="store_true")
argparser.add_argument("--lr", type=float, default=0.002, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-heads", type=int, default=3, help="number of heads")
argparser.add_argument("--n-hidden", type=int, default=250, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.75, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention drop rate")
argparser.add_argument("--edge-drop", type=float, default=0.0, help="edge drop rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--log-every", type=int, default=20, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
argparser.add_argument("--save-pred", action="store_true", help="save final predictions")
args = argparser.parse_args()
if not args.use_labels and args.n_label_iters > 0:
raise ValueError("'--use-labels' must be enabled when n_label_iters > 0")
if args.cpu:
device = th.device("cpu")
device = torch.device("cpu")
else:
device = th.device("cuda:%d" % args.gpu)
# load data
data = DglNodePropPredDataset(name="ogbn-arxiv")
evaluator = Evaluator(name="ogbn-arxiv")
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
device = torch.device(f"cuda:{args.gpu}")
# add reverse edges
srcs, dsts = graph.all_edges()
graph.add_edges(dsts, srcs)
# load data & preprocess
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph = preprocess(graph)
# add self-loop
print(f"Total edges before adding self-loop {graph.number_of_edges()}")
graph = graph.remove_self_loop().add_self_loop()
print(f"Total edges after adding self-loop {graph.number_of_edges()}")
in_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph.create_format_()
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
test_idx = test_idx.to(device)
labels = labels.to(device)
graph = graph.to(device)
graph, labels, train_idx, val_idx, test_idx = map(
lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx)
)
# run
val_accs = []
test_accs = []
val_accs, test_accs = [], []
for i in range(1, args.n_runs + 1):
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
for i in range(args.n_runs):
seed(args.seed + i)
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1)
val_accs.append(val_acc)
test_accs.append(test_acc)
print(args)
print(f"Runned {args.n_runs} times")
print("Val Accs:", val_accs)
print("Test Accs:", test_accs)
......@@ -297,9 +354,18 @@ if __name__ == "__main__":
main()
# Namespace(attn_drop=0.0, cpu=False, dropout=0.75, edge_drop=0.1, gpu=0, input_drop=0.1, log_every=20, lr=0.002, n_epochs=2000, n_heads=3, n_hidden=250, n_label_iters=0, n_layers=3, n_runs=10, no_attn_dst=True, plot_curves=True, use_labels=True, use_norm=True, wd=0)
# Runned 10 times
# Val Accs: [0.7505956575724018, 0.7489177489177489, 0.7502600758414711, 0.7498573777643545, 0.75079700661096, 0.7504278667069365, 0.7505285412262156, 0.7512332628611699, 0.7503271921876573, 0.750729890264774]
# Test Accs: [0.7374853404110857, 0.7357982017570932, 0.7359216509268975, 0.736826944838796, 0.7385140834927885, 0.7370944180400387, 0.7358187766187272, 0.7365183219142851, 0.7343168117194412, 0.7371767174865749]
# Average val accuracy: 0.750367461995369 ± 0.0005934770264509258
# Average test accuracy: 0.7365471267205728 ± 0.0010945826389317434
# Number of params: 1628440
# Val Accs: [0.7492868888217725, 0.7524413570925199, 0.7505620993993087, 0.7500251686298198, 0.7501929594952851, 0.7513003792073559, 0.7516695191113796, 0.7505285412262156, 0.7504949830531226, 0.7515017282459143]
# Test Accs: [0.7366829208073575, 0.7384112091846182, 0.7368886694236981, 0.7345019854741477, 0.7373001666563792, 0.7362508487130424, 0.7352221056313396, 0.736477172191017, 0.7380614365368393, 0.7362919984363105]
# Average val accuracy: 0.7508003624282694 ± 0.0008760483047616948
# Average test accuracy: 0.736608851305475 ± 0.0011192876013651112
# Number of params: 1441580
# Namespace(attn_drop=0.0, cpu=False, dropout=0.75, edge_drop=0.3, gpu=0, input_drop=0.25, log_every=20, lr=0.002, n_epochs=2000, n_heads=3, n_hidden=250, n_label_iters=1, n_layers=3, n_runs=10, no_attn_dst=True, plot_curves=True, use_labels=True, use_norm=True, wd=0)
# Runned 20 times
# Val Accs: [0.7529782878620088, 0.7521393335346823, 0.7521728917077755, 0.7504949830531226, 0.7518037518037518, 0.7518373099768448, 0.7516359609382866, 0.7511325883418907, 0.7509312393033323, 0.7515017282459143, 0.7511325883418907, 0.7514346118997282, 0.7509312393033323, 0.7521393335346823, 0.7528776133427296, 0.7522735662270545, 0.7504949830531226, 0.7522735662270545, 0.7511661465149837, 0.7501258431490989]
# Test Accs: [0.7390901796185421, 0.7398720243606361, 0.7394605271279551, 0.7384523589078863, 0.7388638561405675, 0.7397280003291978, 0.7414151389831903, 0.7376499393041582, 0.7399748986688065, 0.7400366232537087, 0.7392547785116145, 0.7388844310022015, 0.7374853404110857, 0.7384317840462523, 0.7418677859391396, 0.737937987367035, 0.7381643108450096, 0.7399543238071724, 0.7377322387506944, 0.7385758080776906]
# Average val accuracy: 0.7515738783180644 ± 0.0007617982474634186
# Average test accuracy: 0.7391416167726272 ± 0.0011522198067958794
# Number of params: 1441580
......@@ -116,7 +116,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
# training loop
total_time = 0
best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf")
best_val_acc, final_test_acc, best_val_loss = 0, 0, float("inf")
accs, train_accs, val_accs, test_accs = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
......@@ -138,18 +138,17 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
toc = time.time()
total_time += toc - tic
# if val_acc > best_val_acc:
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_acc = val_acc
best_test_acc = test_acc
final_test_acc = test_acc
if epoch % args.log_every == 0:
print(f"Epoch: {epoch}/{args.n_epochs}")
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}"
f"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}"
)
for l, e in zip(
......@@ -159,7 +158,8 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
l.append(e)
print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}")
print(f"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}")
print("*" * 50)
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
......@@ -197,7 +197,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
plt.tight_layout()
plt.savefig(f"gcn_loss_{n_running}.png")
return best_val_acc, best_test_acc
return best_val_acc, final_test_acc
def count_parameters(args):
......@@ -211,19 +211,19 @@ def main():
argparser = argparse.ArgumentParser("GCN on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--n-runs", type=int, default=10)
argparser.add_argument("--n-epochs", type=int, default=1000)
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1000, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--use-linear", action="store_true", help="Use linear layer.")
argparser.add_argument("--lr", type=float, default=0.005)
argparser.add_argument("--n-layers", type=int, default=3)
argparser.add_argument("--n-hidden", type=int, default=256)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument("--plot-curves", action="store_true")
argparser.add_argument("--lr", type=float, default=0.005, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=256, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.5, help="dropout rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--log-every", type=int, default=20, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
args = argparser.parse_args()
if args.cpu:
......@@ -250,7 +250,7 @@ def main():
in_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph.create_format_()
graph.create_formats_()
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
......
......@@ -2,24 +2,43 @@ import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
from dgl import function as fn
from dgl._ffi.base import DGLError
from dgl.nn.pytorch.utils import Identity
from dgl.nn.functional import edge_softmax
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
class Bias(nn.Module):
def __init__(self, size):
class ElementWiseLinear(nn.Module):
def __init__(self, size, weight=True, bias=True, inplace=False):
super().__init__()
if weight:
self.weight = nn.Parameter(torch.Tensor(size))
else:
self.weight = None
if bias:
self.bias = nn.Parameter(torch.Tensor(size))
else:
self.bias = None
self.inplace = inplace
self.reset_parameters()
def reset_parameters(self):
if self.weight is not None:
nn.init.ones_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x):
return x + self.bias
if self.inplace:
if self.weight is not None:
x.mul_(self.weight)
if self.bias is not None:
x.add_(self.bias)
else:
if self.weight is not None:
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return x
class GCN(nn.Module):
......@@ -33,7 +52,7 @@ class GCN(nn.Module):
self.convs = nn.ModuleList()
if use_linear:
self.linear = nn.ModuleList()
self.bns = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_hidden if i > 0 else in_feats
......@@ -44,15 +63,15 @@ class GCN(nn.Module):
if use_linear:
self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False))
if i < n_layers - 1:
self.bns.append(nn.BatchNorm1d(out_hidden))
self.norms.append(nn.BatchNorm1d(out_hidden))
self.dropout0 = nn.Dropout(min(0.1, dropout))
self.input_drop = nn.Dropout(min(0.1, dropout))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, graph, feat):
h = feat
h = self.dropout0(h)
h = self.input_drop(h)
for i in range(self.n_layers):
conv = self.convs[i](graph, h)
......@@ -64,7 +83,7 @@ class GCN(nn.Module):
h = conv
if i < self.n_layers - 1:
h = self.bns[i](h)
h = self.norms[i](h)
h = self.activation(h)
h = self.dropout(h)
......@@ -79,35 +98,36 @@ class GATConv(nn.Module):
num_heads=1,
feat_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
use_attn_dst=True,
residual=False,
activation=None,
allow_zero_in_degree=False,
norm="none",
use_symmetric_norm=False,
):
super(GATConv, self).__init__()
if norm not in ("none", "both"):
raise DGLError('Invalid norm value. Must be either "none", "both".' ' But got "{}".'.format(norm))
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._norm = norm
self._use_symmetric_norm = use_symmetric_norm
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False)
else:
self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
if use_attn_dst:
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
else:
self.register_buffer("attn_r", None)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer("res_fc", None)
self.reset_parameters()
......@@ -121,6 +141,7 @@ class GATConv(nn.Module):
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
if isinstance(self.attn_r, nn.Parameter):
nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
......@@ -143,13 +164,17 @@ class GATConv(nn.Module):
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src, feat_dst = h_src, h_dst
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
h_src = self.feat_drop(feat)
feat_src = h_src
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block:
h_dst = h_src[: graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
else:
h_dst = h_src
feat_dst = feat_src
if self._norm == "both":
if self._use_symmetric_norm:
degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1)
......@@ -167,19 +192,30 @@ class GATConv(nn.Module):
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
if self.attn_r is not None:
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.dstdata.update({"er": er})
graph.apply_edges(fn.u_add_v("el", "er", "e"))
else:
graph.apply_edges(fn.copy_u("el", "e"))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
if self.training and self.edge_drop > 0:
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
if self._norm == "both":
if self._use_symmetric_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim() - 1)
......@@ -190,15 +226,29 @@ class GATConv(nn.Module):
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval
# activation
if self._activation is not None:
rst = self._activation(rst)
return rst
class GAT(nn.Module):
def __init__(
self, in_feats, n_classes, n_hidden, n_layers, n_heads, activation, dropout=0.0, attn_drop=0.0, norm="none"
self,
in_feats,
n_classes,
n_hidden,
n_layers,
n_heads,
activation,
dropout=0.0,
input_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
use_attn_dst=True,
use_symmetric_norm=False,
):
super().__init__()
self.in_feats = in_feats
......@@ -208,42 +258,49 @@ class GAT(nn.Module):
self.num_heads = n_heads
self.convs = nn.ModuleList()
self.linear = nn.ModuleList()
self.bns = nn.ModuleList()
self.biases = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes
# in_channels = n_heads if i > 0 else 1
num_heads = n_heads if i < n_layers - 1 else 1
out_channels = n_heads
self.convs.append(GATConv(in_hidden, out_hidden, num_heads=n_heads, attn_drop=attn_drop, norm=norm))
self.convs.append(
GATConv(
in_hidden,
out_hidden,
num_heads=num_heads,
attn_drop=attn_drop,
edge_drop=edge_drop,
use_attn_dst=use_attn_dst,
use_symmetric_norm=use_symmetric_norm,
residual=True,
)
)
self.linear.append(nn.Linear(in_hidden, out_channels * out_hidden, bias=False))
if i < n_layers - 1:
self.bns.append(nn.BatchNorm1d(out_channels * out_hidden))
self.norms.append(nn.BatchNorm1d(out_channels * out_hidden))
self.bias_last = Bias(n_classes)
self.bias_last = ElementWiseLinear(n_classes, weight=False, bias=True, inplace=True)
self.dropout0 = nn.Dropout(min(0.1, dropout))
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, graph, feat):
h = feat
h = self.dropout0(h)
h = self.input_drop(h)
for i in range(self.n_layers):
conv = self.convs[i](graph, h)
linear = self.linear[i](h).view(conv.shape)
h = conv + linear
h = conv
if i < self.n_layers - 1:
h = h.flatten(1)
h = self.bns[i](h)
h = self.activation(h)
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
h = h.mean(1)
......
# Sample-based GAT on OGB Products
# DGL examples for ogbn-products
## Sample-based GAT
Requires DGL 0.4.3post2 or later versions.
Run `main.py` and you should directly see the result.
Accuracy over 5 runs: 0.7863197 ± 0.00072568655
## GAT (another implementation)
Requires DGL 0.5 or later versions.
For the score of `GAT`, run the following command and you should directly see the result.
```bash
python3 gat.py
```
Or, if you want to speed up during training time, run with `--estimation-mode` enabled.
This option will do a complete evaluation when the training is over.
```bash
python3 gat.py --estimation-mode
```
## Results
Here are the results over 10 runs.
| Method | Validation Accuracy | Test Accuracy | #Parameters |
|:-------------:|:-------------------:|:---------------:|:-----------:|
| GAT (main.py) | N/A | 0.7863 ± 0.0007 | N/A |
| GAT (gat.py) | 0.9327 ± 0.0003 | 0.8126 ± 0.0018 | 1,065,127 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import math
import random
import time
from collections import OrderedDict
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from tqdm import tqdm
from models import GAT
from utils import BatchSampler, DataLoaderWrapper
epsilon = 1 - math.log(2)
device = None
dataset = "ogbn-products"
n_node_feats, n_edge_feats, n_classes = 0, 0, 0
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dgl.random.seed(seed)
def load_data(dataset):
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
graph.ndata["labels"] = labels
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph, labels, train_idx):
global n_node_feats, n_classes
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph = graph.remove_self_loop().add_self_loop()
n_node_feats = graph.ndata["feat"].shape[-1]
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["is_train"] = torch.zeros(graph.number_of_nodes(), dtype=torch.bool)
graph.ndata["is_train"][train_idx] = 1
graph.create_formats_()
return graph, labels
def gen_model(args):
if args.use_labels:
n_node_feats_ = n_node_feats + n_classes
else:
n_node_feats_ = n_node_feats
model = GAT(
n_node_feats_,
n_edge_feats,
n_classes,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_hidden=args.n_hidden,
edge_emb=0,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
attn_drop=args.attn_dropout,
edge_drop=args.edge_drop,
use_attn_dst=not args.use_attn_dst,
allow_zero_in_degree=True,
residual=False,
)
return model
def custom_loss_function(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = torch.log(epsilon + y) - math.log(epsilon)
return torch.mean(y)
def add_soft_labels(graph, soft_labels):
feat = graph.srcdata["feat"]
graph.srcdata["feat"] = torch.cat([feat, soft_labels], dim=-1)
def update_hard_labels(graph, idx=None):
if idx is None:
idx = torch.arange(graph.srcdata["is_train"].shape[0])[graph.srcdata["is_train"]]
graph.srcdata["feat"][idx, -n_classes:] = graph.srcdata["train_labels_onehot"][idx]
def train(args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator):
model.train()
loss_sum, total = 0, 0
preds = torch.zeros(labels.shape[0], n_classes)
for it in range(args.n_label_iters + 1):
preds_old = preds.clone()
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = torch.arange(len(output_nodes))
if args.use_labels:
mask = torch.rand(new_train_idx.shape) < args.mask_rate
train_labels_idx = torch.cat([new_train_idx[~mask], torch.arange(len(output_nodes), len(input_nodes))])
train_pred_idx = new_train_idx[mask]
add_soft_labels(subgraphs[0], F.softmax(preds_old[input_nodes].to(device), dim=-1))
update_hard_labels(subgraphs[0], train_labels_idx)
else:
train_pred_idx = new_train_idx
pred = model(subgraphs)
preds[output_nodes] = pred.cpu().detach()
# NOTE: This is not a complete implementation of label reuse, since it is too expensive
# to predict the nodes in validation and test set during training time.
if it == args.n_label_iters:
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx])
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(train_pred_idx)
loss_sum += loss.item() * count
total += count
torch.cuda.empty_cache()
return (
evaluator(preds[train_idx], labels[train_idx]),
loss_sum / total,
)
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
model.eval()
# Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
eval_times = 1
preds_avg = torch.zeros(labels.shape[0], n_classes)
for _ in range(eval_times):
preds = torch.zeros(labels.shape[0], n_classes)
for _it in range(args.n_label_iters + 1):
preds_old = preds.clone()
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
if args.use_labels:
add_soft_labels(subgraphs[0], F.softmax(preds_old[input_nodes].to(device), dim=-1))
update_hard_labels(subgraphs[0])
pred = model(subgraphs, inference=True)
preds[output_nodes] = pred.cpu()
torch.cuda.empty_cache()
preds_avg += preds
preds_avg = preds_avg.to(device)
preds_avg /= eval_times
train_loss = criterion(preds_avg[train_idx], labels[train_idx]).item()
val_loss = criterion(preds_avg[val_idx], labels[val_idx]).item()
test_loss = criterion(preds_avg[test_idx], labels[test_idx]).item()
return (
evaluator(preds_avg[train_idx], labels[train_idx]),
evaluator(preds_avg[val_idx], labels[val_idx]),
evaluator(preds_avg[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
criterion = custom_loss_function
n_train_samples = train_idx.shape[0]
train_batch_size = (n_train_samples + 29) // 30
train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)])
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True),
num_workers=4,
)
)
eval_batch_size = 32768
eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)])
if args.estimation_mode:
test_idx_during_training = test_idx[torch.arange(start=0, end=len(test_idx), step=45)]
else:
test_idx_during_training = test_idx
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()])
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
eval_idx,
eval_sampler,
batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4
)
best_model_state_dict = None
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
scores, train_scores, val_scores, test_scores = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
score, loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time()
total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate(
args,
model,
eval_dataloader,
labels,
train_idx,
val_idx,
test_idx_during_training,
criterion,
evaluator_wrapper,
)
if val_score > best_val_score:
best_val_score = val_score
final_test_score = test_score
if args.estimation_mode:
best_model_state_dict = {k: v.to("cpu") for k, v in model.state_dict().items()}
if epoch == args.n_epochs or epoch % args.log_every == 0:
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2s}\n"
f"Loss: {loss:.4f}, Score: {score:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
)
for l, e in zip(
[scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
):
l.append(e)
lr_scheduler.step(val_score)
if args.estimation_mode:
model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
test_idx.cpu(),
eval_sampler,
batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)[2]
print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}")
print("*" * 50)
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_score_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
return best_val_score, final_test_score
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device
argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-products", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=250, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--n-label-iters", type=int, default=0, help="number of label iterations")
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate")
argparser.add_argument("--n-heads", type=int, default=4, help="number of heads")
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=120, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.5, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-dropout", type=float, default=0.0, help="attention drop rate")
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=2, help="log every EVAL_EVERY epochs")
argparser.add_argument(
"--estimation-mode", action="store_true", help="Estimate the score of test set for speed during training."
)
argparser.add_argument("--log-every", type=int, default=2, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device("cuda:%d" % args.gpu)
# load data & preprocess
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx))
# run
val_scores, test_scores = [], []
for i in range(1, args.n_runs + 1):
seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
val_scores.append(val_score)
test_scores.append(test_score)
print(args)
print(f"Runned {args.n_runs} times")
print("Val scores:", val_scores)
print("Test scores:", test_scores)
print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
print(f"Number of params: {count_parameters(args)}")
if args.estimation_mode:
print(
"WARNING: Estimation mode is enabled. The final test score is accurate, but not accurate during training time."
)
if __name__ == "__main__":
main()
# Namespace(attn_dropout=0.0, cpu=False, dropout=0.5, edge_drop=0.1, estimation_mode=True, eval_every=2, gpu=1, input_drop=0.1, log_every=2, lr=0.01, mask_rate=0.5, n_epochs=250, n_heads=4, n_hidden=120, n_label_iters=0, n_layers=3, n_runs=10, no_attn_dst=False, plot_curves=True, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.9326348447473489, 0.9330163008926073, 0.9327619967957684, 0.932355110240826, 0.9330163008926073, 0.9327365663860845, 0.9329145792538718, 0.9322788190117742, 0.9321516669633548, 0.9329908704829235]
# Test scores: [0.8147550191112792, 0.8115680737936217, 0.8128332725586069, 0.8134062268564646, 0.8118784993477448, 0.8145462613150566, 0.8151228304665284, 0.8115274066904614, 0.8108545920615103, 0.8094583548530088]
# Average val score: 0.9326857055667167 ± 0.00030580001557474636
# Average test score: 0.8125950537054282 ± 0.001765025824381352
# Number of params: 1065127
# Namespace(attn_dropout=0.0, cpu=False, dropout=0.5, edge_drop=0.1, estimation_mode=True, eval_every=2, gpu=0, input_drop=0.1, log_every=2, lr=0.01, mask_rate=0.5, n_epochs=250, n_heads=4, n_hidden=120, n_label_iters=0, n_layers=3, n_runs=5, no_attn_dst=True, plot_curves=True, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.9332451745797625, 0.9330417313022913, 0.9328128576151362, 0.9323296798311421, 0.9324568318795616, 0.9327874272054523, 0.9327619967957684, 0.9328128576151362, 0.9322025277827226, 0.9329400096635557]
# Test scores: [0.8103399272781824, 0.8115870517750965, 0.8107294277551171, 0.8115771109276573, 0.8130244079434601, 0.8094628734200265, 0.8105681149125815, 0.809217063374258, 0.8108085026779287, 0.8151549122923549]
# Average val score: 0.932739109427053 ± 0.0003061065079170266
# Average test score: 0.8112469392356664 ± 0.0016644261188834386
# Number of params: 1060887
# update time: 2020.11.02 17:33
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import function as fn
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
class GATConv(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
out_feats,
n_heads=1,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
use_attn_dst=True,
allow_zero_in_degree=True,
use_symmetric_norm=False,
):
super(GATConv, self).__init__()
self._n_heads = n_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(node_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._use_symmetric_norm = use_symmetric_norm
# feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False)
if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None
else:
self.dst_fc = None
self.bias = nn.Parameter(out_feats * n_heads)
# attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
else:
self.attn_dst_fc = None
if edge_feats > 0:
self.attn_edge_fc = nn.Linear(edge_feats, n_heads, bias=False)
else:
self.attn_edge_fc = None
self.attn_drop = nn.Dropout(attn_drop)
self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope, inplace=True)
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.src_fc.weight, gain=gain)
if self.dst_fc is not None:
nn.init.xavier_normal_(self.dst_fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_src_fc.weight, gain=gain)
if self.attn_dst_fc is not None:
nn.init.xavier_normal_(self.attn_dst_fc.weight, gain=gain)
if self.attn_edge_fc is not None:
nn.init.xavier_normal_(self.attn_edge_fc.weight, gain=gain)
if self.bias is not None:
nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
self._allow_zero_in_degree = set_value
def forward(self, graph, feat_src, feat_edge=None):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
assert False
if graph.is_block:
feat_dst = feat_src[: graph.number_of_dst_nodes()]
else:
feat_dst = feat_src
if self._use_symmetric_norm:
degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = torch.reshape(norm, shp)
feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats)
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src})
if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node"))
else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"]
if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1)
graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"]
e = self.leaky_relu(e)
if self.training and self.edge_drop > 0:
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc"))
rst = graph.dstdata["feat_src_fc"]
if self._use_symmetric_norm:
degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim())
norm = torch.reshape(norm, shp)
rst = rst * norm
# residual
if self.dst_fc is not None:
rst += feat_dst_fc
else:
rst += self.bias
# activation
if self.activation is not None:
rst = self.activation(rst, inplace=True)
return rst
class GAT(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
n_classes,
n_layers,
n_heads,
n_hidden,
edge_emb,
activation,
dropout,
input_drop,
attn_drop,
edge_drop,
use_attn_dst=True,
allow_zero_in_degree=False,
residual=False,
):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.n_hidden = n_hidden
self.n_classes = n_classes
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.node_encoder = nn.Linear(node_feats, n_hidden)
if edge_emb > 0:
self.edge_encoder = nn.ModuleList()
else:
self.edge_encoder = None
for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else node_feats
out_hidden = n_hidden
if self.edge_encoder is not None:
self.edge_encoder.append(nn.Linear(edge_feats, edge_emb))
self.convs.append(
GATConv(
in_hidden,
edge_emb,
out_hidden,
n_heads=n_heads,
attn_drop=attn_drop,
edge_drop=edge_drop,
use_attn_dst=use_attn_dst,
allow_zero_in_degree=allow_zero_in_degree,
)
)
self.norms.append(nn.BatchNorm1d(n_heads * out_hidden))
self.pred_linear = nn.Linear(n_heads * n_hidden, n_classes)
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.activation = activation
self.residual = residual
def forward(self, g, inference=False):
if not isinstance(g, list):
subgraphs = [g] * self.n_layers
else:
subgraphs = g
h = subgraphs[0].srcdata["feat"]
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
if self.edge_encoder is not None:
efeat = subgraphs[i].edata["feat"]
efeat_emb = self.edge_encoder[i](efeat)
efeat_emb = F.relu(efeat_emb, inplace=True)
else:
efeat_emb = None
h = self.convs[i](subgraphs[i], h, efeat_emb).flatten(1, -1)
if self.residual and h_last is not None:
h += h_last[: h.shape[0], :]
h_last = h
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
if inference:
torch.cuda.empty_cache()
h = self.pred_linear(h)
return h
class MLP(nn.Module):
def __init__(
self, in_feats, n_classes, n_layers, n_hidden, activation, dropout=0.0, input_drop=0.0, residual=False,
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.linears = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes
self.linears.append(nn.Linear(in_hidden, out_hidden))
if i < n_layers - 1:
self.norms.append(nn.BatchNorm1d(out_hidden))
self.activation = activation
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.residual = residual
def forward(self, h):
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
h = self.linears[i](h)
if self.residual and 0 < i < self.n_layers - 1:
h += h_last
h_last = h
if i < self.n_layers - 1:
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
return h
import torch
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size, shuffle=False):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
if not self.shuffle:
perm = torch.arange(start=0, end=self.n)
while True:
if self.shuffle:
perm = torch.randperm(self.n)
shuf = perm.split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
# DGL examples for ogbn-products
Requires DGL 0.5 or later versions.
For the score of `MLP`, run the following command and you should directly see the result.
```bash
python3 mlp.py --eval-last
```
## Results
Here are the results over 10 runs.
| Method | Validation Accuracy | Test Accuracy | #Parameters |
|:------:|:-------------------:|:---------------:|:-----------:|
| MLP | 0.7841 ± 0.0014 | 0.6320 ± 0.0013 | 535,727 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import math
import random
import time
from collections import OrderedDict
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from tqdm import tqdm
from models import MLP
from utils import BatchSampler, DataLoaderWrapper
epsilon = 1 - math.log(2)
device = None
dataset = "ogbn-products"
n_node_feats, n_edge_feats, n_classes = 0, 0, 0
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_data(dataset):
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
graph.ndata["labels"] = labels
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph, labels):
global n_node_feats, n_classes
n_node_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item()
# graph = graph.remove_self_loop().add_self_loop()
n_node_feats = graph.ndata["feat"].shape[-1]
return graph, labels
def gen_model(args):
model = MLP(
n_node_feats,
n_classes,
n_layers=args.n_layers,
n_hidden=args.n_hidden,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
residual=False,
)
return model
def custom_loss_function(x, labels):
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = torch.log(epsilon + y) - math.log(epsilon)
return torch.mean(y)
def train(args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator):
model.train()
loss_sum, total = 0, 0
preds = torch.zeros(labels.shape[0], n_classes)
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach()
loss = criterion(pred[new_train_idx], labels[output_nodes][new_train_idx])
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(new_train_idx)
loss_sum += loss.item() * count
total += count
return (
loss_sum / total,
evaluator(preds[train_idx], labels[train_idx]),
)
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
model.eval()
preds = torch.zeros(labels.shape[0], n_classes, device=device)
eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
for _ in range(eval_times):
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred
preds /= eval_times
train_loss = criterion(preds[train_idx], labels[train_idx]).item()
val_loss = criterion(preds[val_idx], labels[val_idx]).item()
test_loss = criterion(preds[test_idx], labels[test_idx]).item()
return (
evaluator(preds[train_idx], labels[train_idx]),
evaluator(preds[val_idx], labels[val_idx]),
evaluator(preds[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval(
{"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
)["acc"]
criterion = custom_loss_function
train_batch_size = 4096
train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True),
num_workers=4,
)
)
eval_batch_size = 4096
eval_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
if args.eval_last:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()])
else:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
eval_idx,
eval_sampler,
batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4
)
best_model_state_dict = None
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
scores, train_scores, val_scores, test_scores = [], [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
loss, score = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time()
total_time += toc - tic
if epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)
if val_score > best_val_score:
best_val_score = val_score
final_test_score = test_score
if args.eval_last:
best_model_state_dict = {k: v.to("cpu") for k, v in model.state_dict().items()}
best_model_state_dict = OrderedDict(best_model_state_dict)
if epoch % args.log_every == 0:
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch}"
)
print(
f"Loss: {loss:.4f}, Score: {score:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
)
for l, e in zip(
[scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
):
l.append(e)
lr_scheduler.step(val_score)
if args.eval_last:
model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
test_idx.cpu(),
eval_sampler,
batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False),
num_workers=4,
)
)
final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)[2]
print("*" * 50)
print(f"Average epoch time: {total_time / args.n_epochs}, Test score: {final_test_score}")
if args.plot_curves:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_score_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
return best_val_score, final_test_score
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device
argparser = argparse.ArgumentParser("GAT on OGBN-Proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
argparser.add_argument("--seed", type=int, help="seed", default=0)
argparser.add_argument("--n-runs", type=int, default=10)
argparser.add_argument("--n-epochs", type=int, default=500)
argparser.add_argument("--lr", type=float, default=0.01)
argparser.add_argument("--n-layers", type=int, default=4)
argparser.add_argument("--n-hidden", type=int, default=480)
argparser.add_argument("--dropout", type=float, default=0.2)
argparser.add_argument("--input-drop", type=float, default=0)
argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--estimation-mode", action="store_true", help="Estimate the score of test set for speed.")
argparser.add_argument("--eval-last", action="store_true", help="Evaluate the score of test set at last.")
argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument("--log-every", type=int, default=1)
argparser.add_argument("--plot-curves", action="store_true")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device("cuda:%d" % args.gpu)
if args.estimation_mode:
print("WARNING: Estimation mode is enabled. The test score is not accurate.")
seed(args.seed)
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
graph, labels = preprocess(graph, labels)
graph.create_formats_()
# graph = graph.to(device)
labels = labels.to(device)
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
test_idx = test_idx.to(device)
if args.estimation_mode:
test_idx = test_idx[torch.arange(start=0, end=len(test_idx), step=50)]
val_scores, test_scores = [], []
for i in range(1, args.n_runs + 1):
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
val_scores.append(val_score)
test_scores.append(test_score)
print(args)
print(f"Runned {args.n_runs} times")
print("Val scores:", val_scores)
print("Test scores:", test_scores)
print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
print(f"Number of params: {count_parameters(args)}")
if args.estimation_mode:
print("WARNING: Estimation mode is enabled. The test score is not accurate.")
if __name__ == "__main__":
main()
# Namespace(cpu=False, dropout=0.2, estimation_mode=False, eval_every=1, eval_last=True, gpu=2, input_drop=0, log_every=1, lr=0.01, n_epochs=500, n_hidden=480, n_layers=4, n_runs=10, plot_curves=True, seed=0, wd=0)
# Runned 10 times
# Val scores: [0.7846298603870508, 0.7811713246700405, 0.7828751621188618, 0.7839941001449533, 0.7843501258805279, 0.7841466826030568, 0.7846298603870508, 0.7865880019327112, 0.7832057574447524, 0.7851384685807289]
# Test scores: [0.6318660190656417, 0.6304137516261193, 0.6329961126767946, 0.6312885462007662, 0.6340624944929965, 0.6301507710256831, 0.6314534738969161, 0.6334637843631373, 0.6312465235275007, 0.6329857199726536]
# Average val score: 0.7840729344149735 ± 0.0013702460721628086
# Average test score: 0.6319927196848208 ± 0.001252448369121226
# Number of params: 535727
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(
self, in_feats, n_classes, n_layers, n_hidden, activation, dropout=0.0, input_drop=0.0, residual=False,
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.linears = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_hidden if i > 0 else in_feats
out_hidden = n_hidden if i < n_layers - 1 else n_classes
self.linears.append(nn.Linear(in_hidden, out_hidden))
if i < n_layers - 1:
self.norms.append(nn.BatchNorm1d(out_hidden))
self.activation = activation
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.residual = residual
def forward(self, h):
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
h = self.linears[i](h)
if self.residual and 0 < i < self.n_layers - 1:
h += h_last
h_last = h
if i < self.n_layers - 1:
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
return h
import torch
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size, shuffle=False):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
while True:
if self.shuffle:
perm = torch.randperm(self.n)
else:
perm = torch.arange(start=0, end=self.n)
shuf = perm.split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
# DGL for ogbn-proteins
## Models
## GAT
DGL implementation of GAT for [ogbn-proteins](https://ogb.stanford.edu/docs/nodeprop/). Using some of the techniques from *Bag of Tricks for Node Classification with Graph Neural Networks* ([https://arxiv.org/abs/2103.13355](https://arxiv.org/abs/2103.13355)).
Requires DGL 0.5 or later versions.
### Usage
For the best score, run `gat.py` and you should directly see the result.
```bash
python3 gat.py
```
For the score of `GAT+labels`, run `gat.py` with `--use-labels` enabled and you should directly see the result.
```bash
python3 gat.py --use-labels
```
### Results
Here are the results over 10 runs.
| Method | Validation ROC-AUC | Test ROC-AUC | #Parameters |
|:----------:|:------------------:|:---------------:|:-----------:|
| GAT | 0.9276 ± 0.0007 | 0.8747 ± 0.0016 | 2,475,232 |
| GAT+labels | 0.9280 ± 0.0008 | 0.8765 ± 0.0008 | 2,484,192 |
## MWE-GCN and MWE-DGCN
### Models
[MWE-GCN and MWE-DGCN](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf) are GCN models designed for graphs whose edges contain multi-dimensional edge weights that indicate the strengths of the relations represented by the edges.
## Dependencies
- DGL 0.4.3
### Dependencies
- DGL 0.5.2
- PyTorch 1.4.0
- OGB 1.2.0
- Tensorboard 2.1.1
## Usage
### Usage
To use MWE-GCN:
```python
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import os
import random
import sys
import time
import dgl
import dgl.function as fn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
from models import GAT
from utils import BatchSampler, DataLoaderWrapper
device = None
dataset = "ogbn-proteins"
n_node_feats, n_edge_feats, n_classes = 0, 8, 112
def seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dgl.random.seed(seed)
def load_data(dataset):
data = DglNodePropPredDataset(name=dataset)
evaluator = Evaluator(name=dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
graph.ndata["labels"] = labels
return graph, labels, train_idx, val_idx, test_idx, evaluator
def preprocess(graph, labels, train_idx):
global n_node_feats
# The sum of the weights of adjacent edges is used as node features.
graph.update_all(fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat"))
n_node_feats = graph.ndata["feat"].shape[-1]
# Only the labels in the training set are used as features, while others are filled with zeros.
graph.ndata["train_labels_onehot"] = torch.zeros(graph.number_of_nodes(), n_classes)
graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1)
graph.create_formats_()
return graph, labels
def gen_model(args):
if args.use_labels:
n_node_feats_ = n_node_feats + n_classes
else:
n_node_feats_ = n_node_feats
model = GAT(
n_node_feats_,
n_edge_feats,
n_classes,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_hidden=args.n_hidden,
edge_emb=16,
activation=F.relu,
dropout=args.dropout,
input_drop=args.input_drop,
attn_drop=args.attn_drop,
edge_drop=args.edge_drop,
use_attn_dst=not args.no_attn_dst,
)
return model
def add_labels(graph, idx):
feat = graph.srcdata["feat"]
train_labels_onehot = torch.zeros([feat.shape[0], n_classes], device=device)
train_labels_onehot[idx] = graph.srcdata["train_labels_onehot"][idx]
graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1)
def train(args, model, dataloader, _labels, _train_idx, criterion, optimizer, _evaluator):
model.train()
loss_sum, total = 0, 0
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = torch.arange(len(output_nodes), device=device)
if args.use_labels:
train_labels_idx = torch.arange(len(output_nodes), len(input_nodes), device=device)
train_pred_idx = new_train_idx
add_labels(subgraphs[0], train_labels_idx)
else:
train_pred_idx = new_train_idx
pred = model(subgraphs)
loss = criterion(pred[train_pred_idx], subgraphs[-1].dstdata["labels"][train_pred_idx].float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(train_pred_idx)
loss_sum += loss.item() * count
total += count
# torch.cuda.empty_cache()
return loss_sum / total
@torch.no_grad()
def evaluate(args, model, dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator):
model.eval()
preds = torch.zeros(labels.shape).to(device)
# Due to the memory capacity constraints, we use sampling for inference and calculate the average of the predictions 'eval_times' times.
eval_times = 1
for _ in range(eval_times):
for input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(input_nodes)))
if args.use_labels:
add_labels(subgraphs[0], new_train_idx)
pred = model(subgraphs)
preds[output_nodes] += pred
# torch.cuda.empty_cache()
preds /= eval_times
train_loss = criterion(preds[train_idx], labels[train_idx].float()).item()
val_loss = criterion(preds[val_idx], labels[val_idx].float()).item()
test_loss = criterion(preds[test_idx], labels[test_idx].float()).item()
return (
evaluator(preds[train_idx], labels[train_idx]),
evaluator(preds[val_idx], labels[val_idx]),
evaluator(preds[test_idx], labels[test_idx]),
train_loss,
val_loss,
test_loss,
preds,
)
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
evaluator_wrapper = lambda pred, labels: evaluator.eval({"y_pred": pred, "y_true": labels})["rocauc"]
train_batch_size = (len(train_idx) + 9) // 10
# batch_size = len(train_idx)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size),
num_workers=10,
)
)
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler,
batch_sampler=BatchSampler(graph.number_of_nodes(), batch_size=65536),
num_workers=10,
)
)
criterion = nn.BCEWithLogitsLoss()
model = gen_model(args).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.75, patience=50, verbose=True)
total_time = 0
val_score, best_val_score, final_test_score = 0, 0, 0
train_scores, val_scores, test_scores = [], [], []
losses, train_losses, val_losses, test_losses = [], [], [], []
final_pred = None
for epoch in range(1, args.n_epochs + 1):
tic = time.time()
loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time()
total_time += toc - tic
if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
train_score, val_score, test_score, train_loss, val_loss, test_loss, pred = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)
if val_score > best_val_score:
best_val_score = val_score
final_test_score = test_score
final_pred = pred
if epoch % args.log_every == 0:
print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s"
)
print(
f"Loss: {loss:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
)
for l, e in zip(
[train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses],
[train_score, val_score, test_score, loss, train_loss, val_loss, test_loss],
):
l.append(e)
lr_scheduler.step(val_score)
print("*" * 50)
print(f"Best val score: {best_val_score}, Final test score: {final_test_score}")
print("*" * 50)
if args.plot:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.set_yticks(np.linspace(0, 1.0, 101))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.01))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_score_{n_running}.png")
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.set_xticks(np.arange(0, args.n_epochs, 100))
ax.tick_params(labeltop=True, labelright=True)
for y, label in zip(
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
):
plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(AutoMinorLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
plt.grid(which="major", color="red", linestyle="dotted")
plt.grid(which="minor", color="orange", linestyle="dotted")
plt.legend()
plt.tight_layout()
plt.savefig(f"gat_loss_{n_running}.png")
if args.save_pred:
os.makedirs("./output", exist_ok=True)
torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt")
return best_val_score, final_test_score
def count_parameters(args):
model = gen_model(args)
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def main():
global device
argparser = argparse.ArgumentParser(
"GAT implementation on ogbn-proteins", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides '--gpu'.")
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
argparser.add_argument("--seed", type=int, default=0, help="random seed")
argparser.add_argument("--n-runs", type=int, default=10, help="running times")
argparser.add_argument("--n-epochs", type=int, default=1200, help="number of epochs")
argparser.add_argument(
"--use-labels", action="store_true", help="Use labels in the training set as input features."
)
argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
argparser.add_argument("--n-heads", type=int, default=6, help="number of heads")
argparser.add_argument("--lr", type=float, default=0.01, help="learning rate")
argparser.add_argument("--n-layers", type=int, default=6, help="number of layers")
argparser.add_argument("--n-hidden", type=int, default=80, help="number of hidden units")
argparser.add_argument("--dropout", type=float, default=0.25, help="dropout rate")
argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention dropout rate")
argparser.add_argument("--edge-drop", type=float, default=0.1, help="edge drop rate")
argparser.add_argument("--wd", type=float, default=0, help="weight decay")
argparser.add_argument("--eval-every", type=int, default=5, help="evaluate every EVAL_EVERY epochs")
argparser.add_argument("--log-every", type=int, default=5, help="log every LOG_EVERY epochs")
argparser.add_argument("--plot", action="store_true", help="plot learning curves")
argparser.add_argument("--save-pred", action="store_true", help="save final predictions")
args = argparser.parse_args()
if args.cpu:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{args.gpu}")
# load data & preprocess
print("Loading data")
graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
print("Preprocessing")
graph, labels = preprocess(graph, labels, train_idx)
labels, train_idx, val_idx, test_idx = map(lambda x: x.to(device), (labels, train_idx, val_idx, test_idx))
# run
val_scores, test_scores = [], []
for i in range(args.n_runs):
print("Running", i)
seed(args.seed + i)
val_score, test_score = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1)
val_scores.append(val_score)
test_scores.append(test_score)
print(" ".join(sys.argv))
print(args)
print(f"Runned {args.n_runs} times")
print("Val scores:", val_scores)
print("Test scores:", test_scores)
print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
print(f"Number of params: {count_parameters(args)}")
if __name__ == "__main__":
main()
# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=6, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.927741031859485, 0.9272113161947824, 0.9271363901359605, 0.9275579074100136, 0.9264291968462317, 0.9275278541203443, 0.9286381790529751, 0.9288245051991526, 0.9269289529175155, 0.9278177920224489]
# Test scores: [0.8754403567694566, 0.8749781870941457, 0.8735933245353141, 0.8759835445000637, 0.8745950242855286, 0.8742530369108132, 0.8784892022402326, 0.873345314887444, 0.8724393129004984, 0.874077975765639]
# Average val score: 0.927581312575891 ± 0.0006953509986591492
# Average test score: 0.8747195279889135 ± 0.001593598488797452
# Number of params: 2475232
# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=7, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=True, wd=0)
# Runned 10 times
# Val scores: [0.9293776332568928, 0.9281066322254939, 0.9286775378440911, 0.9270252685136046, 0.9267937838323375, 0.9277731792338011, 0.9285615428437761, 0.9270819730221879, 0.9276822010553241, 0.9287115722177839]
# Test scores: [0.8761623033485811, 0.8773002619440896, 0.8756680817047869, 0.8751873860287073, 0.875781797307807, 0.8764533839446703, 0.8771202308989311, 0.8765888651476396, 0.8773581283481205, 0.8777751912293709]
# Average val score: 0.9279791324045293 ± 0.0008115348697502517
# Average test score: 0.8765395629902706 ± 0.0008016806017700173
# Number of params: 2484192
import os
import numpy as np
import time
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from utils import load_model, set_random_seed
def normalize_edge_weights(graph, device, num_ew_channels):
degs = graph.in_degrees().float()
degs = torch.clamp(degs, min=1)
norm = torch.pow(degs, 0.5)
norm = norm.to(args['device'])
graph.ndata['norm'] = norm.unsqueeze(1)
graph.apply_edges(fn.e_div_u('feat', 'norm', 'feat'))
graph.apply_edges(fn.e_div_v('feat', 'norm', 'feat'))
norm = norm.to(args["device"])
graph.ndata["norm"] = norm.unsqueeze(1)
graph.apply_edges(fn.e_div_u("feat", "norm", "feat"))
graph.apply_edges(fn.e_div_v("feat", "norm", "feat"))
for channel in range(num_ew_channels):
graph.edata['feat_' + str(channel)] = graph.edata['feat'][:, channel:channel+1]
graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1]
def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
model.train()
logits = model(graph)[node_idx]
labels = graph.ndata['labels'][node_idx]
labels = graph.ndata["labels"][node_idx]
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
......@@ -39,79 +40,71 @@ def run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):
labels = labels.cpu().numpy()
preds = logits.cpu().detach().numpy()
return loss, evaluator.eval({"y_true": labels, "y_pred": preds})['rocauc']
return loss, evaluator.eval({"y_true": labels, "y_pred": preds})["rocauc"]
def run_an_eval_epoch(graph, splitted_idx, model, evaluator):
model.eval()
with torch.no_grad():
logits = model(graph)
labels = graph.ndata['labels'].cpu().numpy()
labels = graph.ndata["labels"].cpu().numpy()
preds = logits.cpu().detach().numpy()
train_score = evaluator.eval({
"y_true": labels[splitted_idx["train"]],
"y_pred": preds[splitted_idx["train"]]
})
val_score = evaluator.eval({
"y_true": labels[splitted_idx["valid"]],
"y_pred": preds[splitted_idx["valid"]]
})
test_score = evaluator.eval({
"y_true": labels[splitted_idx["test"]],
"y_pred": preds[splitted_idx["test"]]
})
return train_score['rocauc'], val_score['rocauc'], test_score['rocauc']
train_score = evaluator.eval({"y_true": labels[splitted_idx["train"]], "y_pred": preds[splitted_idx["train"]]})
val_score = evaluator.eval({"y_true": labels[splitted_idx["valid"]], "y_pred": preds[splitted_idx["valid"]]})
test_score = evaluator.eval({"y_true": labels[splitted_idx["test"]], "y_pred": preds[splitted_idx["test"]]})
return train_score["rocauc"], val_score["rocauc"], test_score["rocauc"]
def main(args):
print (args)
if (args['rand_seed'] > -1):
set_random_seed(args['rand_seed'])
print(args)
if args["rand_seed"] > -1:
set_random_seed(args["rand_seed"])
dataset = DglNodePropPredDataset(name=args['dataset'])
print(dataset.meta_info[args['dataset']])
dataset = DglNodePropPredDataset(name=args["dataset"])
print(dataset.meta_info)
splitted_idx = dataset.get_idx_split()
graph = dataset.graph[0]
graph.ndata['labels'] = dataset.labels.float().to(args['device'])
graph.edata['feat'] = graph.edata['feat'].float().to(args['device'])
if (args['ewnorm'] == 'both'):
print ('Symmetric normalization of edge weights by degree')
normalize_edge_weights(graph, args['device'], args['num_ew_channels'])
elif (args['ewnorm'] == 'none'):
print ('Not normalizing edge weights')
for channel in range(args['num_ew_channels']):
graph.edata['feat_' + str(channel)] = graph.edata['feat'][:, channel:channel+1]
model = load_model(args).to(args['device'])
optimizer = Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
graph.ndata["labels"] = dataset.labels.float().to(args["device"])
graph.edata["feat"] = graph.edata["feat"].float().to(args["device"])
if args["ewnorm"] == "both":
print("Symmetric normalization of edge weights by degree")
normalize_edge_weights(graph, args["device"], args["num_ew_channels"])
elif args["ewnorm"] == "none":
print("Not normalizing edge weights")
for channel in range(args["num_ew_channels"]):
graph.edata["feat_" + str(channel)] = graph.edata["feat"][:, channel : channel + 1]
model = load_model(args).to(args["device"])
optimizer = Adam(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])
min_lr = 1e-3
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=0.7, patience=100, verbose=True, min_lr=min_lr)
print ('scheduler min_lr', min_lr)
scheduler = ReduceLROnPlateau(optimizer, "max", factor=0.7, patience=100, verbose=True, min_lr=min_lr)
print("scheduler min_lr", min_lr)
criterion = nn.BCEWithLogitsLoss()
evaluator = Evaluator(args['dataset'])
evaluator = Evaluator(args["dataset"])
print ('model', args['model'])
print ('n_layers', args['n_layers'])
print ('hidden dim', args['hidden_feats'])
print ('lr', args['lr'])
print("model", args["model"])
print("n_layers", args["n_layers"])
print("hidden dim", args["hidden_feats"])
print("lr", args["lr"])
dur = []
best_val_score = 0.
best_val_score = 0.0
num_patient_epochs = 0
model_folder = './saved_models/'
model_path = model_folder + str(args['exp_name']) + '_' + str(args['postfix'])
model_folder = "./saved_models/"
model_path = model_folder + str(args["exp_name"]) + "_" + str(args["postfix"])
if not os.path.exists(model_folder):
os.makedirs(model_folder)
for epoch in range(1, args['num_epochs'] + 1):
for epoch in range(1, args["num_epochs"] + 1):
if epoch >= 3:
t0 = time.time()
loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model,
criterion, optimizer, evaluator)
loss, train_score = run_a_train_epoch(graph, splitted_idx["train"], model, criterion, optimizer, evaluator)
if epoch >= 3:
dur.append(time.time() - t0)
......@@ -119,8 +112,7 @@ def main(args):
else:
avg_time = None
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx,
model, evaluator)
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
scheduler.step(val_score)
......@@ -132,50 +124,54 @@ def main(args):
else:
num_patient_epochs += 1
print('Epoch {:d}, loss {:.4f}, train score {:.4f}, '
'val score {:.4f}, avg time {}, num patient epochs {:d}'.format(
epoch, loss, train_score, val_score, avg_time, num_patient_epochs))
print(
"Epoch {:d}, loss {:.4f}, train score {:.4f}, "
"val score {:.4f}, avg time {}, num patient epochs {:d}".format(
epoch, loss, train_score, val_score, avg_time, num_patient_epochs
)
)
if num_patient_epochs == args['patience']:
if num_patient_epochs == args["patience"]:
break
model.load_state_dict(torch.load(model_path))
train_score, val_score, test_score = run_an_eval_epoch(graph, splitted_idx, model, evaluator)
print('Train score {:.4f}'.format(train_score))
print('Valid score {:.4f}'.format(val_score))
print('Test score {:.4f}'.format(test_score))
print("Train score {:.4f}".format(train_score))
print("Valid score {:.4f}".format(val_score))
print("Test score {:.4f}".format(test_score))
with open("results.txt", "w") as f:
f.write("loss {:.4f}\n".format(loss))
f.write("Best validation rocauc {:.4f}\n".format(best_val_score))
f.write("Test rocauc {:.4f}\n".format(test_score))
with open('results.txt', 'w') as f:
f.write('loss {:.4f}\n'.format(loss))
f.write('Best validation rocauc {:.4f}\n'.format(best_val_score))
f.write('Test rocauc {:.4f}\n'.format(test_score))
print(args)
print (args)
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
from configure import get_exp_configure
parser = argparse.ArgumentParser(
description='OGB node property prediction with DGL using full graph training')
parser.add_argument('-m', '--model', type=str, choices=['MWE-GCN', 'MWE-DGCN'], default='MWE-DGCN',
help='Model to use')
parser.add_argument('-c', '--cuda', type=str, default='none')
parser.add_argument('--postfix', type=str, default='', help='a string appended to the file name of the saved model')
parser.add_argument('--rand_seed', type=int, default=-1, help='random seed for torch and numpy')
parser.add_argument('--residual', action='store_true')
parser.add_argument('--ewnorm', type=str, default='none', choices=['none', 'both'])
parser = argparse.ArgumentParser(description="OGB node property prediction with DGL using full graph training")
parser.add_argument(
"-m", "--model", type=str, choices=["MWE-GCN", "MWE-DGCN"], default="MWE-DGCN", help="Model to use"
)
parser.add_argument("-c", "--cuda", type=str, default="none")
parser.add_argument("--postfix", type=str, default="", help="a string appended to the file name of the saved model")
parser.add_argument("--rand_seed", type=int, default=-1, help="random seed for torch and numpy")
parser.add_argument("--residual", action="store_true")
parser.add_argument("--ewnorm", type=str, default="none", choices=["none", "both"])
args = parser.parse_args().__dict__
# Get experiment configuration
args['dataset'] = 'ogbn-proteins'
args['exp_name'] = '_'.join([args['model'], args['dataset']])
args["dataset"] = "ogbn-proteins"
args["exp_name"] = "_".join([args["model"], args["dataset"]])
args.update(get_exp_configure(args))
if not (args['cuda'] == 'none'):
args['device'] = torch.device('cuda: ' + str(args['cuda']))
if not (args["cuda"] == "none"):
args["device"] = torch.device("cuda: " + str(args["cuda"]))
else:
args['device'] = torch.device('cpu')
args["device"] = torch.device("cpu")
main(args)
import math
from functools import partial
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl import DGLGraph
from dgl import function as fn
from dgl._ffi.base import DGLError
from dgl.base import ALL
from dgl.nn.pytorch.utils import Identity
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
from torch.nn import init
from torch.utils.checkpoint import checkpoint
class MWEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
activation,
bias=True,
num_channels=8,
aggr_mode='sum'):
def __init__(self, in_feats, out_feats, activation, bias=True, num_channels=8, aggr_mode="sum"):
super(MWEConv, self).__init__()
self.num_channels = num_channels
self._in_feats = in_feats
......@@ -26,18 +31,18 @@ class MWEConv(nn.Module):
self.reset_parameters()
self.activation = activation
if (aggr_mode == 'concat'):
self.aggr_mode = 'concat'
if aggr_mode == "concat":
self.aggr_mode = "concat"
self.final = nn.Linear(out_feats * self.num_channels, out_feats)
elif (aggr_mode == 'sum'):
self.aggr_mode = 'sum'
elif aggr_mode == "sum":
self.aggr_mode = "sum"
self.final = nn.Linear(out_feats, out_feats)
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
stdv = 1. / math.sqrt(self.bias.size(0))
stdv = 1.0 / math.sqrt(self.bias.size(0))
self.bias.data.uniform_(-stdv, stdv)
def forward(self, g, node_state_prev):
......@@ -54,20 +59,22 @@ class MWEConv(nn.Module):
for c in range(self.num_channels):
node_state_c = node_state
if self._out_feats < self._in_feats:
g.ndata['feat_' + str(c)] = torch.mm(node_state_c, self.weight[:, :, c])
g.ndata["feat_" + str(c)] = torch.mm(node_state_c, self.weight[:, :, c])
else:
g.ndata['feat_' + str(c)] = node_state_c
g.update_all(fn.src_mul_edge('feat_' + str(c), 'feat_' + str(c), 'm'), fn.sum('m', 'feat_' + str(c) + '_new'))
node_state_c = g.ndata.pop('feat_' + str(c) + '_new')
g.ndata["feat_" + str(c)] = node_state_c
g.update_all(
fn.src_mul_edge("feat_" + str(c), "feat_" + str(c), "m"), fn.sum("m", "feat_" + str(c) + "_new")
)
node_state_c = g.ndata.pop("feat_" + str(c) + "_new")
if self._out_feats >= self._in_feats:
node_state_c = torch.mm(node_state_c, self.weight[:, :, c])
if self.bias is not None:
node_state_c = node_state_c + self.bias[:, c]
node_state_c = self.activation(node_state_c)
new_node_states.append(node_state_c)
if (self.aggr_mode == 'sum'):
if self.aggr_mode == "sum":
node_states = torch.stack(new_node_states, dim=1).sum(1)
elif (self.aggr_mode == 'concat'):
elif self.aggr_mode == "concat":
node_states = torch.cat(new_node_states, dim=1)
node_states = self.final(node_states)
......@@ -76,25 +83,15 @@ class MWEConv(nn.Module):
class MWE_GCN(nn.Module):
def __init__(self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
aggr_mode='sum',
device='cpu'):
def __init__(self, n_input, n_hidden, n_output, n_layers, activation, dropout, aggr_mode="sum", device="cpu"):
super(MWE_GCN, self).__init__()
self.dropout = dropout
self.activation = activation
self.layers = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode))
for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode))
self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device
......@@ -112,16 +109,9 @@ class MWE_GCN(nn.Module):
class MWE_DGCN(nn.Module):
def __init__(self,
n_input,
n_hidden,
n_output,
n_layers,
activation,
dropout,
residual=False,
aggr_mode='sum',
device='cpu'):
def __init__(
self, n_input, n_hidden, n_output, n_layers, activation, dropout, residual=False, aggr_mode="sum", device="cpu"
):
super(MWE_DGCN, self).__init__()
self.n_layers = n_layers
self.activation = activation
......@@ -131,12 +121,10 @@ class MWE_DGCN(nn.Module):
self.layers = nn.ModuleList()
self.layer_norms = nn.ModuleList()
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
self.layers.append(MWEConv(n_input, n_hidden, activation=activation, aggr_mode=aggr_mode))
for i in range(n_layers - 1):
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, \
aggr_mode=aggr_mode))
self.layers.append(MWEConv(n_hidden, n_hidden, activation=activation, aggr_mode=aggr_mode))
for i in range(n_layers):
self.layer_norms.append(nn.LayerNorm(n_hidden, elementwise_affine=True))
......@@ -144,23 +132,22 @@ class MWE_DGCN(nn.Module):
self.pred_out = nn.Linear(n_hidden, n_output)
self.device = device
def forward(self, g, node_state=None):
node_state = torch.ones(g.number_of_nodes(), 1).float().to(self.device)
node_state = self.layers[0](g, node_state)
for layer in range(1, self.n_layers):
node_state_new = self.layer_norms[layer-1](node_state)
node_state_new = self.layer_norms[layer - 1](node_state)
node_state_new = self.activation(node_state_new)
node_state_new = F.dropout(node_state_new, p=self.dropout, training=self.training)
if (self.residual == 'true'):
if self.residual == "true":
node_state = node_state + self.layers[layer](g, node_state_new)
else:
node_state = self.layers[layer](g, node_state_new)
node_state = self.layer_norms[self.n_layers-1](node_state)
node_state = self.layer_norms[self.n_layers - 1](node_state)
node_state = self.activation(node_state)
node_state = F.dropout(node_state, p=self.dropout, training=self.training)
......@@ -169,3 +156,249 @@ class MWE_DGCN(nn.Module):
return out
class GATConv(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
out_feats,
n_heads=1,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
use_attn_dst=True,
allow_zero_in_degree=True,
use_symmetric_norm=False,
):
super(GATConv, self).__init__()
self._n_heads = n_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(node_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._use_symmetric_norm = use_symmetric_norm
# feat fc
self.src_fc = nn.Linear(self._in_src_feats, out_feats * n_heads, bias=False)
if residual:
self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)
self.bias = None
else:
self.dst_fc = None
self.bias = nn.Parameter(out_feats * n_heads)
# attn fc
self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
if use_attn_dst:
self.attn_dst_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)
else:
self.attn_dst_fc = None
if edge_feats > 0:
self.attn_edge_fc = nn.Linear(edge_feats, n_heads, bias=False)
else:
self.attn_edge_fc = None
self.attn_drop = nn.Dropout(attn_drop)
self.edge_drop = edge_drop
self.leaky_relu = nn.LeakyReLU(negative_slope, inplace=True)
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.src_fc.weight, gain=gain)
if self.dst_fc is not None:
nn.init.xavier_normal_(self.dst_fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_src_fc.weight, gain=gain)
if self.attn_dst_fc is not None:
nn.init.xavier_normal_(self.attn_dst_fc.weight, gain=gain)
if self.attn_edge_fc is not None:
nn.init.xavier_normal_(self.attn_edge_fc.weight, gain=gain)
if self.bias is not None:
nn.init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
self._allow_zero_in_degree = set_value
def forward(self, graph, feat_src, feat_edge=None):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
assert False
if graph.is_block:
feat_dst = feat_src[: graph.number_of_dst_nodes()]
else:
feat_dst = feat_src
if self._use_symmetric_norm:
degs = graph.srcdata["deg"]
# degs = graph.out_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = torch.reshape(norm, shp)
feat_src = feat_src * norm
feat_src_fc = self.src_fc(feat_src).view(-1, self._n_heads, self._out_feats)
feat_dst_fc = self.dst_fc(feat_dst).view(-1, self._n_heads, self._out_feats)
attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
graph.srcdata.update({"feat_src_fc": feat_src_fc, "attn_src": attn_src})
if self.attn_dst_fc is not None:
attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)
graph.dstdata.update({"attn_dst": attn_dst})
graph.apply_edges(fn.u_add_v("attn_src", "attn_dst", "attn_node"))
else:
graph.apply_edges(fn.copy_u("attn_src", "attn_node"))
e = graph.edata["attn_node"]
if feat_edge is not None:
attn_edge = self.attn_edge_fc(feat_edge).view(-1, self._n_heads, 1)
graph.edata.update({"attn_edge": attn_edge})
e += graph.edata["attn_edge"]
e = self.leaky_relu(e)
if self.training and self.edge_drop > 0:
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids))
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e("feat_src_fc", "a", "m"), fn.sum("m", "feat_src_fc"))
rst = graph.dstdata["feat_src_fc"]
if self._use_symmetric_norm:
degs = graph.dstdata["deg"]
# degs = graph.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, 0.5)
shp = norm.shape + (1,) * (feat_dst.dim())
norm = torch.reshape(norm, shp)
rst = rst * norm
# residual
if self.dst_fc is not None:
rst += feat_dst_fc
else:
rst += self.bias
# activation
if self.activation is not None:
rst = self.activation(rst, inplace=True)
return rst
class GAT(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
n_classes,
n_layers,
n_heads,
n_hidden,
edge_emb,
activation,
dropout,
input_drop,
attn_drop,
edge_drop,
use_attn_dst=True,
allow_zero_in_degree=False,
):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.n_hidden = n_hidden
self.n_classes = n_classes
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.node_encoder = nn.Linear(node_feats, n_hidden)
if edge_emb > 0:
self.edge_encoder = nn.ModuleList()
for i in range(n_layers):
in_hidden = n_heads * n_hidden if i > 0 else n_hidden
out_hidden = n_hidden
# bias = i == n_layers - 1
if edge_emb > 0:
self.edge_encoder.append(nn.Linear(edge_feats, edge_emb))
self.convs.append(
GATConv(
in_hidden,
edge_emb,
out_hidden,
n_heads=n_heads,
attn_drop=attn_drop,
edge_drop=edge_drop,
use_attn_dst=use_attn_dst,
allow_zero_in_degree=allow_zero_in_degree,
use_symmetric_norm=False,
)
)
self.norms.append(nn.BatchNorm1d(n_heads * out_hidden))
self.pred_linear = nn.Linear(n_heads * n_hidden, n_classes)
self.input_drop = nn.Dropout(input_drop)
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, g):
if not isinstance(g, list):
subgraphs = [g] * self.n_layers
else:
subgraphs = g
h = subgraphs[0].srcdata["feat"]
h = self.node_encoder(h)
h = F.relu(h, inplace=True)
h = self.input_drop(h)
h_last = None
for i in range(self.n_layers):
if self.edge_encoder is not None:
efeat = subgraphs[i].edata["feat"]
efeat_emb = self.edge_encoder[i](efeat)
efeat_emb = F.relu(efeat_emb, inplace=True)
else:
efeat_emb = None
h = self.convs[i](subgraphs[i], h, efeat_emb).flatten(1, -1)
if h_last is not None:
h += h_last[: h.shape[0], :]
h_last = h
h = self.norms[i](h)
h = self.activation(h, inplace=True)
h = self.dropout(h)
h = self.pred_linear(h)
return h
import numpy as np
import random
import numpy as np
import torch
from models import MWE_GCN, MWE_DGCN
import torch.nn.functional as F
from models import MWE_DGCN, MWE_GCN
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
print ('random seed set to be ' + str(seed))
print("random seed set to be " + str(seed))
def load_model(args):
if args['model'] == 'MWE-GCN':
if args["model"] == "MWE-GCN":
model = MWE_GCN(
n_input=args['in_feats'],
n_hidden=args['hidden_feats'],
n_output=args['out_feats'],
n_layers=args['n_layers'],
n_input=args["in_feats"],
n_hidden=args["hidden_feats"],
n_output=args["out_feats"],
n_layers=args["n_layers"],
activation=torch.nn.Tanh(),
dropout=args['dropout'],
aggr_mode=args['aggr_mode'],
device=args['device'])
elif args['model'] == 'MWE-DGCN':
dropout=args["dropout"],
aggr_mode=args["aggr_mode"],
device=args["device"],
)
elif args["model"] == "MWE-DGCN":
model = MWE_DGCN(
n_input=args['in_feats'],
n_hidden=args['hidden_feats'],
n_output=args['out_feats'],
n_layers=args['n_layers'],
n_input=args["in_feats"],
n_hidden=args["hidden_feats"],
n_output=args["out_feats"],
n_layers=args["n_layers"],
activation=torch.nn.ReLU(),
dropout=args['dropout'],
aggr_mode=args['aggr_mode'],
residual=args['residual'],
device=args['device'])
dropout=args["dropout"],
aggr_mode=args["aggr_mode"],
residual=args["residual"],
device=args["device"],
)
else:
raise ValueError('Unexpected model {}'.format(args['model']))
raise ValueError("Unexpected model {}".format(args["model"]))
return model
class Logger(object):
def __init__(self, runs, info=None):
self.info = info
......@@ -53,11 +60,11 @@ class Logger(object):
if run is not None:
result = 100 * torch.tensor(self.results[run])
argmax = result[:, 1].argmax().item()
print(f'Run {run + 1:02d}:')
print(f'Highest Train: {result[:, 0].max():.2f}')
print(f'Highest Valid: {result[:, 1].max():.2f}')
print(f' Final Train: {result[argmax, 0]:.2f}')
print(f' Final Test: {result[argmax, 2]:.2f}')
print(f"Run {run + 1:02d}:")
print(f"Highest Train: {result[:, 0].max():.2f}")
print(f"Highest Valid: {result[:, 1].max():.2f}")
print(f" Final Train: {result[argmax, 0]:.2f}")
print(f" Final Test: {result[argmax, 2]:.2f}")
else:
result = 100 * torch.tensor(self.results)
......@@ -71,12 +78,39 @@ class Logger(object):
best_result = torch.tensor(best_results)
print(f'All runs:')
print(f"All runs:")
r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 1]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 2]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}')
print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}')
print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size):
self.n = n
self.batch_size = batch_size
def __iter__(self):
while True:
shuf = torch.randperm(self.n).split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment